In [3]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np
from pathlib import Path
from torch import tensor
from fastcore.test import test_close
torch.manual_seed(42)

mpl.rcParams['image.cmap'] = 'gray'
torch.set_printoptions(precision=2, linewidth=125, sci_mode=False)
np.set_printoptions(precision=2, linewidth=125)

path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
nh = 50
n,m = x_train.shape

In [5]:
w1 = torch.randn(m,nh)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)
b2 = torch.zeros(1)

In [6]:
def lin(x,w,b):
    return x@w + b

t = lin(x_valid , w1,b1)
t.shape

torch.Size([10000, 50])

In [7]:
def relu(x):
    return x.clamp_min(0.)
t = relu(t)

In [8]:
def model(xb):
    l1 = lin(xb,w1,b1)
    l2=relu(l1)
    return lin(l2,w2,b2)
model(x_valid).shape

torch.Size([10000, 1])

In [9]:
def mse(output,target):
    return (output[:,0] - target).pow(2).mean()

In [10]:
def lin_grad(inp , out , w,b):
    print(inp.shape,out.g.shape,w.T.shape,b.shape)
    inp.g = out.g @ w.T
    w.g =  inp.T @ out.g
    b.g = out.g



def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    b.g = out.g.sum(0)


In [11]:
def forward_backward(inp,targ):
    #forward pass
    l1 = lin(inp , w1, b1)
    l2 = relu(l1)
    out = lin(l2,w2,b2)
    diff = out[:,0] - targ
    loss = diff.pow(2).mean()

    #backward pass
    out.g = 2.*diff[:,None]/inp.shape[0]
    print(out.g.shape)
    lin_grad(l2,out,w2,b2)
    l1.g = (l1>0).float()*l2.g
    print(5)
    lin_grad(inp , l1,w1,b1)


In [12]:
forward_backward(x_train,y_train)

In [13]:
def get_grad(x):
    return x.g.clone()

chks = w1,w2,b1,b2,x_train

ptgrad = w1g,w2g,b1g,b2g,ig = tuple(map(get_grad,chks))

In [14]:
def mkgrad(x): return x.clone().requires_grad_(True)
ptgrads = w12,w22,b12,b22,xt2 = tuple(map(mkgrad, chks))

In [15]:
def forward(inp, targ):
    # forward pass:
    l1 = lin(inp, w12, b12)
    l2 = relu(l1)
    out = lin(l2, w22, b22)
    return mse(out,targ)

In [16]:
loss = forward(xt2 , y_train)
loss.backward()

In [17]:
w12.grad

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

## layers as class

In [18]:
class Relu():
    def __call__(self,inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)
        return self.out
    def backward(self):
        self.inp.g = (self.inp>0).float() * self.out.g


In [24]:
class Lin():
    def __init__(self, w,b):
        self.w , self.b =w,b
    def __call__(self,inp):
        self.inp=inp
        self.out = lin(inp,self.w,self.b)
        return self.out
    def backward(self):
        self.inp.g = self.out.g * self.w.T
        self.w.g = self.inp.T @ self.out.g
        self.b.g = self.out.g
        

In [29]:
class Mse():
    def __call__(self,out,tar):
        self.out = out
        self.target = tar
        self.diff = mse(out,tar)
        return self.diff
    def backward(self):
        self.inp.g = 2. * (self.inp.unsqueeze() - self.tar.unsqueeze(-1)) / self.out.shape


In [30]:
class Model():
    def __init__(self,w1,b1,w2,b2):
        self.layers = [Lin(w1,b1) , Relu() , Lin(w2,b2)]
        self.loss=Mse()
    def __call__(self ,x,target):
        for l in self.layers:
            x = l(x)
        return self.loss(x , target)
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [31]:
model = Model(w1,b1,w2,b2)

In [32]:
loss = model(x_valid , y_valid)

torch.Size([10000, 784])


class Relu():
def __call__(self,inp):
self.inp = inp
self.out = inp.clamp_min(0.)
return self.out


### nn.Module

In [33]:
from typing import Any


class Module():
    def __call__(self, *args):
        self.args=args
        self.out = self.forward(*args)
        return self.out
    def forward(self):
        raise Exception("not implemented")
    def backward(self):
        self.bwd(self.out , *self.args)
    def bwd(self):
        raise Exception("not implemented")
    


In [57]:
class Relu(Module):
    def forward(self,inp):
        self.out = inp.clamp_min(0.)
        return self.out
    
    def bwd(self ,out,inp):
        inp.g = (inp>0).float() * self.out.g

In [85]:
class Lin(Module):
    def __init__(self, w,b):
        self.w , self.b =w,b
    def forward(self,inp):
        self.out = lin(inp,self.w,self.b)
        return self.out
    def bwd(self ,out,inp):
        print(self.w.shape,out.g.shape,inp.shape)
        inp.g = self.out.g @ self.w.T
        self.w.g = inp.T @ self.out.g
        self.b.g = self.out.g.sum(0)

In [86]:
class Mse(Module):
    def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()
    def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]
     

In [87]:
class Model():
    def __init__(self,w1,b1,w2,b2):
        self.layers = [Lin(w1,b1) , Relu() , Lin(w2,b2)]
        self.loss=Mse()
    def __call__(self ,x,target):
        for l in self.layers:
            x = l(x)
        return self.loss(x , target)
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

In [88]:
model = Model(w1,b1,w2,b2)

In [89]:
loss = model(x_train , y_train)
loss

tensor(4308.76)

In [90]:
model.backward()

torch.Size([50, 1]) torch.Size([50000, 1]) torch.Size([50000, 50])
torch.Size([784, 50]) torch.Size([50000, 50]) torch.Size([50000, 784])


In [91]:
from torch import nn
import torch.nn.functional as F

In [92]:
class Linear(nn.Module):
    def __init__(self , n_in , n_out):
        super().__init__()
        self.w = torch.randn(n_in , n_out).requires_grad_()
        self.b = torch.zeros(n_out).requires_grad_()
    def forward(self,inp):
        return inp@self.w +self.b
        