# ResNet-18 with MNIST

## Dependencies

In [1]:
# Basic
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

In [2]:
# PyTorch
import torch
import torchvision
from torchvision import transforms
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

## Configs

In [15]:
# configs
config = {
    "data": {
        "path": "../data",
        "in_channels": 1,
        "num_classes": 10,
        "batch_size": 512
    },
    "model": {
        
    },
    "training":{
        "learning_rate": 1e-5,
        "optimizer": "sgd",
        "epochs": 5
    }
}

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Data

In [5]:
# transform for MNIST 
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # specific for MNIST
])

In [6]:
train_dataset = torchvision.datasets.MNIST(
    root = config['data']['path'],
    train = True,
    transform = transform,
    download = False
)

test_dataset = torchvision.datasets.MNIST(
    root = config['data']['path'],
    train = False,
    transform = transform
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████| 9912422/9912422 [00:00<00:00, 57797747.13it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|███████████████████████████████| 28881/28881 [00:00<00:00, 26396969.67it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|███████████████████████████| 1648877/1648877 [00:00<00:00, 27105729.67it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████| 4542/4542 [00:00<00:00, 4836387.10it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [7]:
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = config['data']['batch_size'],
    shuffle = True
)
test_loader = torch.utils.data.DataLoader(
    dataset = test_dataset,
    batch_size = config['data']['batch_size'],
    shuffle=False
)

## Model

In [8]:
class BasicBlock(nn.Module):
    expansion = 1
    
    def __init__(self, in_channels, channels, stride=1):
        super(BasicBlock, self).__init__()
        self.expansion = 1
        self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion*channels:
            # modify shortcut for correct channel output
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion*channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * channels)
            )
            
    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out
    
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels, num_classes):
        super(ResNet, self).__init__()
        self.in_channels = in_channels
        
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.in_channels = 64 # modify as it after first channel
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)
        
        
    def _make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        
        for stride in strides:
            layers.append(block(self.in_channels, channels, stride))
            self.in_channels = channels * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = nn.AdaptiveAvgPool2d(1)(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [9]:
model = ResNet(
    BasicBlock, 
    num_blocks = [2, 2, 2, 2], 
    in_channels = config['data']['in_channels'],
    num_classes = config['data']['num_classes']
).to(device)

## Training

In [10]:
# helper functions for training
def calculate_acc(logits, labels):
    """
    Given logits and correct labels:
    Return number of corrections and accuracy
    """
    pred = logits.argmax(dim=1, keepdim=True)
    correct = pred.eq(labels.view_as(pred)).sum().detach().cpu().item()
    return correct, 100 * correct / len(inputs)

In [11]:
criterion = nn.CrossEntropyLoss()
if config['training']['optimizer'] == 'sgd':
    optimizer = torch.optim.Adam(model.parameters(), lr=config['training']['learning_rate']) 
elif config['training']['topmizer'] == 'adam':
    optimizer = torch.optim.SGD(model.parameters(), lr=config['training']['learning_rate'])
# lr_scheduler = 

In [12]:
# TRAINING

# Assitant List for showing live loss/acc
live_losses = []
live_accs = []

epoch_describer = tqdm(range(config['training']['epochs']), desc=f"Epoch {1}", ncols=100)

training_losses = []
training_accs = []
validation_losses = []
validation_accs = []
    
for epoch in epoch_describer:
    
    training_loss = 0.0
    training_corr = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        correct, accuracy = calculate_acc(outputs, labels)
        
        training_corr += correct
        training_loss += loss.detach().cpu().item()
        
        live_losses.append(loss.detach().cpu().item())
        live_accs.append(accuracy)
        if len(live_losses) > 100:
            live_losses.pop(0)
            live_accs.pop(0)
            
        # update 
        epoch_describer.\
        set_description(f"Epoch {epoch+1} [loss:{np.mean(live_losses):.2f}, acc:{np.mean(live_accs):.2f}]")
        
    # record training data per epoch
    training_losses.append(training_loss / len(train_loader))
    training_accs.append(training_corr / len(train_loader.dataset))
        
    # validate the result
    valid_loss = 0.0
    valid_corr = 0.0
    
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            correct, accuracy = calculate_acc(outputs, labels)
            
            valid_loss += loss.detach().cpu().item()
            valid_corr += correct
            
    validation_losses.append(valid_loss / len(test_loader))
    validation_accs.append(valid_corr / len(test_loader.dataset))                                         

Epoch 5 [loss:0.04, acc:99.28]: 100%|█████████████████████████████████| 5/5 [02:38<00:00, 31.68s/it]


In [13]:
training_accs

[0.7413833333333333,
 0.9613166666666667,
 0.97995,
 0.9874166666666667,
 0.9926666666666667]

In [14]:
validation_accs

[0.9458, 0.9736, 0.9814, 0.9854, 0.9868]