In [1]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.func import vmap      # torch>=2.0

In [2]:
def init_hdc(ratio, seed, flip_perc=None):
    try:
        del self.alphag1
        del self.g
        del self.alpha1

    except UnboundLocalError:
        pass  # If 'g' was not defined, do nothing
    except AttributeError:
        pass  # Do nothing if the attribute does not exist

    slope = self.bn.weight / torch.sqrt(self.bn.running_var + self.eps)
    w_bn = self.conv.weight * slope.view(-1, 1, 1, 1)
    w_bn = w_bn.unsqueeze(0)

    n = w_bn.shape[1:].numel()
    self.nHDC = int(self.custom_round(ratio * n)) if ratio<1000 else int(ratio)

    torch.manual_seed(seed)
    self.g = torch.randn(self.nHDC, *w_bn.shape[1:], device=w_bn.device, dtype=w_bn.dtype)
    self.alpha1 = torch.sign((w_bn * self.g).sum(dim=(2, 3, 4), keepdim=True))
    if flip_perc is not None and flip_perc > 0.0:
        self.flip_sign_(self.alpha1, flip_perc)
    
    temp = (self.alpha1 * self.g)
    self.size = temp.shape
    self.alphag1 = temp.view(-1, *w_bn.shape[2:])


def hdc(self, x):
    B, C, H, W = x.shape

    x = x + self.bias_trick_par
    x_p = F.pad(x, (self.padding, self.padding, self.padding, self.padding), value=0)
    
    out = nn.functional.conv2d(x_p, self.alphag1, stride=self.stride, padding=0)
    out = out.view(B, self.size[0], self.size[1], out.size(2), out.size(3))
    
    zhat = (torch.pi / (2 * self.nHDC)) * torch.sign(out).sum(dim=1)
    return zhat

In [3]:
x = torch.randn((10, 3, 28, 28))
print(x.shape)

torch.Size([10, 3, 28, 28])


In [4]:
w = torch.randn((64, 3, 5, 5)).unsqueeze(0)
nHDC = 1000
G = torch.randn(1000, *w.shape[1:])
print(w.shape, G.shape)

torch.Size([1, 64, 3, 5, 5]) torch.Size([1000, 64, 3, 5, 5])


In [5]:
alpha1 = torch.sign((w * G).sum(dim=(2, 3, 4), keepdim=True))

In [6]:
alphag1 = alpha1 * G
size = alphag1.shape
alphag1 = alphag1.view(-1, *w.shape[2:])
x_p = torch.nn.functional.pad(x, (2, 2, 2, 2), value=0)
out1 = torch.nn.functional.conv2d(x_p, alphag1, stride=1, padding=0)
out1 = out1.view(x.shape[0], size[0], size[1], out1.size(2), out1.size(3))
zhat1 = (torch.pi / (2 * nHDC)) * torch.sign(out1).sum(dim=1)

In [7]:
def apply_one(weight):
    # padded x (x_p) is captured from the outer scope; stride=1, padding=0 keeps 28Ã—28
    return F.conv2d(x_p, weight, padding=0)
print(G.shape)
gx = vmap(apply_one)(G)
print(gx.shape)
gx = torch.sign(gx)
out2 = (gx.transpose(1, 2) * alpha1)
zhat2 = out2.sum(dim=0) * (torch.pi / (2 * nHDC))
zhat2 = zhat2.transpose(0, 1)
torch.norm(zhat2 - zhat1) / torch.norm(zhat1)

torch.Size([1000, 64, 3, 5, 5])
torch.Size([1000, 10, 64, 28, 28])


tensor(0.)