In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

In [21]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
epochs = 10
lr = 0.001

def get_data(training_data, test_data):
    return (
        DataLoader(training_data), #, shuffle=True),
        DataLoader(test_data),
    )

train_dl, valid_dl = get_data(training_data, test_data)

Using cpu device


In [48]:
x1,y1 = next(iter(train_dl))
x1, y1 =x1.flatten(), y1.squeeze()
x1.shape, y1.shape

(torch.Size([784]), torch.Size([]))

In [335]:
state_dict = torch.load('model.pth')
state_dict.keys()

odict_keys(['linear_relu_stack.0.weight', 'linear_relu_stack.0.bias', 'linear_relu_stack.2.weight', 'linear_relu_stack.2.bias'])

In [336]:
state_dict['linear_relu_stack.0.weight']

tensor([[-0.0003,  0.0192, -0.0294,  ...,  0.0219,  0.0037,  0.0021],
        [-0.0198, -0.0150, -0.0104,  ..., -0.0203, -0.0060, -0.0299],
        [-0.0201,  0.0149, -0.0333,  ..., -0.0203,  0.0012,  0.0080],
        ...,
        [-0.0215,  0.0106,  0.0308,  ..., -0.0199,  0.0161, -0.0342],
        [ 0.0350, -0.0297, -0.0037,  ...,  0.0171,  0.0238, -0.0001],
        [ 0.0085,  0.0223, -0.0324,  ..., -0.0296,  0.0182, -0.0296]])

In [337]:
weights = []
weights.append(state_dict['linear_relu_stack.0.weight'].clone())
weights.append(state_dict['linear_relu_stack.2.weight'].clone())
bias = []
bias.append(state_dict['linear_relu_stack.0.bias'].clone())
bias.append(state_dict['linear_relu_stack.2.bias'].clone())

In [356]:
def relu(X):
    X[X<0] = 0
    return X

def d_relu(l):
    d = a[l].clone()
    d[a[l] < 0] = 0
    d[a[l] > 0] = 1
    return d

def softmax(x):
    return x.exp() / x.exp().sum()

def delta_l(y_pred, y, l):
    if l == len(weights): # This means it's the last layer... must be a vector
        y_ = torch.zeros(len(y_pred))
        y_[y] = 1
        return (y_pred - y_)[None,:]
    else:
        return (weights[l].T@(delta_l(y_pred,y,l+1)*d_relu(l+1)).T).T
    
activations = []
activations.append(relu)
activations.append(softmax)
    
def forward(X):
    a = [X]
    
    for w,b,func in zip(weights, bias, activations):
        a.append(func(a[-1]@w.T+b))
    return a

def cross_entropy(y_pred,y):
    return -y_pred[y].log()

def backward(y_pred,y):
    for l in range(len(weights)):
        print(delta_l(y_pred,y,len(weights)-l).shape, a[len(a)-l-2].shape)
        weights[len(weights)-l-1] -= lr*delta_l(y_pred,y,len(weights)-l).T@a[len(a)-l-2][None,:]
        bias[len(bias)-l-1] -= lr*delta_l(y_pred,y,len(bias)-l).squeeze()
        
    loss = cross_entropy(y_pred,y)



In [338]:
state_dict['linear_relu_stack.2.weight'][y1]

tensor([ 0.0453, -0.1708, -0.0112, -0.1373, -0.1919,  0.1684, -0.0890,  0.1894,
        -0.1370,  0.0977, -0.0159, -0.0262,  0.1361, -0.2134,  0.1970, -0.0483,
         0.1064, -0.1399])

In [343]:
a[-2]

tensor([0.1195, 0.1275, 0.0000, 0.0000, 0.0728, 0.0000, 0.0000, 0.0665, 0.2172,
        0.0000, 0.0826, 0.0000, 0.0000, 0.0000, 0.1340, 0.2454, 0.0387, 0.0000])

In [344]:
weights[-1][y1]

tensor([ 0.0453, -0.1708, -0.0112, -0.1373, -0.1919,  0.1684, -0.0890,  0.1894,
        -0.1370,  0.0977, -0.0159, -0.0262,  0.1361, -0.2134,  0.1970, -0.0483,
         0.1064, -0.1399])

In [349]:
bias[-1]

tensor([-0.1422, -0.0626, -0.0508, -0.2056, -0.0988, -0.1239,  0.0111,  0.2339,
         0.1431,  0.1194])

In [357]:
backward(a[-1],y1)

torch.Size([1, 10]) torch.Size([18])
torch.Size([1, 18]) torch.Size([784])


In [346]:
weights[-1][y1]

tensor([ 0.1546, -0.0542, -0.0112, -0.1373, -0.1254,  0.1684, -0.0890,  0.2502,
         0.0615,  0.0977,  0.0597, -0.0262,  0.1361, -0.2134,  0.3195,  0.1761,
         0.1419, -0.1399])

In [339]:
a = forward(x1)
[k.shape for k in a]

[torch.Size([784]), torch.Size([18]), torch.Size([10])]

In [245]:
y = torch.zeros(10)
y[y1] = 1
y, a[-1]

(tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]),
 tensor([0.0875, 0.0947, 0.0959, 0.0821, 0.0914, 0.0891, 0.1020, 0.1274, 0.1164,
         0.1136]))

In [220]:
dc_dz2 = a[-1] - y

In [233]:
dc_dw2 = dc_dz2[None, :].T @ a[-2][None,:]

In [235]:
dc_dw2.shape, weights[-1].shape

(torch.Size([10, 18]), torch.Size([10, 18]))