### Assignment

Build a 2-layer MLP for MNIST digit classfication. Feel free to play around with the model architecture and see how the training time/performance changes, but to begin, try the following:

Image (784 dimensions) ->  
fully connected layer (500 hidden units) -> nonlinearity (ReLU) ->  
fully connected (10 hidden units) -> softmax

Try building the model both with basic PyTorch operations, and then again with more object-oriented higher-level APIs. 
You should get similar results!


*Some hints*:
- Even as we add additional layers, we still only require a single optimizer to learn the parameters.
Just make sure to pass all parameters to it!
- As you'll calculate in the Short Answer, this MLP model has many more parameters than the logisitic regression example, which makes it more challenging to learn.
To get the best performance, you may want to play with the learning rate and increase the number of training epochs.
- Be careful using `torch.nn.CrossEntropyLoss()`. 
If you look at the [PyTorch documentation](https://pytorch.org/docs/stable/nn.html#crossentropyloss): you'll see that `torch.nn.CrossEntropyLoss()` combines the softmax operation with the cross-entropy.
This means you need to pass in the logits (predictions pre-softmax) to this loss.
Computing the softmax separately and feeding the result into `torch.nn.CrossEntropyLoss()` will significantly degrade your model's performance!

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from tqdm.notebook import tqdm

# Check if MPS (Apple Silicon GPU) is available
#if torch.backends.mps.is_available():
#    device = torch.device('mps')
#elif torch.cuda.is_available():
#    device = torch.device('cuda')
#else:
device = torch.device('cpu')

print(f'Using device: {device}')

# Load the data
mnist_train = datasets.MNIST(root="./datasets", train=True, transform=transforms.ToTensor(), download=True)
mnist_test = datasets.MNIST(root="./datasets", train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=100, shuffle=False)

## Training
# Initialize parameters
W = torch.randn(784, 500, device=device) / np.sqrt(784)
W.requires_grad_()
b = torch.zeros(500, device=device, requires_grad=True)

W1 = torch.randn(500, 10, device=device) / np.sqrt(500)
W1.requires_grad_()
b1 = torch.zeros(10, device=device, requires_grad=True)

# Optimizer
optimizer = torch.optim.SGD([W, b, W1, b1], lr=0.1)

# Iterate through train set minibatches
for images, labels in tqdm(train_loader):
    # Move data to device
    images, labels = images.to(device), labels.to(device)
    
    # Zero out the gradients
    optimizer.zero_grad()
    
    # Forward pass
    x = images.view(-1, 28*28)
    y_0 = torch.matmul(x, W) + b
    x_relu_F = F.relu(y_0)
    y_1 = torch.matmul(x_relu_F, W1) + b1

    cross_entropy = F.cross_entropy(y_1, labels)
    # Backward pass
    cross_entropy.backward()
    optimizer.step()

## Testing
correct = 0
total = len(mnist_test)

with torch.no_grad():
    # Iterate through test set minibatches
    for images, labels in tqdm(test_loader):
        # Move data to device
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        x = images.view(-1, 28*28)
        y_0 = torch.matmul(x, W) + b
        x_relu_F = F.relu(y_0)
        y_1 = torch.matmul(x_relu_F, W1) + b1
        
        predictions = torch.argmax(y_1, dim=1)
        correct += torch.sum((predictions == labels).float())
    
print('Test accuracy: {}'.format(correct / total))