### Make Necessary Imports

In [1]:
import argparse

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import config
from models import Net
from utils import GrayscaleToRgb

If CUDA-enabled GPU isn't found, we run on CPU 

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Define train and val dataloaders

In [3]:
def create_dataloaders(batch_size):
    dataset = MNIST(config.DATA_DIR/'mnist', train=True, download=True,
                    transform=Compose([GrayscaleToRgb(), ToTensor()]))
    shuffled_indices = np.random.permutation(len(dataset))
    train_idx = shuffled_indices[:int(0.8*len(dataset))]
    val_idx = shuffled_indices[int(0.8*len(dataset)):]

    train_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True,
                              sampler=SubsetRandomSampler(train_idx),
                              num_workers=1, pin_memory=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, drop_last=False,
                            sampler=SubsetRandomSampler(val_idx),
                            num_workers=1, pin_memory=True)
    return train_loader, val_loader

### Define function to train the model for one epoch

In [4]:
def do_epoch(model, dataloader, criterion, optim=None):
    total_loss = 0
    total_accuracy = 0
    for x, y_true in tqdm(dataloader, leave=False):
        x, y_true = x.to(device), y_true.to(device)
        y_pred = model(x)
        loss = criterion(y_pred, y_true)

        if optim is not None:
            optim.zero_grad()
            loss.backward()
            optim.step()

        total_loss += loss.item()
        total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()
    mean_loss = total_loss / len(dataloader)
    mean_accuracy = total_accuracy / len(dataloader)

    return mean_loss, mean_accuracy

### Set necessary hyperparameters

In [5]:
batch_size = 64
epochs = 30

train_loader, val_loader = create_dataloaders(batch_size)

model = Net().to(device)
optim = torch.optim.Adam(model.parameters())
lr_schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, patience=1, verbose=True)
criterion = torch.nn.CrossEntropyLoss()

### Train the model

In [6]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

best_accuracy = 0
for epoch in range(1, epochs+1):
    model.train()
    train_loss, train_accuracy = do_epoch(model, train_loader, criterion, optim=optim)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    
    model.eval()
    with torch.no_grad():
        val_loss, val_accuracy = do_epoch(model, val_loader, criterion, optim=None)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
    tqdm.write(f'Epoch {epoch:03d}: train_loss={train_loss:.4f}, train_accuracy={train_accuracy:.4f} '
               f'val_loss={val_loss:.4f}, val_accuracy={val_accuracy:.4f}')

    if val_accuracy > best_accuracy:
        print('Saving model...')
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), 'trained_models/source.pt')

    lr_schedule.step(val_loss)
    

                                                                                                                                                    

Epoch 001: train_loss=0.5734, train_accuracy=0.8191 val_loss=0.1213, val_accuracy=0.9646
Saving model...


                                                                                                                                                    

Epoch 002: train_loss=0.2656, train_accuracy=0.9216 val_loss=0.0868, val_accuracy=0.9751
Saving model...


                                                                                                                                                    

Epoch 003: train_loss=0.2189, train_accuracy=0.9353 val_loss=0.0760, val_accuracy=0.9786
Saving model...


                                                                                                                                                    

Epoch 004: train_loss=0.1962, train_accuracy=0.9417 val_loss=0.0704, val_accuracy=0.9787
Saving model...


                                                                                                                                                    

Epoch 005: train_loss=0.1807, train_accuracy=0.9463 val_loss=0.0724, val_accuracy=0.9790
Saving model...


                                                                                                                                                    

Epoch 006: train_loss=0.1673, train_accuracy=0.9514 val_loss=0.0657, val_accuracy=0.9809
Saving model...


                                                                                                                                                    

Epoch 007: train_loss=0.1606, train_accuracy=0.9536 val_loss=0.0595, val_accuracy=0.9820
Saving model...


                                                                                                                                                    

