In [30]:
import torch
import torchvision
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from typing import Tuple, List, Type, Dict, Any

In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [10]:
train_aug = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(size=(28, 28), scale=(0.5, 1.0), ratio=(0.9, 1.1)),
                                            torchvision.transforms.RandomRotation(degrees=15),
                                            torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.15,), (0.31,))])

test_aug = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                           torchvision.transforms.Normalize((0.15,), (0.31,))])

In [12]:
train_data = torchvision.datasets.MNIST(root='', download=True, train=True, transform=train_aug)
test_data = torchvision.datasets.MNIST(root='', download=True, train=False, transform=test_aug)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST\raw\train-images-idx3-ubyte.gz to MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST\raw\train-labels-idx1-ubyte.gz to MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST\raw\t10k-images-idx3-ubyte.gz to MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST\raw\t10k-labels-idx1-ubyte.gz to MNIST\raw



In [20]:
class Network(torch.nn.Module):
    def __init__(self, output_sizes='default', activation=torch.nn.ReLU):
        super(Network, self).__init__()
        if output_sizes == 'default':
            self.output_sizes = [512, 128, 64, 10]
        else:
            self.output_sizes = output_sizes
        self.act = activation
        
        self._make_net()
        
    def _make_net(self):
        blocks = []
        prev_size = 28*28
        
        blocks.append(torch.nn.Flatten())
        
        for size in self.output_sizes:
            blocks.append(torch.nn.Linear(in_features=prev_size, out_features=size))
            blocks.append(self.act())
            
            prev_size = size
            
        self.net = torch.nn.Sequential(*blocks[:-1])
        
    def forward(self, batch):
        return self.net(batch)

In [43]:
def train_single_epoch(model: torch.nn.Module,
                       optimizer: torch.optim.Optimizer,
                       train_dataloader: torch.utils.data.DataLoader,
                       loss_fn: torch.nn.Module,
                       tb_writer: SummaryWriter,
                       epoch: int):
    model.train()
    loss_value = 0
    
    with tqdm(total=len(train_dataloader)) as pbar:
        for step, (X, y) in enumerate(train_dataloader):
            X, y = X.to(device), y.to(device)
            
            prediction = model(X)
            loss = loss_fn(prediction, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            loss_value += loss.item()
            
            pbar.update()
            pbar.set_postfix({'loss - ':loss_value})
            
    for tag, param in model.named_parameters():
        tb_writer.add_histogram('grad/%s'%tag, param.grad.data.cpu().numpy(), epoch)
        tb_writer.add_histogram('weight/%s'%tag, param.data.cpu().numpy(), epoch)
    
    return loss_value

In [44]:
def validate_single_epoch(model: torch.nn.Module,
                          loss_function: torch.nn.Module, 
                          val_dataloader: torch.utils.data.DataLoader):
    model.eval()
    size = len(val_dataloader.dataset)
    num_butches = len(val_dataloader)
    test_loss, correct = 0, 0
    
    with torch.no_grad():
        for X, y in val_dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_function(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
    test_loss /= num_butches
    correct /= size
    
    return {'loss': test_loss, 'accuracy' : correct}

In [45]:
def train_model(model: torch.nn.Module, 
                train_dataset: torch.utils.data.Dataset,
                val_dataset: torch.utils.data.Dataset,
                loss_function: torch.nn.Module = torch.nn.CrossEntropyLoss(),
                optimizer_class: Type[torch.optim.Optimizer] = torch.optim.Adam,
                optimizer_params: Dict = {},
                lr_scheduler_class: Any = torch.optim.lr_scheduler.ReduceLROnPlateau,
                lr_scheduler_params: Dict = {},
                batch_size = 64,
                max_epochs = 100,
                early_stopping_patience = 10,
                is_save = False):

    optimizer = optimizer_class(model.parameters(), **optimizer_params)   
    lr_scheduler = lr_scheduler_class(optimizer, **lr_scheduler_params)
    
    train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size)
    
    tb_writer = SummaryWriter()

    best_val_loss = None
    best_epoch = 0
    loss_history = []
    
    for epoch in tqdm(range(max_epochs)):
        train_loss_for_epoch = train_single_epoch(model, optimizer, train_loader, loss_function, tb_writer, epoch)
        loss_history.append(train_loss_for_epoch)
        val_metrics = validate_single_epoch(model, loss_function, val_loader)
        
        tb_writer.add_scalar('train_loss', train_loss_for_epoch, epoch)
        tb_writer.add_scalar('val_loss', val_metrics['loss'], epoch)
        tb_writer.add_scalar('val_accuracy', val_metrics['accuracy'], epoch)
        
        print(f'Validation metrics: \n{val_metrics}')

        lr_scheduler.step(val_metrics['loss'])
        
        if is_save and (best_val_loss is None or best_val_loss > val_metrics['loss']):
            print(f'Best model yet, saving')
            best_val_loss = val_metrics['loss']
            best_epoch = epoch
            torch.save(model, './best_model.pth')

    return loss_history

In [49]:
net = Network().to(device)
print(net)
print('Total number of trainable parameters : {}'.format(sum(parameter.numel() for parameter in net.parameters() if parameter.requires_grad)))

Network(
  (net): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=784, out_features=512, bias=True)
    (2): ReLU()
    (3): Linear(in_features=512, out_features=128, bias=True)
    (4): ReLU()
    (5): Linear(in_features=128, out_features=64, bias=True)
    (6): ReLU()
    (7): Linear(in_features=64, out_features=10, bias=True)
  )
)
Total number of trainable parameters : 476490


In [50]:
train_model(net,
           train_data,
           test_data,
           max_epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.1472013020104948, 'accuracy': 0.9546}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.10115171148245976, 'accuracy': 0.9672}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.08670728280083692, 'accuracy': 0.9706}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.0856634212804626, 'accuracy': 0.9726}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.06362745913255746, 'accuracy': 0.9796}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.06373749558063473, 'accuracy': 0.9778}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.06722666027734744, 'accuracy': 0.9777}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.07851365542671032, 'accuracy': 0.9744}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.06860259617678821, 'accuracy': 0.9797}


  0%|          | 0/938 [00:00<?, ?it/s]

Validation metrics: 
{'loss': 0.05766284259300898, 'accuracy': 0.9817}


[466.558920674026,
 217.4366431310773,
 176.32391147129238,
 158.9316128231585,
 142.55198414996266,
 133.33365505374968,
 128.31265699584037,
 120.78331385320053,
 117.179966757074,
 113.21249551605433]