In [253]:
from sklearn.datasets import load_breast_cancer
data = load_breast_cancer()
x = data['data']
y = data['target']
print("shape of x: {}\nshape of y: {}".format(x.shape,y.shape))

shape of x: (569, 30)
shape of y: (569,)


In [254]:
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
x = sc.fit_transform(x)

In [255]:
from torch.utils.data import Dataset, DataLoader
import torch
class dataset(Dataset):
  def __init__(self,x,y):
    self.x = torch.tensor(x,dtype=torch.float32)
    self.y = torch.tensor(y,dtype=torch.float32)
    self.length = self.x.shape[0]

  def __getitem__(self,idx):
    return self.x[idx],self.y[idx]
  def __len__(self):
    return self.length
trainset = dataset(x,y)
#DataLoader
trainloader = DataLoader(trainset,batch_size=60,shuffle=False)

In [256]:
from torch import nn
from torch.nn import functional as F
class Net(nn.Module):
  def __init__(self,input_shape):
    super(Net,self).__init__()
    self.fc1 = nn.Linear(input_shape,32)
    self.fc2 = nn.Linear(32,64)
    self.fc3 = nn.Linear(64,1)
  def forward(self,x):
    x = torch.relu(self.fc1(x))
    x = torch.relu(self.fc2(x))
    x = self.fc3(x)
    return x

In [257]:
from torch import optim
network = Net(x.shape[1])
optimizer = optim.Adam(network.parameters(), lr=0.001)

In [258]:
from torch.optim.optimizer import Optimizer
class EKFACDistilled(Optimizer):
    def __init__(self, net, eps):
        self.eps = eps
        self.params = []
        self._fwd_handles = []
        self._bwd_handles = []
        self.net = net
        for mod in net.modules():
          mod_class = mod.__class__.__name__
          if mod_class in ['Linear']:
              handle = mod.register_forward_pre_hook(self._save_input)
              self._fwd_handles.append(handle)
              handle = mod.register_full_backward_hook(self._save_grad_output)
              self._bwd_handles.append(handle)
              params = [mod.weight]
              if mod.bias is not None:
                  params.append(mod.bias)
              d = {'params': params, 'mod': mod, 'layer_type': mod_class}
              self.params.append(d)
        super(EKFACDistilled, self).__init__(self.params, {})

    def step(self, update_stats=True, update_params=True):
        for group in self.param_groups:
            if len(group['params']) == 2:
                weight, bias = group['params']
            else:
                weight = group['params'][0]
                bias = None
            state = self.state[weight]

            self._compute_kfe(group, state)

            self._precond(weight, bias, group, state)

    def _compute_kfe(self, group, state):
        mod = group['mod']
        x = self.state[group['mod']]['x']
        print(f"Shape of x: {x.shape}")
        gy = self.state[group['mod']]['gy']
        print(f"Shape of gy: {gy.shape}")

        # Computation of xxt
        x = x.data.t() # transpose of activations

        # Append column of ones to x if bias is not None
        if mod.bias is not None:
            ones = torch.ones_like(x[:1])
            x = torch.cat([x, ones], dim=0)

        # Calculate covariance matrix for activations (A_{l-1})
        xxt = torch.mm(x, x.t()) / float(x.shape[1])

        print(f'A cov matrix shape: {xxt.shape}')

        # Calculate eigenvalues and eigenvectors of covariance matrix (lambdaA, QA)
        la, Qa = torch.linalg.eigh(xxt, UPLO='U')
        state['Qa'] = Qa
        print(f'Qa eigenvec shape: {Qa.shape}')
        print(f'LambdaA eigenval vec shape: {la.shape}')
        # Computation of ggt
        gy = gy.data.t()

        # Calculate covariance matrix for layer outputs (S_{l})
        ggt = torch.mm(gy, gy.t()) / float(gy.shape[1])

        print(f'S cov matrix shape: {ggt.shape}')
        # Calculate eigenvalues and eigenvectors of covariance matrix (lambdaS, QS)
        ls, Qs = torch.linalg.eigh(ggt, UPLO='U')

        G_real = torch.kron(xxt,ggt)
        print(f'AxS direct shape: {G_real.shape}')

        state['Qs'] = Qs

        print(f'Qs eigenvec shape: {Qs.shape}')
        print(f'LambdaS eigenval vec shape: {ls.shape}')

        prod_as = torch.kron(Qa, Qs)

        print(f'Kroneker product of Qa * Qs: {prod_as.shape}')

        prod_eigval = torch.kron(torch.diag(la),torch.diag(ls))

        G = torch.matmul(prod_as,torch.matmul(prod_eigval, prod_as.t()))

        print(f'G SHAPE: {G.shape}')

        print(f'Kroneker product of LambdaA * LambdaS: {prod_eigval.shape}')

        # Outer product of the eigenvalue vectors. Of shape (len(s) x len(a))
        state["m2"] = m2 = ls.unsqueeze(1) * la.unsqueeze(0)
        print(f"eigenval outer product shape: {m2.shape}")

        print(G_real - G)

    def _precond(self, weight, bias, group, state):
        """Applies preconditioning."""
        Qa = state['Qa']
        Qs = state['Qs']
        m2 = state['m2']
        x = self.state[group['mod']]['x']
        print(x)
        gy = self.state[group['mod']]['gy']
        g = weight.grad.data
        s = g.shape
        s_x = x.size()
        s_gy = gy.size()
        bs = x.size(0)

        # Append column of ones to x if bias is not None
        if bias is not None:
            ones = torch.ones_like(x[:,:1])
            x = torch.cat([x, ones], dim=1)

        # KFE of activations ??
        x_kfe = torch.mm(x, Qa)

        print(f"KFE of activations a shape: {x_kfe.shape}")

        # KFE of layer outputs ??
        gy_kfe = torch.mm(gy, Qs)

        print(f"KFE of outputs gy shape: {gy_kfe.shape}")

        m2 = torch.mm(gy_kfe.t()**2, x_kfe**2) / bs

        print(f'kfe squared matrix idk shape: {m2.shape}')
        g_kfe = torch.mm(gy_kfe.t(), x_kfe) / bs

        print(f'g_kfe shape: {g_kfe.shape}')

        g_nat_kfe = g_kfe / (m2 + self.eps)

        print(f'g_nat_kfe shape: {g_nat_kfe.shape}')

        g_nat = torch.mm(g_nat_kfe, Qs.t())

        if bias is not None:
            gb = g_nat[:, -1].contiguous().view(*bias.shape)
            bias.grad.data = gb
            g_nat = g_nat[:, :-1]

        g_nat = g_nat.contiguous().view(*s)
        weight.grad.data = g_nat

    def _save_input(self, mod, i):
        """Saves input of layer to compute covariance."""
        self.state[mod]['x'] = i[0]

    def _save_grad_output(self, mod, grad_input, grad_output):
        """Saves grad on output of layer to compute covariance."""
        self.state[mod]['gy'] = grad_output[0] * grad_output[0].size(0)

