In [14]:
import numpy as np
import matplotlib.pyplot as plt
import math

import torch
import torch.optim
import torch.functional as F

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms

from torch.nn.functional import conv2d, max_pool2d


mb_size = 100 # mini-batch size of 100


trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize((0.5, 0.5, 0.5),
                                                 (0.5, 0.5, 0.5))])


dataset = dset.MNIST("./", download = True,
                     train = True,
                     transform = trans)


dataloader = torch.utils.data.DataLoader(dataset, batch_size=mb_size,
                                         shuffle=True, num_workers=1,
                                         pin_memory=True)



def init_weights(shape):
    w = torch.randn(size=shape)*0.01
    w.requires_grad = True
    return w

def rectify(X):
    return torch.max(torch.zeros_like(X), X)


# you can also use torch.nn.functional.softmax on future sheets
def softmax(X):
    c = torch.max(X, dim=1)[0].reshape(mb_size, 1)
    # this avoids a blow up of the exponentials
    # but calculates the same formula
    stabelized = X-c
    exp = torch.exp(stabelized)
    return exp/torch.sum(exp, dim=1).reshape(mb_size, 1)


# this is an example as a reduced version of the pytorch internal RMSprop optimizer
class RMSprop(torch.optim.Optimizer):
    def __init__(self, params, lr=1e-3, alpha=0.9, eps=1e-8):
        defaults = dict(lr=lr, alpha=alpha, eps=eps)
        super(RMSprop, self).__init__(params, defaults)

    def step(self):
        for group in self.param_groups:
            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['square_avg'] = torch.zeros_like(p.data)

                square_avg = state['square_avg']
                alpha = group['alpha']

                # update running averages
                square_avg.mul_(alpha).addcmul_(1 - alpha, grad, grad)
                avg = square_avg.sqrt().add_(group['eps'])

                # gradient update
                p.data.addcdiv_(-group['lr'], grad, avg)


def model(X, w_h, w_h2, w_o, p_drop_input, p_drop_hidden):
    #X = dropout(X, p_drop_input)
    h = rectify(X @ w_h)
    #h_ = dropout(h, p_drop_hidden)
    h2 = rectify(h @ w_h2)
    #h2_ = dropout(h2, p_drop_hidden)
    pre_softmax = h2 @ w_o
    return pre_softmax.transpose(0,1)


w_h = init_weights((784, 625))
w_h2 = init_weights((625, 625))
w_o = init_weights((625, 10))

optimizer = RMSprop([w_h, w_h2, w_o])




# put this into a training loop over 100 epochs
for (_, (X, y)) in enumerate(dataloader, 0):
    noise_py_x = model(X.reshape(mb_size, 784), w_h, w_h2, w_o, 0.8, 0.7)
    print(np.shape(noise_py_x), np.shape(y))
    noise_py_x = noise_py_x.transpose(0,1)
    print(np.shape(noise_py_x), np.shape(y))
    cost = torch.nn.functional.cross_entropy(noise_py_x, y)
    cost.backward()
    print("Loss: {}".format(cost))
    optimizer.step()

torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.302578926086426
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.5175278186798096
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.312056064605713
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.3570504188537598
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.250384569168091
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.2449934482574463
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.1825501918792725
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 2.1023566722869873
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 1.9342840909957886
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 1

Loss: 0.3975578248500824
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.5971046686172485
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.5934029817581177
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.48539066314697266
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.5373384356498718
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.501455545425415
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.471505731344223
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.48631808161735535
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.5728931427001953
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.36840343475341797
torch.Size([10, 100]) torch.S

Loss: 0.5086359977722168
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.6703182458877563
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2059687376022339
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.37499481439590454
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.31844040751457214
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.20343074202537537
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2697656750679016
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3935304880142212
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3377548158168793
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3875012695789337
torch.Size([10, 100]) torch

Loss: 0.3510689437389374
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.44054344296455383
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.5392363667488098
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.4107312262058258
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.27933409810066223
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2602881193161011
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.36009272933006287
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.37063392996788025
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3176676630973816
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.4947277009487152
torch.Size([10, 100]) torc

torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.17307716608047485
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3525333106517792
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2738361954689026
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2582901418209076
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.491456001996994
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3042663633823395
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.29298943281173706
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3150053024291992
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.43709641695022583
torch.Size([10, 100]) torch.Size([100])
torch.Size([1

torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3484710454940796
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.3789461851119995
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.6037936210632324
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.7240298986434937
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2562873363494873
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.24521836638450623
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.15013593435287476
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.8571891784667969
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.34080180525779724
torch.Size([10, 100]) torch.Size([100])
torch.Size([

Loss: 0.35107576847076416
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.40302059054374695
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.4212396740913391
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2851231098175049
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.18388153612613678
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.31380903720855713
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.23038750886917114
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.27438822388648987
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.2899872958660126
torch.Size([10, 100]) torch.Size([100])
torch.Size([100, 10]) torch.Size([100])
Loss: 0.4996764361858368
torch.Size([10, 100]) to