In [1]:
import numpy as np
import scipy
import scipy.stats as stats
import math
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from sklearn.datasets import load_boston
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

X, Y = load_boston(return_X_y=True)
X = MinMaxScaler(feature_range=(-1, 1)).fit_transform(X)
n_train, k = int(X.shape[0] * 2 / 3), X.shape[1]
n_test = X.shape[0] - n_train

In [3]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, train_size=n_train)

X_train = torch.DoubleTensor(X_train)
Y_train = torch.DoubleTensor(Y_train).view(-1, n_train).T
X_test = torch.DoubleTensor(X_test)
Y_test = torch.DoubleTensor(Y_test).view(-1, n_test).T

In [4]:
print(X_train.shape)
print(X_train[:2])
print(Y_train.shape)
print(Y_train[:2])

torch.Size([337, 13])
tensor([[-0.4170, -1.0000,  0.2933, -1.0000,  0.2099, -0.3321,  0.7755, -0.9058,
          1.0000,  0.8282,  0.6170, -0.3593,  0.3747],
        [-0.8977, -1.0000,  0.2933, -1.0000,  0.3704, -1.0000,  0.7508, -0.9120,
          1.0000,  0.8282,  0.6170,  0.7872, -0.7025]], dtype=torch.float64)
torch.Size([337, 1])
tensor([[10.4000],
        [27.5000]], dtype=torch.float64)


# Models

In [5]:
def init(module):
    if type(module) == nn.Linear:
        module.weight.data.normal_(0, 1)
        module.weight.data *= 1. / np.sqrt(module.weight.size(1))
        if module.bias is not None:
            module.bias.data.normal_(0, 1)
            module.bias.data *= 0.00001

In [6]:
class SimpleRegression(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.fc1 = nn.Linear(d, 120).to(dtype=torch.float64)
        self.fc2 = nn.Linear(120, 84).to(dtype=torch.float64)
        self.fc3 = nn.Linear(84, 1).to(dtype=torch.float64)
    
    def forward(self, x):
        y = F.relu(self.fc1(x))
        y = F.relu(self.fc2(y))
        y = self.fc3(y)
        return y

    def nparams(self):
        n = 0
        for p in self.parameters():
            n += p.nelement()
        return n
        

# WideNN

In [7]:
class WideNN:
    def __init__(self, model, X, Y):
        self.f_0 = model
        self.X = X
        self.Y = Y
        self.n = X.shape[0]
        self.k = X.shape[1]
        self.nparams = model.nparams()

        self.grad_0 = self.get_gradient(X) # df_0(X)/dtheta
        self.kernel_0 = torch.mm(self.grad_0, self.grad_0.T)
    
    def linearize(self, t, nu):
        nuThetat = -nu * self.kernel_0 * t
        exp_kernel_0 = torch.DoubleTensor(scipy.linalg.expm((nuThetat).numpy())) # e^(-nu * kernel_0 * t)
        kernel_0_inv = torch.DoubleTensor(np.linalg.inv(self.kernel_0.numpy())) # kernel_0^(-1)
        f_0_X = self.f_0(self.X) # f0(X)

        mu_partial = torch.mm(kernel_0_inv, (torch.eye(self.n) - exp_kernel_0))
        gamma_partial = torch.mm(mu_partial, f_0_X)
        mu_partial = torch.mm(mu_partial, self.Y)

        def f(x):
            kernel = torch.mm(self.get_gradient(x), self.grad_0.T)
            mu = torch.mm(kernel, mu_partial)
            gamma = self.f_0(x) - torch.mm(kernel, gamma_partial)
            return mu + gamma
        return f
    
    def nu_critical(self):
        eigens = torch.eig(self.kernel_0)[0].view(-1)[::2]
        return 2. * (torch.min(eigens) + torch.max(eigens)).numpy() ** -1

    def get_gradient(self, data):
        grad = torch.DoubleTensor(np.zeros((data.shape[0], self.nparams)))
        for i in range(data.shape[0]):
            self.f_0.zero_grad()
            pred = self.f_0(data[i])
            pred.backward()
            grads = [p.grad.view(-1) for p in self.f_0.parameters()]
            grad[i, :] = torch.cat(grads).view(-1)
        return grad
        


In [8]:
reg1 = SimpleRegression(k)
reg1.apply(init)

SimpleRegression(
  (fc1): Linear(in_features=13, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=1, bias=True)
)

In [9]:
wideNN = WideNN(reg1, X_train, Y_train)

In [11]:
nu_critical = wideNN.nu_critical()
print(nu_critical)

0.00039580751771586505


In [12]:
nu = 1.
t = 1.
f_lin = wideNN.linearize(t, nu)

In [13]:
def r2(pred_y, y):
    s1 = 0.
    s2 = 0.
    mean = torch.mean(y)
    for i in range(y.shape[0]):
        s1 += (pred_y[i] - y[i]) ** 2
        s2 += (y[i] - mean) ** 2
    return 1. - s1 / s2

In [14]:
pred_Y_train = f_lin(X_train)
print("Train")
print("r2", r2(pred_Y_train, Y_train).detach().numpy()[0])

Train
r2 0.9465517809511681


In [15]:
pred_Y_test = f_lin(X_test)
print("Test")
print("r2", r2(pred_Y_test, Y_test).detach().numpy()[0])

Test
r2 0.8440471765767307


# SGD

In [16]:
import torch.optim as optim

reg2 = SimpleRegression(k)
reg2.apply(init)
optimizer = optim.SGD(reg2.parameters(), lr=0.01)

# Inside the training loop
for t in range(1000):
    output = reg2(X_train)
    loss = -r2(output, Y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 100 == 0:
        print("Step", t, "\tr2", -loss.detach().numpy()[0])


Step 0 	r2 -6.64848903177599
Step 100 	r2 0.17817834470008054
Step 200 	r2 0.40056396844043607
Step 300 	r2 0.5226518154870351
Step 400 	r2 0.6051399850870088
Step 500 	r2 0.6599199934299809
Step 600 	r2 0.6953469842502269
Step 700 	r2 0.7184822864455804
Step 800 	r2 0.7344633227374477
Step 900 	r2 0.7464076037571982


In [17]:
print("Test")
print("r2", r2(reg2(X_test), Y_test).detach().numpy()[0])

Test
r2 0.7181211150081149