In [259]:
precond = EKFACDistilled(network, eps=0.001)
criterion = torch.nn.BCELoss()

for mod in network.modules():
  mod_class = mod.__class__.__name__
  print(mod_class)
  print("**********************")

Net
**********************
Linear
**********************
Linear
**********************
Linear
**********************


In [262]:
for i, (inputs, targets) in enumerate(trainloader):
  optimizer.zero_grad
  print(f'Input Shape: {inputs.shape}')
  print(inputs)
  print(f'Target Shape: {targets.shape}')
  outputs = network(inputs)
  print(outputs.shape)
  outputs[0].backward()
  #print(f'Output Shape: {targets.shape}')
  #loss = criterion(outputs, targets.reshape(-1,1))
  #loss.backward()

  print(network.fc1.weight.grad.shape)
  #precond.step()

  break

Input Shape: torch.Size([60, 30])
tensor([[ 1.0971, -2.0733,  1.2699,  ...,  2.2961,  2.7506,  1.9370],
        [ 1.8298, -0.3536,  1.6860,  ...,  1.0871, -0.2439,  0.2812],
        [ 1.5799,  0.4562,  1.5665,  ...,  1.9550,  1.1523,  0.2014],
        ...,
        [ 0.1655,  0.5353,  0.1475,  ...,  1.0475,  1.2898,  1.4106],
        [-0.3060,  0.0047, -0.3855,  ..., -1.5759, -0.7470, -1.1668],
        [-1.5647, -1.7452, -1.5499,  ..., -1.0722,  0.5165,  0.3499]])
Target Shape: torch.Size([60])
torch.Size([60, 1])
torch.Size([32, 30])
