### Multi Layer Perceptron - forward and backward pass

In [2]:
import pickle
import gzip
import math
import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from torch import tensor

In [3]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [4]:
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['figure.figsize'] = [4.0, 3.0]
plt.rcParams['font.size'] = 8

In [5]:
torch.manual_seed(11)
torch.set_printoptions(precision=3, linewidth=140, sci_mode=False)
np.set_printoptions(precision=3, linewidth=140)

In [6]:
path_data = Path('data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: 
    ((x_train, y_train),(x_test, y_test),_) = pickle.load(f, encoding='latin-1')

In [7]:
x_train, y_train, x_test, y_test = map(tensor,[x_train, y_train, x_test, y_test])

In [8]:
u = [5,3,7]
v = 8

tuple(zip([v]+u, u))

((8, 5), (5, 3), (3, 7))

In [11]:
class MLP:
    def __init__(self, nx, nl):
        self.nx = nx
        self.nl = nl
        self.build_nn()
    
    def build_nn(self):
        self.W = []
        self.B = []
        for j,k in zip([self.nx] + self.nl, self.nl):
            self.W.append(torch.randn(j,k))
            self.B.append(torch.randn(k))
    
    def backwards(self):
        pass
        
    def forward(self, x): 
        for w, b in zip(self.W, self.B):
            x = x@w + b
        return x

In [12]:
x_train[:5].shape

torch.Size([5, 784])

In [13]:
nl = [256, 64, 10]

In [14]:
model = MLP(x_train.shape[1], nl)

In [15]:
model.forward(x_train[:5])

tensor([[ 2592.791,  1274.272,   144.264,  2289.282,   928.843,  1150.213,   932.392,    18.566,  -166.792,   454.910],
        [ 1982.107,  1113.323, -1364.185,  1635.586,   921.450,   765.490,  1228.645,   783.224,   125.286,  1622.450],
        [ 1059.275,  -195.095,   -35.324,   416.384,    24.821,  1962.227,  -367.846, -1617.746, -1111.161,  1213.927],
        [-1018.900,  -360.771,  2017.314,  1833.248,   991.787,   740.458,   717.676,   805.373,  1263.424,  1532.452],
        [  700.308,  1704.129,  1577.885,   871.914,   127.479,   668.522,    -9.767, -1963.807,   452.376,  2143.573]])