# RESNET

## The Wide ResNet Model


We'll be using a Wide Residual Network to train on this dataset, which is a convolutional neural network proven to perform very well in image classification challenges. Feel free to take some time to learn more about wide residual networks, the original residual networks they are based on, or about convolutional neural networks in general.


In the early days of CNNs, the community drove towards very deep models (many tens or hundreds of layers), but as computing power advanced and algorithms improved, in particular after the idea of the residual block was demonstrated, it became more desirable to swing back towards shallower networks with wider layers, which was the primary innovation of the WideResNet family of models. The WideResNet-16-10 we will use below can achieve with O(10 million) parameters accuracy that is competitive with much deeper networks with more parameters.


In [None]:
import torch
import torch.nn as nn
import numpy as np
import os
import time
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

## Standard convolution block followed by batch normalization 

In [None]:
class cbrblock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(cbrblock, self).__init__()
        self.cbr = nn.Sequential(nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=(1,1),
                                           padding='same', bias=False),
                                 nn.BatchNorm2d(output_channels),
                                 nn.ReLU()
                                 )
    def forward(self, x):
        out = self.cbr(x)
        return out

## Basic residual block

In [None]:
class conv_block(nn.Module):
    def __init__(self, input_channels, output_channels, scale_input):
        super(conv_block, self).__init__()
        self.scale_input = scale_input
        if self.scale_input:
            self.scale = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=(1,1),
                                   padding='same')
        self.layer1 = cbrblock(input_channels, output_channels)
        self.dropout = nn.Dropout(p=0.01)
        self.layer2 = cbrblock(output_channels, output_channels)

    def forward(self, x):
        residual = x
        out = self.layer1(x)
        out = self.dropout(out)
        out = self.layer2(out)
        if self.scale_input:
            residual = self.scale(residual)
        out = out + residual

        return out

## Model

In [None]:
class WideResNet(nn.Module):
    def __init__(self, num_classes):
        super(WideResNet, self).__init__()
        nChannels = [1, 16, 160, 320, 640]

        self.input_block = cbrblock(nChannels[0], nChannels[1])

        # Module with alternating components employing input scaling
        self.block1 = conv_block(nChannels[1], nChannels[2], 1)
        self.block2 = conv_block(nChannels[2], nChannels[2], 0)
        self.pool1 = nn.MaxPool2d(2)
        self.block3 = conv_block(nChannels[2], nChannels[3], 1)
        self.block4 = conv_block(nChannels[3], nChannels[3], 0)
        self.pool2 = nn.MaxPool2d(2)
        self.block5 = conv_block(nChannels[3], nChannels[4], 1)
        self.block6 = conv_block(nChannels[4], nChannels[4], 0)

        # Global average pooling
        self.pool = nn.AvgPool2d(7)

        # Feature flattening followed by linear layer
        self.flat = nn.Flatten()
        self.fc = nn.Linear(nChannels[4], num_classes)

    def forward(self, x):
        out = self.input_block(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.pool1(out)
        out = self.block3(out)
        out = self.block4(out)
        out = self.pool2(out)
        out = self.block5(out)
        out = self.block6(out)
        out = self.pool(out)
        out = self.flat(out)
        out = self.fc(out)

        return out

In [None]:
def train(model, optimizer, train_loader, loss_fn, device):
    model.train()
    for images, labels in train_loader:
        # Transfering images and labels to GPU if available
        labels = labels.to(device)
        images = images.to(device)

        # Forward pass 
        outputs = model(images)
        loss = loss_fn(outputs, labels)

        # Setting all parameter gradients to zero to avoid gradient accumulation
        optimizer.zero_grad()

        # Backward pass
        loss.backward()

        # Updating model parameters
        optimizer.step()


In [None]:
def test(model, test_loader, loss_fn, device):
    total_labels = 0
    correct_labels = 0
    loss_total = 0
    model.eval()
    
    with torch.no_grad():
        for images, labels in test_loader:
            # Transfering images and labels to GPU if available
            labels = labels.to(device)
            images = images.to(device)

            # Forward pass 
            outputs = model(images)
            loss = loss_fn(outputs, labels)

            # Extracting predicted label, and computing validation loss and validation accuracy
            predictions = torch.max(outputs, 1)[1]
            total_labels += len(labels)
            correct_labels += (predictions == labels).sum()
            loss_total += loss

    v_accuracy = correct_labels / total_labels
    v_loss = loss_total / len(test_loader)

    return v_accuracy, v_loss

## LOADIN DATASET

In [None]:
 print('==> Running main part of our training')
if torch.cuda.is_available():
    print('=> Our gpu type and uploaded driver')
    os.system("nvidia-smi --query-gpu=gpu_name,driver_version --format=csv")
else:
    print('==> There is no CUDA capable GPU device here!')

In [None]:
config = {
    "batch_size": 2048,
    "epochs": 40,
    "patience": 2,
    "target_accuracy": 0.85,
    "base_lr": 0.01,
}

In [None]:
 # transform the raw data into tensors
# Define transforms for the training and testing sets
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
test_set = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

# Train only on 1/5 of the dataset
# train_subset = torch.utils.data.Subset(train_set, list(range(0, 12000)))
# test_subset = torch.utils.data.Subset(test_set, list(range(0, 10000)))
train_subset = torch.utils.data.Subset(train_set, list(range(0, 10000)))
test_subset = torch.utils.data.Subset(test_set, list(range(0, 2000)))

# Training data loader
train_loader = torch.utils.data.DataLoader(train_subset,
                                           batch_size=config["batch_size"], drop_last=True)
# Validation data loader
test_loader = torch.utils.data.DataLoader(test_subset,
                                          batch_size=config["batch_size"], drop_last=True)

## INIT MODEL INSTANCE

In [None]:
 # Create the model and move to GPU device if available
num_classes = 10

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = WideResNet(num_classes).to(device)

# Define loss function
loss_fn = nn.CrossEntropyLoss()

# Define the SGD optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=config["base_lr"])

val_accuracy = []

# Make a change 3: initialize the variable total_time with a value of 0.
total_time = 0

## TRAIN LOOP

In [None]:
for epoch in range(config["epochs"]):

    # get the beginning
    t0 = time.time()
    os.system("nvidia-smi --query-gpu=temperature.gpu,utilization.gpu,utilization.memory --format=csv")
    train(model, optimizer, train_loader, loss_fn, device)

    epoch_time = time.time() - t0
    total_time += epoch_time

    images_per_sec = len(train_loader) * config["batch_size"] / epoch_time
    v_accuracy, v_loss = test(model, test_loader, loss_fn, device)

    val_accuracy.append(v_accuracy)

    print("Epoch = {:2d}: Epoch Time = {:5.3f}, Validation Loss = {:5.3f}, Validation Accuracy = {:5.3f}, Images/sec = {}, Cumulative Time = {:5.3f}".format(epoch+1, epoch_time, v_loss, val_accuracy[-1], images_per_sec, total_time))


    target_accuracy_indexes = list([i for i, e in enumerate(val_accuracy) if e > config["target_accuracy"]])
    if len(target_accuracy_indexes) > 0 and target_accuracy_indexes[0] + config["patience"] <= epoch:
        print(f"Early stopping on epoch {epoch+1}!")
        break