In [46]:
from fmodule import FModule
import torch.nn as nn
import torch.nn.functional as F
import torch

class MLP(FModule):
    def __init__(self, bias=False):
        super().__init__()
        self.fc1 = nn.Linear(784, 128, bias=bias)
        self.fc2 = nn.Linear(128, 10, bias=bias)

    def forward(self, x):
        self.a0 = x.view(x.shape[0], -1)
        self.s1 = self.fc1(self.a0)
        self.a1 = F.relu(self.s1)
        self.s2 = self.fc2(self.a1)
        self.FIM_params = [self.s1, self.s2]
        return self.s2

In [134]:
X = torch.randn((1, 28, 28))
Y = torch.randint(0, 1, [1])

loss_func = F.mse_loss

In [135]:
model = MLP()

In [None]:
def inverse(input):
    try:
        res = torch.linalg.inv(input + torch.eye(input.shape[0]) * 0.01)
    except:
        res = torch.linalg.inv(input + torch.eye(input.shape[0]) * 0.1)
    return res

def compute_natural_grads(model:MLP, X:torch.Tensor, Y:torch.Tensor, loss_fn=F.kl_div):
    pred = model(X)
    loss = loss_func(pred, F.one_hot(Y, 10) * 1.0)
    
    grads = torch.autograd.grad(loss, [*model.FIM_params, *model.parameters()], create_graph=False, retain_graph=False)
    g1, g2, dw1, dw2 = grads
    a0, a1 = model.a0, model.a1
    
    A0 = torch.mean(a0.unsqueeze(2) @ a0.unsqueeze(2).transpose(1,2), dim=0)
    A1 = torch.mean(a1.unsqueeze(2) @ a1.unsqueeze(2).transpose(1,2), dim=0)
    G1 = torch.mean(g1.unsqueeze(2) @ g1.unsqueeze(2).transpose(1,2), dim=0)
    G2 = torch.mean(g2.unsqueeze(2) @ g2.unsqueeze(2).transpose(1,2), dim=0)
    
    natural_grads = [
        inverse(G1) @ dw1 @ inverse(A0),
        inverse(G2) @ dw2 @ inverse(A1),
    ]
    
    return natural_grads