In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

### Sample linear layer

In [2]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Linear(5, 3)  # Input size 5, Output size 3
        self.activation = nn.ReLU()
    def forward(self, x):
        fc_out = self.fc(x)
        return self.activation(fc_out)

In [None]:
# Dummy input

input_tensor = torch.randn(10, 5, requires_grad=True) # Batch size 10, Input size 5
target = torch.randn(10, 3)  # Target size matching the output size
model = SimpleNN()

In [4]:
activations = model(input_tensor)

In [None]:
loss_fn = nn.MSELoss()
loss = loss_fn(activations, target)
activations.retain_grad()   # Needed to access gradients of intermediates
loss.backward()  # Compute gradients

### Calculate covariance matrix of activations and gradients

In [6]:
layer_input = input_tensor.detach()  # This is the input to the fc layer
activation_mean = layer_input.mean(dim=0, keepdim=True)
A_KFAC = layer_input - activation_mean  # Centered activations
A_KFAC = A_KFAC.T @ A_KFAC / (activations.size(0) - 1)  # Covariance matrix

In [None]:
gradients = activations.grad   # Gradient wrt layer output after activation
gradient_mean = gradients.mean(dim=0, keepdim=True)
B_KFAC = gradients - gradient_mean  # Centered gradients
B_KFAC = B_KFAC.T @ B_KFAC / (gradients.size(0) - 1)  # Covariance matrix

### Compute FIM using K-FAC

In [8]:
F = torch.kron(A_KFAC, B_KFAC)

In [9]:
# Print results
print("Covariance Matrix of Activations (A_KFAC):")
print(A_KFAC)

print("\nCovariance Matrix of Gradients (B_KFAC):")
print(B_KFAC)

print("\nKronecker product F:")
print(F)

Covariance Matrix of Activations (A_KFAC):
tensor([[ 5.2808e-01, -6.2417e-01, -8.4895e-02, -6.5382e-02,  2.1640e-01],
        [-6.2417e-01,  1.4291e+00, -6.8364e-04, -3.7937e-01, -2.6783e-01],
        [-8.4895e-02, -6.8364e-04,  8.6800e-01,  5.2422e-01, -2.3668e-02],
        [-6.5382e-02, -3.7937e-01,  5.2422e-01,  6.3217e-01, -3.0087e-02],
        [ 2.1640e-01, -2.6783e-01, -2.3668e-02, -3.0087e-02,  3.7295e-01]])

Covariance Matrix of Gradients (B_KFAC):
tensor([[ 0.0052, -0.0001, -0.0014],
        [-0.0001,  0.0067, -0.0015],
        [-0.0014, -0.0015,  0.0083]])

Kronecker product F:
tensor([[ 2.7638e-03, -7.0235e-05, -7.2501e-04, -3.2667e-03,  8.3015e-05,
          8.5693e-04, -4.4431e-04,  1.1291e-05,  1.1655e-04, -3.4219e-04,
          8.6959e-06,  8.9764e-05,  1.1326e-03, -2.8781e-05, -2.9710e-04],
        [-7.0235e-05,  3.5560e-03, -8.0745e-04,  8.3015e-05, -4.2030e-03,
          9.5437e-04,  1.1291e-05, -5.7166e-04,  1.2981e-04,  8.6959e-06,
         -4.4027e-04,  9.9971e-05,

In [10]:
print(A_KFAC.shape)
print(B_KFAC.shape)
print(F.shape)

torch.Size([5, 5])
torch.Size([3, 3])
torch.Size([15, 15])
