In [2]:
# Relevant imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader


import time
import matplotlib.pyplot as plt

## Implementing Resnet from scratch.

### For more info check out this paper: https://arxiv.org/abs/1512.03385 which introduces the architecture.

### High level:

Really deep architectures suffer in training performance, this means that the network is suffering for reasons other than overfitting. The reason is mostly due to exploding or vanishing gradients. Resnets handle this by including the input of a layer (the previous layer's output) as into the output of the previous layer such that if the gradient explodes or vanishes, the input is still added.

Below I reproduced the model and ran it on the CIFAR data set.


In [3]:

class BasicBlock(nn.Module):
  def __init__(self,in_channels,out_channels,stride=1, downsample=None):
    ## a basic residual block is defined as...
    ### convolution layer 1, batchnorm1 relu (or some other activation) convolution 2, batch norm, then downsample if needed.
    ### Note you will only down sample if stride ir padding !=1
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels,out_channels, stride=stride, kernel_size = 3,padding=1,bias=False)
    self.batchnorm1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv2d(out_channels,out_channels, stride=1, kernel_size = 3,padding=1,bias=False)
    self.batchnorm2 = nn.BatchNorm2d(out_channels)
    self.downsample = downsample




  def forward(self,x):
    identity = x
    out= self.conv1(x)
    out=self.batchnorm1(out)
    out= self.relu(out)
    out= self.conv2(out)
    out = self.batchnorm2(out)

    if self.downsample is not None:
      identity = self.downsample(x)

    out+=identity
    out = self.relu(out)
    return out

class ResNet18(nn.Module):
  def __init__(self):
    super(ResNet18, self).__init__()

    self.initial_conv = nn.Conv2d(3,64,kernel_size=3,stride=1,padding=1)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU(inplace=True)

    self.first_subgroup = self._make_layer(64, 64, stride=1)

    self.second_subgroup = self._make_layer(64, 128, stride=2)


    self.third_subgroup = self._make_layer(128, 256,stride=2)

    self.fourth_subgroup = self._make_layer(256, 512, stride=2)

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    self.linear = nn.Linear(512,10)

  def _make_layer(self, in_channels, out_channels, blocks=2, stride=1):
      downsample = None
      if stride != 1 or in_channels != out_channels:
          downsample = nn.Sequential(
              nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
              nn.BatchNorm2d(out_channels),
          )

      layers = []
      layers.append(BasicBlock(in_channels, out_channels, stride, downsample))
      for _ in range(1, blocks):
          layers.append(BasicBlock(out_channels, out_channels))

      return nn.Sequential(*layers)


  def forward(self,x):
    x = self.initial_conv(x)
    x = self.bn1(x)
    x= self.relu(x)
    x= self.first_subgroup(x)
    x= self.second_subgroup(x)
    x= self.third_subgroup(x)
    x= self.fourth_subgroup(x)
    x  = self.avgpool(x)
    x = torch.flatten(x, 1)
    output = self.linear(x)
    return output


In [5]:

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms

def create_dataloader(dataset, shuffle=True, batch_size=128, num_workers=2):
    return DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers
    )


## applying the transformations they did in the paper
def get_dataset(datapath, batch_size=128, num_workers=2):
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.2010))
    ])

    # Load the CIFAR-10 dataset with the transformations
    train_dataset = torchvision.datasets.CIFAR10(
        root=datapath,
        train=True,
        download=True,
        transform=transform
    )

    train_loader = create_dataloader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=num_workers
    )

    return train_loader



In [8]:
def train(args):

  device = torch.device(args.device)

  model = ResNet18().to(device)

  optimizer = args.optimizer

  if optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
  elif optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr,momentum=args.momentum, weight_decay=args.weight_decay)
  print(optimizer)

  # Load CIFAR10 data
  workers = args.device
  train_loader = get_dataset(args.datapath,args.batch_size,args.workers)

  # Training Loop
  epochs = args.epochs
  criterion = nn.CrossEntropyLoss()
  for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()  # Zero the gradients
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)


        output = model(batch_x)

        # Compute the loss
        loss = criterion(output, batch_y)

        # Backward pass: Compute gradients
        loss.backward()

        # Update model weights
        optimizer.step()

        # Accumulate the running loss
        running_loss += loss.item()

        # Predictions and accuracy
        _, predicted = torch.max(output, 1)  # Get the class with the highest score (Top-1)
        correct_predictions += (predicted == batch_y).sum().item()  # Count correct predictions
        total_predictions += batch_y.size(0)  # Total number of samples in this batch


    # Compute the average loss and accuracy for the epoch
    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct_predictions / total_predictions

    print({"Per epoch top 1":(accuracy)})
    print({"Per epoch average loss":(avg_loss)})




### Wanted to try argparse (to simulate linux machine running model)

In [9]:
import argparse

def main():
    parser = argparse.ArgumentParser(description="CIFAR-10 Training Example with DataLoader")

    parser.add_argument('--optimizer', type=str, choices=['adam', 'sgd'], default='adam', help='Optimizer type (adam or sgd)')
    parser.add_argument('--device', type=str, choices=['cuda', 'cpu'], default='cuda', help='Device to train on')
    parser.add_argument('--datapath', type=str, default='./data', help='Path to the dataset')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training')
    parser.add_argument('--workers', type=int, default=2, help='Number of data loading workers')
    parser.add_argument('--epochs', type=int, default=5, help='Number of data loading workers')
    parser.add_argument('--lr', type=float, default=0.1, help='Learning rate (default: 0.1)')
    parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD optimizer (default: 0.9)')
    parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay (L2 penalty) (default: 5e-4)')


    # Parse the arguments
    args = parser.parse_known_args()
    print(args[0])

    my_dataloader = create_dataloader(args[0].datapath, args[0].batch_size, args[0].workers)

    print(f"Using device: {args[0].device}")
    print(f"Using optimizer: {args[0].optimizer}")

    # Call Training
    train(args[0])




In [10]:
main()

Namespace(optimizer='adam', device='cuda', datapath='./data', batch_size=128, workers=2, epochs=5, lr=0.1, momentum=0.9, weight_decay=0.0005)
Using device: cuda
Using optimizer: adam
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.1
    maximize: False
    weight_decay: 0.0005
)
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12879375.69it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
{'Per epoch top 1': 22.262}
{'Per epoch average loss': 2.217677504815104}
{'Per epoch top 1': 29.288}
{'Per epoch average loss': 1.8525692733657328}
{'Per epoch top 1': 29.882}
{'Per epoch average loss': 1.8284326861886417}
{'Per epoch top 1': 30.526}
{'Per epoch average loss': 1.8183220164550236}
{'Per epoch top 1': 31.334}
{'Per epoch average loss': 1.8140286448056742}
