# Let us implement a 3-layer network for MNIST and the cross-entropy loss

In [9]:
import os, sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

### Define the network, instantiate it and choose the loss:

### Make a class of 3-layer network:

In [4]:
class three_layer_net(nn.Module):
    
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size):
        super(three_layer_net, self).__init__()
        
        self.layer1 = nn.Linear(input_size, hidden_size1, bias = False)
        self.layer2 = nn.Linear(hidden_size1, hidden_size2, bias = False)
        self.layer3 = nn.Linear(hidden_size2, output_size, bias = False)
        
    def forward(self, x):
        
        y = self.layer1(x)
        y_hat = F.relu(y)
        z = self.layer2(y_hat)
        z_hat = F.relu(z)
        scores = self.layer3(z_hat)
        
        return scores

### Create the network

Hyper-parameters:
- Input dim: 784
- 1st hidden layer: 1000
- 2nd hidden layer: 1500
- Output: 10

In [5]:
net = three_layer_net(784, 1000, 1500, 10)

### Define the criterion

In [8]:
criterion = nn.CrossEntropyLoss()

### Create an optimization function
- We choose our optimizer to be an SGD (there are other SGD-like optimizers
- The optimizer is given access to the parameters of the 3-layer network
- We choose the learning rate to be lr = 0.01

In [12]:
optimizer = torch.optim.SGD( net.parameters(), lr = 0.01)

### Training loop and backpropagation

In [None]:
bs = 200

In [None]:
for iter in range(5000):
    
    optimizer.zero_grad()
    
    # create the minibatch
    indices = torch.LongTensor(bs).random_(0, 60000)
    minibatch_data = train_data[indices]
    minibatch_label = train_label[indices]
    
    inputs = minibatch_data.view(bs, 784)
    
    # start recording all operations that will be done to the input tensor
    inputs.requires_grad_()
    
    # forward pass: compute the scores
    scores = net(inputs)
    
    # compute the cross-entropy criterion
    loss = criterion(scores, minibatch_label)
    
    # backward pass: compute the gradients
    loss.backward()
    
    # do one step of SGD to update the weights
    optimizer.step()