In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import random
import numpy as np
import matplotlib.pyplot as plt

In [8]:
lambda_min = 0.95

x = torch.tensor([[1., 0.], [0., lambda_min]])
y_true = torch.tensor([[1.], [1.]])

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 1, bias=False)
        self.fc2 = nn.Linear(1, 1, bias=False)

        self.gradients = {
            'fc1.weight': [],
            # 'fc1.bias': [],
            'fc2.weight': [],
            # 'fc2.bias': []
        }
        self.loss = []
        self.ntk = []

    def init_layer1(self, weight, bias):
        self.fc1.weight.data = weight.float()
        # self.fc1.bias.data = bias.float()

    def init_layer2(self, weight, bias):
        self.fc2.weight.data = weight.float()
        # self.fc2.bias.data = bias.float()

    def get_layer1(self):
        return self.fc1.weight.data
    
    def get_layer2(self):
        return self.fc2.weight.data

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x
    
    def train_epoch(self, x, y_true, eta):
        y_pred = self.forward(x)
        loss = F.l1_loss(y_pred, y_true)
        self.zero_grad()
        loss.backward()
        with torch.no_grad():
            for param in self.parameters():
                param -= eta * param.grad
        
        for name, param in self.named_parameters():
            self.gradients[name].append(param.grad)
        self.loss.append((y_pred - y_true))


        return loss
    
    def train(self, x, y_true, eta, epochs=None):
        if epochs is None:
            epochs = 1
            while True:
                loss = self.train_epoch(x, y_true, eta)

                theta1, theta2 = self.get_layer1().flatten().tolist()
                theta3 = self.get_layer2().flatten().item()
                ntk = [
                        [ 1 - eta*(theta1**2+theta3**2), -eta*theta1*theta2],
                        [-eta*theta1*theta2, 1 - eta*(theta2**2+theta3**2)]
                ]


                print(f'Epoch {epochs}: loss {loss}' + f' eigen {np.linalg.eigvals(ntk)}')
                print(f' theta1 {theta1} theta2 {theta2} theta3 {theta3}')  
                print(f' theta1*theta3 {theta1*theta3} theta2*theta3 {theta2*theta3}')
                epochs += 1
                if loss < 1e-6 or epochs > 10:
                    break
        else:
            for epoch in range(epochs):
                loss = self.train_epoch(x, y_true, eta)
                print(f'Epoch {epoch}: loss {loss}')
        return loss 

In [11]:
net = Net()
net.init_layer1(torch.tensor([[0, 0]]), torch.tensor([0]))
net.init_layer2(torch.tensor([[1]]), torch.tensor([0]))

eta = 2/(1+lambda_min**2)
print(1-eta)
eta = eta*0.9
loss = net.train(x, y_true, eta=eta)



-0.05124835742444156
Epoch 1: loss 1.0 eigen [-0.34894091  0.05387648]
 theta1 0.47306177020072937 theta2 0.4494086802005768 theta3 1.0
 theta1*theta3 0.47306177020072937 theta2*theta3 0.4494086802005768
Epoch 2: loss 0.5499999523162842 eigen [-2.53452966 -0.92326011]
 theta1 0.9461235404014587 theta2 0.8988173604011536 theta3 1.4257556200027466
 theta1*theta3 1.3489409549442755 theta2*theta3 1.2814939029479788
Epoch 3: loss 0.2831800580024719 eigen [0.55517776 0.68800945]
 theta1 0.27165305614471436 theta2 0.2580704092979431 theta3 0.5742444396018982
 theta1*theta3 0.15599525699196448 theta2*theta3 0.14819549756512984
Epoch 4: loss 0.8516095280647278 eigen [-0.16553458  0.36579219]
 theta1 0.5433061122894287 theta2 0.5161408185958862 theta3 0.8187322020530701
 theta1*theta3 0.4448222097036165 theta2*theta3 0.4225811089784841
Epoch 5: loss 0.5768628716468811 eigen [-2.17685183 -0.61796551]
 theta1 0.9306169748306274 theta2 0.8840861320495605 theta3 1.3077077865600586
 theta1*theta3 1.2