In [None]:
import torch
from torch import nn
from torch import optim

from torch.utils.data.dataset import random_split
from torch.utils.data import DataLoader

from torchvision import datasets, transforms

In [None]:
x = torch.randn(5)
# print(x.cpu())
# print(x.cuda())

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
print(x.to(device))

In [None]:
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64,64),
    nn.ReLU(),
    nn.Dropout(0.1), # if we are overfitting, can regularize with Dropout
    nn.Linear(64, 10)
)
model = model.to(device=device)

In [None]:
# Define a more flexible mode
class ResMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(28*28, 64)
        self.l2 = nn.Linear(64, 64)
        self.l3 = nn.Linear(64, 10)
        self.do = nn.Dropout(0.1)

    def forward(self, x):
        h1 = nn.functional.relu(self.l1(x))
        h2 = nn.functional.relu(self.l2(h1))
        do = self.do(h2 + h1) # added residual connection
        logits = self.l3(do)
        return logits

model = model.to(device=device)

In [None]:
# define my optimizer
optimiser = optim.SGD(model.parameters(), lr=1e-2)

In [None]:
# define my loss
loss_fct = nn.CrossEntropyLoss()

In [None]:
# train/valid split
train_data = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train, batch_size=32)
val_loader = DataLoader(val, batch_size=32)

In [None]:
for batch in train_loader:
    #print(batch)
    x, y = batch[0].to(device), batch[1].to(device)
    print(x.shape, y.shape)
    print(x.device, y.device)
    break

In [None]:
# my training loop
nb_epochs = 5
for epoch in range(nb_epochs):
    print(f'Epoch {epoch+1}')

    # ======================================================================
    # training
    # ======================================================================
    model.train()
    losses = []
    accuracies = []
    for batch in train_loader:
        x, y = batch[0].to(device), batch[1].to(device)
        # x (batch_size, 1, 28, 28)
        
        batch_size = x.shape[0]
        x = x.view(batch_size, -1) # (batch_size, 1, 28, 28) -> (batch_size, 1*28*28)

        # 1 forward
        logits = model(x) # outputs before softmax (batch size, output size)
        #import pdb; pdb.set_trace()

        predictions = logits.detach().argmax(dim=1) # (batch size,)
        accuracies.extend(y.eq(predictions).float().tolist())

        # 2 compute the objective function (loss)
        J = loss_fct(logits, y)
        losses.append(J.item())

        # 3 cleaning the gradients
        model.zero_grad()
        # optimizer.zero_grad() does the same thing

        # 4 accumulate the partial defivatives of J wrt parameters
        J.backward()

        # 5 step in the opposite direction of the gradient
        optimiser.step()
        # manual gradient update could be done like that
        # with torch.no_grad():
        #   params = params - eta * params.grad

    print(f'\ttrain loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'\ttrain acc: {torch.tensor(accuracies).mean():.2f}')

    # ======================================================================
    # validation 
    # ======================================================================
    model.eval()
    losses = []
    accuracies = []
    for batch in val_loader:
        x, y = batch[0].to(device), batch[1].to(device)
        # x (batch_size, 1, 28, 28)
        
        batch_size = x.shape[0]
        x = x.view(batch_size, -1) # (batch_size, 1, 28, 28) -> (batch_size, 1*28*28)

        # 1 forward
        with torch.no_grad():
            logits = model(x) # outputs before softmax
        predictions = logits.detach().argmax(dim=1) # (batch size,)
        accuracies.append(y.eq(predictions).float().mean())

        # 2 compute the objective function (loss)
        J = loss_fct(logits, y)
        losses.append(J.item())
    print(f'\tvalid loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'\tvalid acc: {torch.tensor(accuracies).mean():.2f}')

    print()


In [None]:
"="*70