# Imports

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
import torch.nn.functional as F

import numpy as np
from sklearn.manifold import TSNE

import altair as alt
alt.data_transformers.disable_max_rows()
import pandas as pd


def plot_tsne(tsne_xy, dataloader, num_points=1000, darkmode=True):
    # import IPython # Try to automatically detect darkmode - colab is blocking my DOM request
    # # html[theme=dark]
    # js_code = r'document.documentElement.getAttribute("theme");'
    # display(IPython.display.Javascript(js_code))

    images, labels = zip(*[(x[0].numpy()[0,:,:,None], x[1]) for x in dataloader.dataset])

    num_points = min(num_points, len(labels))
    data = pd.DataFrame({'x':tsne_xy[:, 0], 'y':tsne_xy[:, 1], 'label':labels,
                        'image': images})
    data = data.sample(n=num_points, replace=False)

    alt.renderers.set_embed_options(theme='dark' if darkmode else 'light')
    selection = alt.selection_single(on='mouseover', clear='false', nearest=True,
                                    init={'x':data['x'][data.index[0]], 'y':data['y'][data.index[0]]})
    scatter = alt.Chart(data).mark_circle().encode(
        alt.X('x:N',axis=None),
        alt.Y('y:N',axis=None),
        color=alt.condition(selection,
                            alt.value('lightgray'),
                            alt.Color('label:N')),
        # shape= alt.Shape('label:N', condition=selection,scale=alt.Scale(range=['circle','diamond'])),
        size=alt.value(100),
        tooltip='label:N'
    ).add_selection(
        selection
    ).properties(
        width=400,
        height=400
    )

    digit  = alt.Chart(data).transform_filter(
        selection
    ).transform_window(
        index='count()'           # number each of the images
    ).transform_flatten(
        ['image']                 # extract rows from each image
    ).transform_window(
        row='count()',            # number the rows...
        groupby=['index']         # ...within each image
    ).transform_flatten(
        ['image']                 # extract the values from each row
    ).transform_window(
        column='count()',         # number the columns...
        groupby=['index', 'row']  # ...within each row & image
    ).mark_rect(stroke='black',strokeWidth=0).encode(
        alt.X('column:O', axis=None),
        alt.Y('row:O', axis=None),
        alt.Color('image:Q',sort='descending',
            scale=alt.Scale(scheme=alt.SchemeParams('darkblue' if darkmode else 'lightgreyteal',
                            extent=[1, 0]),

            ),
            legend=None
        ),
    ).properties(
        width=400,
        height=400,
    )

    return scatter | digit

# Load Data Outcome Variables:
- **train_loader**
- **test_loader**

In [2]:
def MNIST_loaders(train_batch_size=50000, test_batch_size=10000):

    transform = Compose([
        ToTensor(),
        Normalize((0.1307,), (0.3081,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)

    return train_loader, test_loader

torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 30982373.13it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1996797.06it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3209846.99it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3220170.52it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# BP Baseline

In [3]:
x, y = next(iter(train_loader))
x_te, y_te = next(iter(test_loader))

In [4]:

class BPNet(torch.nn.Module):
    def __init__(self, dims, epoch):
        super(BPNet, self).__init__()
        self.model_linear = nn.Sequential(nn.Linear(dims[0],dims[1]), nn.ReLU(),
                                          nn.Linear(dims[1],dims[2]), nn.ReLU(),
                                          nn.Linear(dims[2],dims[3]), nn.ReLU(),
                                          nn.Linear(dims[3],dims[4]))
        self.opt = torch.optim.Adam(self.model_linear.parameters(), lr=0.01)
        # self.opt = torch.optim.SGD(self.model_linear.parameters(), lr=0.01)

        self.epochs = epoch
        self.loss_func = F.cross_entropy
        self.batch_size = 240

    def train(self, x, y):
        batch_size = self.batch_size
        model_linear = self.model_linear
        loss_func = self.loss_func
        opt = self.opt
        for epoch in range(self.epochs):
            for i in tqdm(range((x.shape[0]-1)//batch_size + 1)):
                start_i = i * batch_size
                end_i = start_i + batch_size
                xb = x[start_i:end_i]
                yb = y[start_i:end_i]
                pred = model_linear(xb)
                loss = loss_func(pred, yb)
                loss.backward()
                opt.step() # Updating weights.
                opt.zero_grad()

    def acc(self, x, y):
        y_pred = self.model_linear(x)
        return (torch.argmax(y_pred, dim=1) == y).float().mean()
net_BP = BPNet([784, 500, 300, 300, 10], 1)
print(net_BP)

BPNet(
  (model_linear): Sequential(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): ReLU()
    (2): Linear(in_features=500, out_features=300, bias=True)
    (3): ReLU()
    (4): Linear(in_features=300, out_features=300, bias=True)
    (5): ReLU()
    (6): Linear(in_features=300, out_features=10, bias=True)
  )
)


In [5]:
acc_trs, acc_tes = [], []
for i in range(5):
    net_BP = BPNet([784, 500, 300, 300, 10], 1)
    net_BP.train(x, F.one_hot(y).float())
    acc_trs.append(net_BP.acc(x, y).item())
    acc_tes.append(net_BP.acc(x_te, y_te).item())

100%|██████████| 209/209 [00:05<00:00, 39.86it/s]
100%|██████████| 209/209 [00:04<00:00, 50.68it/s]
100%|██████████| 209/209 [00:04<00:00, 45.81it/s]
100%|██████████| 209/209 [00:04<00:00, 47.92it/s]
100%|██████████| 209/209 [00:04<00:00, 48.80it/s]


In [None]:
print(np.mean(acc_trs)*100, np.std(acc_trs)*100)
print(np.mean(acc_tes)*100, np.std(acc_tes)*100)

94.98999953269958 0.36118669944977455
94.36199903488159 0.4591886793352151