Epoch 008: train_loss=0.1509, train_accuracy=0.9557 val_loss=0.0523, val_accuracy=0.9840
Saving model...


                                                                                                                                                    

Epoch 009: train_loss=0.1443, train_accuracy=0.9570 val_loss=0.0520, val_accuracy=0.9846
Saving model...


                                                                                                                                                    

Epoch 010: train_loss=0.1430, train_accuracy=0.9589 val_loss=0.0545, val_accuracy=0.9836


                                                                                                                                                    

Epoch 011: train_loss=0.1369, train_accuracy=0.9597 val_loss=0.0503, val_accuracy=0.9845


                                                                                                                                                    

Epoch 012: train_loss=0.1304, train_accuracy=0.9619 val_loss=0.0461, val_accuracy=0.9859
Saving model...


                                                                                                                                                    

Epoch 013: train_loss=0.1261, train_accuracy=0.9624 val_loss=0.0480, val_accuracy=0.9858


                                                                                                                                                    

Epoch 014: train_loss=0.1263, train_accuracy=0.9626 val_loss=0.0473, val_accuracy=0.9864
Saving model...
Epoch    13: reducing learning rate of group 0 to 1.0000e-04.


                                                                                                                                                    

Epoch 015: train_loss=0.1127, train_accuracy=0.9670 val_loss=0.0449, val_accuracy=0.9870
Saving model...


                                                                                                                                                    

Epoch 016: train_loss=0.1065, train_accuracy=0.9686 val_loss=0.0432, val_accuracy=0.9870


                                                                                                                                                    

Epoch 017: train_loss=0.1066, train_accuracy=0.9691 val_loss=0.0441, val_accuracy=0.9872
Saving model...


                                                                                                                                                    

Epoch 018: train_loss=0.1003, train_accuracy=0.9704 val_loss=0.0439, val_accuracy=0.9875
Saving model...
Epoch    17: reducing learning rate of group 0 to 1.0000e-05.


                                                                                                                                                    

Epoch 019: train_loss=0.1066, train_accuracy=0.9697 val_loss=0.0433, val_accuracy=0.9875


                                                                                                                                                    

Epoch 020: train_loss=0.0997, train_accuracy=0.9713 val_loss=0.0434, val_accuracy=0.9874
Epoch    19: reducing learning rate of group 0 to 1.0000e-06.


                                                                                                                                                    

Epoch 021: train_loss=0.0948, train_accuracy=0.9714 val_loss=0.0431, val_accuracy=0.9875


                                                                                                                                                    

Epoch 022: train_loss=0.1001, train_accuracy=0.9706 val_loss=0.0431, val_accuracy=0.9875


                                                                                                                                                    

Epoch 023: train_loss=0.0991, train_accuracy=0.9709 val_loss=0.0431, val_accuracy=0.9875


                                                                                                                                                    

Epoch 024: train_loss=0.0998, train_accuracy=0.9707 val_loss=0.0431, val_accuracy=0.9875


                                                                                                                                                    

Epoch 025: train_loss=0.1001, train_accuracy=0.9704 val_loss=0.0431, val_accuracy=0.9875


                                                                                                                                                    

Epoch 026: train_loss=0.1020, train_accuracy=0.9698 val_loss=0.0432, val_accuracy=0.9874
Epoch    25: reducing learning rate of group 0 to 1.0000e-07.


                                                                                                                                                    

Epoch 027: train_loss=0.1008, train_accuracy=0.9701 val_loss=0.0432, val_accuracy=0.9874


                                                                                                                                                    

Epoch 028: train_loss=0.0999, train_accuracy=0.9703 val_loss=0.0430, val_accuracy=0.9875


                                                                                                                                                    

Epoch 029: train_loss=0.1036, train_accuracy=0.9697 val_loss=0.0431, val_accuracy=0.9874


                                                                                                                                                    

Epoch 030: train_loss=0.0997, train_accuracy=0.9701 val_loss=0.0430, val_accuracy=0.9875
