In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.nn.functional import conv2d, max_pool2d, cross_entropy

plt.rc("figure", dpi=100)

batch_size = 100

# transform images into normalized tensors
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = datasets.MNIST(
    "./",
    download=True,
    train=True,
    transform=transform,
)

test_dataset = datasets.MNIST(
    "./",
    download=True,
    train=False,
    transform=transform,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=1,
    pin_memory=True,
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=1,
    pin_memory=True,
)

def init_weights(shape):
    # Kaiming He initialization (a good initialization is important)
    # https://arxiv.org/abs/1502.01852
    std = np.sqrt(2. / shape[0])
    w = torch.randn(size=shape) * std
    w.requires_grad = True
    return w


def rectify(x):
    # Rectified Linear Unit (ReLU)
    return torch.max(torch.zeros_like(x), x)


class RMSprop(optim.Optimizer):
    """
    This is a reduced version of the PyTorch internal RMSprop optimizer
    It serves here as an example
    """
    def __init__(self, params, lr=1e-3, alpha=0.5, eps=1e-8):
        defaults = dict(lr=lr, alpha=alpha, eps=eps)
        super(RMSprop, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # state initialization
                if len(state) == 0:
                    state['square_avg'] = torch.zeros_like(p.data)

                square_avg = state['square_avg']
                alpha = group['alpha']

                # update running averages
                square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
                avg = square_avg.sqrt().add_(group['eps'])

                # gradient update
                p.data.addcdiv_(grad, avg, value=-group['lr'])

In [None]:
def dropout(X, p_drop=0.5):
    if 0 < p_drop < 1:
        mask = torch.bernoulli(torch.full(X.shape, 1 - p_drop))
        X_drop = torch.where(mask == 1, torch.zeros_like(X), X) / (1 - p_drop)
        return X_drop
    else:
        return X
    
def convolution_layer(previous_layer, weightvector, p_drop):
    convolutional_layer = rectify(conv2d(previous_layer, weightvector))
    
    # reduces (2 ,2) window to 1 pixel
    subsample_layer = max_pool2d(convolutional_layer, (2, 2))
    out_layer = dropout(subsample_layer, p_drop_input)
    
    return out_layer

def model(x, w_conv_1, w_conv_2, w_conv_3, w_h2, w_o, p_drop_input, p_drop_hidden):
    #print(x.shape) # [100, 1, 28, 28]
    h_conv_1 = convolution_layer(x, w_conv_1, p_drop_input)
    #print(h_conv_1.shape) # [100, 32, 12, 12]
    h_conv_2 = convolution_layer(h_conv_1, w_conv_2, p_drop_hidden)
    #print(h_conv_2.shape) # [100, 64, 4, 4]
    h_conv_3 = convolution_layer(h_conv_2, w_conv_3, p_drop_hidden)
    #print(h_conv_3.shape)  # [100, 128, 1, 1]
    h_conv_3 = h_conv_3.reshape(100, -1)
    #print(h_conv_3.shape)  # [100, 128]
    h2 = rectify(h_conv_3 @ w_h2)
    pre_softmax = h2 @ w_o
    return pre_softmax

def test_model(x, w_conv_1, w_conv_2, w_conv_3, w_h2, w_o):
    #dropout of -1 means no dropout
    h_conv_1 = convolution_layer(x, w_conv_1, -1)
    h_conv_2 = convolution_layer(h_conv_1, w_conv_2, -1)
    h_conv_3 = convolution_layer(h_conv_2, w_conv_3, -1)
    h_conv_3 = h_conv_3.reshape(100, -1)

    h2 = rectify(h_conv_3 @ w_h2)
    pre_softmax = h2 @ w_o
    return pre_softmax

# initialize weights

w_conv_1 = init_weights((32, 1, 5, 5))
w_conv_2 = init_weights((64, 32, 5, 5))
w_conv_3 = init_weights((128, 64, 2, 2))

number_of_output_pixel = 128
# hidden layer with 625 neurons
w_h2 = init_weights((number_of_output_pixel, 625))
# hidden layer with 625 neurons
w_o = init_weights((625, 10))
# output shape is (B, 10)

optimizer = RMSprop(params=[w_conv_1, w_conv_2, w_conv_3, w_h2, w_o])

In [None]:
n_epochs = 100
p_drop_input = 0.4
p_drop_hidden = 0.4
train_loss = []
test_loss = []

# put this into a training loop over 100 epochs
for epoch in range(n_epochs + 1):
    train_loss_this_epoch = []
    for idx, batch in enumerate(train_dataloader):
        x, y = batch
        #print(x.shape)
        # our model requires flattened input
        #x = x.reshape(batch_size, 784)
        # feed input through model
        noise_py_x = model(x, w_conv_1, w_conv_2, w_conv_3, w_h2, w_o, p_drop_input, p_drop_hidden)

        # reset the gradient
        optimizer.zero_grad()

        # the cross-entropy loss function already contains the softmax
        loss = cross_entropy(noise_py_x, y, reduction="mean")

        train_loss_this_epoch.append(float(loss))

        # compute the gradient
        loss.backward()
        # update weights
        optimizer.step()

    train_loss.append(np.mean(train_loss_this_epoch))

    # test periodically
    if epoch % 10 == 0:
        print(f"Epoch: {epoch}")
        print(f"Mean Train Loss: {train_loss[-1]:.2e}")
        test_loss_this_epoch = []

        # no need to compute gradients for validation
        with torch.no_grad():
            for idx, batch in enumerate(test_dataloader):
                x, y = batch
                #x = x.reshape(batch_size, 784)
                noise_py_x = test_model(x, w_conv_1, w_conv_2, w_conv_3, w_h2, w_o)

                loss = cross_entropy(noise_py_x, y, reduction="mean")
                test_loss_this_epoch.append(float(loss))

        test_loss.append(np.mean(test_loss_this_epoch))

        print(f"Mean Test Loss:  {test_loss[-1]:.2e}")

plt.plot(np.arange(n_epochs + 1), train_loss, label="Train")
plt.plot(np.arange(1, n_epochs + 2, 10), test_loss, label="Test")
plt.title("Train and Test Loss over Training")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()