# Mini-batch stochastic gradient descent

In [None]:
import torch
import torchvision

## Hyperparameters

In [None]:
# Mini-batch SGD parameters
batch_size = 32
num_epochs = 4
learning_rate = 0.1

## 1. Prepare data

In [None]:
# !ls ../data
data_directory_path = '../data/'

In [None]:
# Create data loaders

mnist_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

# Training data loader
train_dataset = torchvision.datasets.MNIST(
    root=data_directory_path, train=True, download=True, transform=mnist_transforms
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# Validation data loader
valid_dataset = torchvision.datasets.MNIST(
    root=data_directory_path, train=False, download=True, transform=mnist_transforms
)

valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=len(valid_dataset), shuffle=True
)

In [None]:
import matplotlib.pyplot as plt

for something in train_loader:
#     print(something)
#     print(len(something))
#     print(something[0].shape, something[1].shape)
#     print(something[0][0].shape)
#     print(something[0][0].squeeze().shape)
    plt.imshow(something[0][0].squeeze())
    plt.title(something[1][0].item(), fontsize=32)
    break

## 2. Create a neural network

In [None]:
nx = 28 *28
ny = 10
model = torch.nn.Sequential(torch.nn.Linear(in_features=nx, out_features=ny))

## 3. Create the loss function

In [None]:
loss = torch.nn.CrossEntropyLoss()

## 4. Implement mini-batch stochastic gradient descent

In [None]:
for epoch in range(num_epochs):
    
    # Set model to training mode
    model.train()
    
    # Update the model for each batch
    train_count = 0
    train_cost = 0
    for X, y in train_loader:
        
        # Compute model cost
        yhat = model(X.view(-1, nx))
        cost = loss(yhat, y)
        
        # Compute gradients
        model.zero_grad()
        cost.backward()
        
        # Update parameters
        with torch.no_grad():
            for param in model.parameters():
                param -= learning_rate * param.grad
        
        train_count += X.shape[0]
        train_cost += cost.item()

    # Set model to evaluation mode
    model.eval()
    
    # Test model on validation data
    valid_count = 0
    valid_cost = 0
    valid_correct = 0
    with torch.no_grad():
        for X, y in valid_loader:

            # Compute model cost
            yhat = model(X.view(-1, nx))
            cost = loss(yhat, y)
            
            # Convert model output into discrete predictions
            predictions = yhat.argmax(dim=1, keepdim=True)
            
            # Compute number correct
            correct = predictions.eq(y.view_as(predictions)).double().sum().item()
            
            valid_count += X.shape[0]
            valid_cost += cost.item()
            valid_correct += correct
            
    train_cost /= train_count
    valid_cost /= valid_count
    valid_accuracy = valid_correct / valid_count
    
    print(epoch, train_cost, valid_cost, valid_accuracy)

print('Done.')