In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.utils.data as data_utils
from ucimlrepo import fetch_ucirepo
from torch.utils.data import DataLoader
from model import WideModel
from tqdm.auto import tqdm

In [None]:
# load dataset
# pip install ucimlrepo

# fetch dataset 
breast_cancer_wisconsin_diagnostic = fetch_ucirepo(id=17) 
  
# data (as pandas dataframes) 
X = torch.tensor(breast_cancer_wisconsin_diagnostic.data.features.values, dtype=torch.float32)
y = breast_cancer_wisconsin_diagnostic.data.targets
y["Diagnosis"] = y["Diagnosis"].map({"M": 1, "B": 0})
y = torch.tensor(y.values, dtype=torch.float32)

train = data_utils.TensorDataset(X[:455], y[:455])
test = data_utils.TensorDataset(X[455:], y[455:])

In [None]:
# create dataloaders. Batch size must be 1
batch_size = 1
train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test, batch_size=batch_size, shuffle=True)

In [None]:
# create device
dev = "cuda" if torch.cuda.is_available() else "cpu"
print("using device", dev)

# create model
model = WideModel(input_dim=30,hidden_dim_scale = 20, output_dim=1).to(dev)

# create optimizer
lr = 1e-3
optimizer = torch.optim.SGD([p for p in model.parameters()], lr=lr)

In [None]:
updated=True
while updated:
    # get linearized models:
    num_params = len(model.flatten_parameters())

    # we reduce f(x,w) to Aw+B, where there is a different A,B per x

    As = torch.empty((0,num_params))
    Bs = torch.empty((0,))
    ys = torch.empty((0,))

    for x,y in tqdm(train_dataloader):
        x = x.to(dev)
        
        # A = gradient matrix of logits
        A = model.flatten_gradient(x).unsqueeze(0)
        # print(A.shape)
        As = torch.concat((As, A), dim=0)
        
        # B = f(x,w) - A w
        B = model.forward(x) - A @ model.flatten_parameters()
        Bs = torch.concat((Bs, B[:,0]), dim=0)
        
        ys = torch.concat((ys, y[:,0]), dim=0)
        
        model.Adict[x] = A
        model.Bdict[x] = B
        
    model.update_stored_linear_tensors(As, Bs)
    
    f = model.batched_linearized_forward(model.flatten_parameters())
    updated=False
    while f.abs().max() > 1:
        updated=True
        w = model.flatten_parameters()/2
        model.update_parameters(w)
        f = model.batched_linearized_forward(model.flatten_parameters())
    

In [None]:
w = model.w0
print(w)
for step in range(100):
    print(f"Starting Newton step {step}")
    old_w = w
    w = model.newton_update(w, ys)
    # print(w)
    print(f"change in w is: {torch.linalg.norm(w-old_w)}")
    # print(f"Achieved loss")

In [None]:
f = model.batched_linearized_forward(w)
dl = (torch.exp(f)/(1+torch.exp(f)) - ys) @ model.Atensor

In [None]:
(torch.exp(f[3])/(1+torch.exp(f[3])))

In [None]:
f[3]

In [None]:
torch.exp(f)/(1+torch.exp(f))

In [None]:
dl

In [None]:
f

In [None]:
(torch.exp(f)/(1+torch.exp(f)) - ys)

In [None]:
torch.exp(f)

In [None]:
ys