In [69]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import numpy as np
import matplotlib.pyplot as plt

In [70]:
x = torch.tensor([[1., 0.], [0., 0.5]])
y_true = torch.tensor([[1.], [1.]])

In [108]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 1, bias=False)
        self.fc2 = nn.Linear(1, 1, bias=False)

        self.gradients = {
            'fc1.weight': [],
            # 'fc1.bias': [],
            'fc2.weight': [],
            # 'fc2.bias': []
        }
        self.loss = []
        self.ntk = []

    def init_layer1(self, weight, bias):
        self.fc1.weight.data = weight.float()
        # self.fc1.bias.data = bias.float()

    def init_layer2(self, weight, bias):
        self.fc2.weight.data = weight.float()
        # self.fc2.bias.data = bias.float()

    def get_layer1(self):
        return self.fc1.weight.data
    
    def get_layer2(self):
        return self.fc2.weight.data

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
    
    def train_epoch(self, x, y_true, eta):
        y_pred = self.forward(x)
        loss = F.l1_loss(y_pred, y_true)
        self.zero_grad()
        loss.backward()
        with torch.no_grad():
            for param in self.parameters():
                param -= eta * param.grad
        
        for name, param in self.named_parameters():
            self.gradients[name].append(param.grad)
        self.loss.append((y_pred - y_true))


        return loss
    
    def train(self, x, y_true, eta, epochs=None):
        if epochs is None:
            epochs = 1
            while True:
                loss = self.train_epoch(x, y_true, eta)
                print(f'Epoch {epochs}: loss {loss}')
                epochs += 1
                if loss < 1e-6 or epochs > 1000:
                    break
        else:
            for epoch in range(epochs):
                loss = self.train_epoch(x, y_true)
                print(f'Epoch {epoch}: loss {loss}')
        return loss

In [113]:
net = Net()
net.init_layer1(torch.tensor([[0, 0]]), torch.tensor([0]))
net.init_layer2(torch.tensor([[1]]), torch.tensor([0]))

loss = net.train(x, y_true, eta=1e-1)

Epoch 1: loss 1.0
Epoch 2: loss 0.96875
Epoch 3: loss 0.9373047351837158
Epoch 4: loss 0.9052725434303284
Epoch 5: loss 0.8722571134567261
Epoch 6: loss 0.8378521203994751
Epoch 7: loss 0.8016366362571716
Epoch 8: loss 0.763169527053833
Epoch 9: loss 0.721984326839447
Epoch 10: loss 0.6775836944580078
Epoch 11: loss 0.6294331550598145
Epoch 12: loss 0.5769545435905457
Epoch 13: loss 0.5195194482803345
Epoch 14: loss 0.4564414620399475
Epoch 15: loss 0.3869679570198059
Epoch 16: loss 0.41383740305900574
Epoch 17: loss 0.3591017425060272
Epoch 18: loss 0.3958168029785156
Epoch 19: loss 0.33080804347991943
Epoch 20: loss 0.37396302819252014
Epoch 21: loss 0.31340306997299194
Epoch 22: loss 0.34155625104904175
Epoch 23: loss 0.2955520749092102
Epoch 24: loss 0.30633029341697693
Epoch 25: loss 0.2770194113254547
Epoch 26: loss 0.26802048087120056
Epoch 27: loss 0.25755587220191956
Epoch 28: loss 0.2263418436050415
Epoch 29: loss 0.23689627647399902
Epoch 30: loss 0.18098655343055725
Epoch 3

In [110]:
theta1, theta2 = net.get_layer1().flatten().tolist()
theta3 = net.get_layer2().flatten().item()

In [111]:
theta1*theta3, theta2*theta3

(0.47048945724394997, 1.3867693484556298)

In [103]:
net.loss

[tensor([[1.],
         [1.]], grad_fn=<PowBackward0>),
 tensor([[0.8100],
         [0.9506]], grad_fn=<PowBackward0>),
 tensor([[0.6526],
         [0.9026]], grad_fn=<PowBackward0>),
 tensor([[0.5180],
         [0.8543]], grad_fn=<PowBackward0>),
 tensor([[0.4020],
         [0.8049]], grad_fn=<PowBackward0>),
 tensor([[0.3030],
         [0.7540]], grad_fn=<PowBackward0>),
 tensor([[0.2204],
         [0.7019]], grad_fn=<PowBackward0>),
 tensor([[0.1540],
         [0.6491]], grad_fn=<PowBackward0>),
 tensor([[0.1027],
         [0.5963]], grad_fn=<PowBackward0>),
 tensor([[0.0650],
         [0.5443]], grad_fn=<PowBackward0>),
 tensor([[0.0387],
         [0.4939]], grad_fn=<PowBackward0>),
 tensor([[0.0214],
         [0.4456]], grad_fn=<PowBackward0>),
 tensor([[0.0108],
         [0.4000]], grad_fn=<PowBackward0>),
 tensor([[0.0047],
         [0.3574]], grad_fn=<PowBackward0>),
 tensor([[0.0016],
         [0.3179]], grad_fn=<PowBackward0>),
 tensor([[0.0003],
         [0.2816]], grad_fn=<