Train a ResNet18 model for classification on even classes of CIFAR-10

Import required libraries

In [1]:
#Import all required libraries
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from time import time
from torch import optim
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms, models
from tqdm import tqdm
import timeit
import math
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Install wandb library

In [2]:
%pip install wandb

Collecting wandb
  Downloading wandb-0.15.10-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.36-py3-none-any.whl (189 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m189.5/189.5 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.30.0-py2.py3-none-any.whl (218 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m218.8/218.8 kB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools (from wandb)
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.2-cp310-cp310-manylinux_2_5_x86_64.manyli

Import and login

In [3]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

Apply transforms

In [4]:
# Define the transforms for data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the Fashion MNIST dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 61957776.12it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


Extract only even classes from the dataset

In [5]:
#Changing the labels of even number to 10 for further operations
train_dataset.targets = torch.tensor(train_dataset.targets)
train_dataset.targets[train_dataset.targets ==1] = 10
train_dataset.targets[train_dataset.targets ==3] = 10
train_dataset.targets[train_dataset.targets ==5] = 10
train_dataset.targets[train_dataset.targets ==7] = 10
train_dataset.targets[train_dataset.targets ==9] = 10

#changing the labels of odd number starting from 0 to 4 to have good indexing and labeling for model
train_dataset.targets[train_dataset.targets ==0] = 0
train_dataset.targets[train_dataset.targets ==2] = 1
train_dataset.targets[train_dataset.targets ==4] = 2
train_dataset.targets[train_dataset.targets ==6] = 3
train_dataset.targets[train_dataset.targets ==8] = 4
print(train_dataset.targets)

tensor([ 3, 10, 10,  ..., 10, 10, 10])


In [6]:
#Changing the labels of even number to 10 for further operations
test_dataset.targets = torch.tensor(test_dataset.targets)
test_dataset.targets[test_dataset.targets ==1] = 10
test_dataset.targets[test_dataset.targets ==3] = 10
test_dataset.targets[test_dataset.targets ==5] = 10
test_dataset.targets[test_dataset.targets ==7] = 10
test_dataset.targets[test_dataset.targets ==9] = 10

#changing the labels of odd number starting from 0 to 4 to have good indexing and labeling for model
test_dataset.targets[test_dataset.targets ==0] = 0
test_dataset.targets[test_dataset.targets ==2] = 1
test_dataset.targets[test_dataset.targets ==4] = 2
test_dataset.targets[test_dataset.targets ==6] = 3
test_dataset.targets[test_dataset.targets ==8] = 4

print(test_dataset.targets)

tensor([10,  4,  4,  ..., 10, 10, 10])


In [7]:
#Create the mask to get required data from tha main dataset
#If the label is other than 10 (Remember the even number labels are changed to 10) then make that true for training and testing data
train_mask = (torch.tensor(train_dataset.targets) != 10)
test_mask = (torch.tensor(test_dataset.targets) != 10)

#Get the indices of the labels which are marked true
train_indices = train_mask.nonzero().reshape(-1)
test_indices =test_mask.nonzero().reshape(-1)

#Select the subset from the mask that is created above
trainset_2 = torch.utils.data.Subset(train_dataset, train_indices)
testset_2 = torch.utils.data.Subset(test_dataset, test_indices)

  train_mask = (torch.tensor(train_dataset.targets) != 10)
  test_mask = (torch.tensor(test_dataset.targets) != 10)


Get resnet18 pretrained model

In [8]:
# Load the ResNet18 model pre-trained on ImageNet
model = models.resnet18(pretrained=True)

# Modify the last fully connected layer to output 10 classes
num_classes = 5
model.fc = nn.Linear(512, num_classes)
#get  model to device
model=model.to(device)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 106MB/s]


loss function

In [9]:
# Make the loss and optimizer
criterion = nn.CrossEntropyLoss()

toplot the prediction of ini=dividual image on the weights and bias

In [10]:
def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(5)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)

Test function

In [11]:
def test_function(model, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += criterion(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(test_loader.dataset), correct / len(test_loader.dataset)

Train the model

In [12]:
def main_fun(model, optimizer):
  # Training
  example_ct = 0
  step_ct = 0
  total_epochs=30
  for epoch in range(total_epochs):
      model.train()
      train_acc=0
      for step, (images, labels) in enumerate(train_loader):
          images, labels = images.to(device), labels.to(device)

          outputs = model(images)
          train_loss = criterion(outputs, labels)
          optimizer.zero_grad()
          train_loss.backward()
          optimizer.step()

          # Compute accuracy and accumulate
          total=labels.size(0)
          _, tr_pred = torch.max(outputs.data, 1)
          tr_correct = (tr_pred == labels).sum().item()
          train_acc+=tr_correct
          example_ct += len(images)
          metrics = {"train/train_loss": train_loss,"train/train_accuracy": tr_correct/total}

          if step + 1 < n_steps_per_epoch:
              # Log train metrics to wandb
              wandb.log(metrics)

          step_ct += 1

      val_loss, accuracy = test_function(model, log_images=(epoch==(total_epochs-1)))

      # Log train and validation metrics to wandb
      val_metrics = {"val/val_loss": val_loss,
                      "val/val_accuracy": accuracy}
      wandb.log({**metrics, **val_metrics})

      print(f"Train accuracy: {train_acc/len(train_loader.dataset)}, Train Loss: {train_loss:.3f}, val Loss: {val_loss:3f}, val accuracy: {accuracy:.2f}")

  # If you had a test set, this is how you could log it as a Summary metric
  wandb.summary['test_accuracy'] = 0.8

  # Close your wandb run
  wandb.finish()

###Driver code

####Hyperparameter Set1

In [13]:
# Create data loaders for training and testing
batch_size = 128
train_loader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset_2, batch_size=batch_size, shuffle=False)
n_steps_per_epoch = math.ceil(len(train_loader.dataset) /batch_size)

In [14]:
# initialise a wandb run
wandb.init(
    project="resnet18_wandb_project",name="Hyper_param1",
)

[34m[1mwandb[0m: Currently logged in as: [33mavantivarude2000[0m ([33mavanti[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
main_fun(model, optimizer)

Train accuracy: 0.78524, Train Loss: 0.648, val Loss: 0.581873, val accuracy: 0.80
Train accuracy: 0.88156, Train Loss: 0.258, val Loss: 0.380121, val accuracy: 0.87
Train accuracy: 0.91628, Train Loss: 0.086, val Loss: 0.392378, val accuracy: 0.87
Train accuracy: 0.9388, Train Loss: 0.279, val Loss: 0.405331, val accuracy: 0.87
Train accuracy: 0.94992, Train Loss: 0.183, val Loss: 0.399076, val accuracy: 0.87
Train accuracy: 0.96284, Train Loss: 0.110, val Loss: 0.451942, val accuracy: 0.87
Train accuracy: 0.9668, Train Loss: 0.251, val Loss: 0.418286, val accuracy: 0.88
Train accuracy: 0.97392, Train Loss: 0.019, val Loss: 0.442639, val accuracy: 0.87
Train accuracy: 0.97668, Train Loss: 0.393, val Loss: 0.462765, val accuracy: 0.88
Train accuracy: 0.9752, Train Loss: 0.067, val Loss: 0.489602, val accuracy: 0.88
Train accuracy: 0.98284, Train Loss: 0.083, val Loss: 0.542273, val accuracy: 0.88
Train accuracy: 0.9828, Train Loss: 0.103, val Loss: 0.470788, val accuracy: 0.88
Train ac

0,1
train/train_accuracy,▁▃▃▅▆▄▇█▇▇▆▇█▇▇▇▇█████▇███▇▇██▇▇█▇██▇▇██
train/train_loss,█▇▆▅▃▅▃▂▂▂▃▂▂▂▃▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▂▂▁▂▂▁▁
val/val_accuracy,▁▆▆▆▆▆▇▆▇▇▇▇█▇▇▇▇█▇███▇▇█▇▇▇█▇
val/val_loss,▇▁▁▂▂▃▂▃▄▄▆▄▄▅▄▅▅▅▆▄▄▆▄▅▄▅█▅▅█

0,1
test_accuracy,0.8
train/train_accuracy,0.95
train/train_loss,0.16002
val/val_accuracy,0.876
val/val_loss,0.59874


####Hyperparameter set 2:

In [17]:
# Create data loaders for training and testing
batch_size = 64
train_loader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset_2, batch_size=batch_size, shuffle=False)
n_steps_per_epoch = math.ceil(len(train_loader.dataset) /batch_size)

wandb.init(
    project="resnet18_wandb_project",name="Hyper_param2",
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
main_fun(model, optimizer)

0,1
train/train_accuracy,█▆▃▁▁▄▄▄▂▁▁▁▄▃▃▁▄▂▆▃▃▁▃▂▂▂▃▅▂▂▃▃▂▃▄▂▁▅▂▄
train/train_loss,█▃▄▂▂▄▂▄▂▂▂▄▂▄▄▂▁▂▁▁▃▅▁▃▃▂▂▂▁▂▁▂▃▂▁▁▃▁▂▁

0,1
train/train_accuracy,0.35938
train/train_loss,1.61735


Train accuracy: 0.48208, Train Loss: 1.294, val Loss: 1.259439, val accuracy: 0.51
Train accuracy: 0.57688, Train Loss: 1.090, val Loss: 0.925271, val accuracy: 0.62
Train accuracy: 0.656, Train Loss: 0.909, val Loss: 0.888508, val accuracy: 0.68
Train accuracy: 0.71484, Train Loss: 0.473, val Loss: 0.699629, val accuracy: 0.74
Train accuracy: 0.75652, Train Loss: 0.720, val Loss: 0.620912, val accuracy: 0.77
Train accuracy: 0.78576, Train Loss: 0.475, val Loss: 0.644418, val accuracy: 0.76
Train accuracy: 0.80556, Train Loss: 0.496, val Loss: 0.649045, val accuracy: 0.75
Train accuracy: 0.82384, Train Loss: 0.421, val Loss: 0.661013, val accuracy: 0.76
Train accuracy: 0.84528, Train Loss: 0.350, val Loss: 0.529666, val accuracy: 0.81
Train accuracy: 0.86664, Train Loss: 0.369, val Loss: 0.554757, val accuracy: 0.81
Train accuracy: 0.87992, Train Loss: 0.356, val Loss: 0.508617, val accuracy: 0.82
Train accuracy: 0.8984, Train Loss: 0.262, val Loss: 0.552544, val accuracy: 0.81
Train a

0,1
train/train_accuracy,▁▃▄▄▆▇▆▅▆▅▆▆▇▇▆▇█▇▇▇▇█▇▇▇▇███▇▇█████▇███
train/train_loss,█▅▄▄▃▃▃▃▃▃▃▂▂▂▃▂▁▂▂▂▂▁▁▂▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▃▅▆▇▇▆▇█████▇██████████████▇█
val/val_loss,█▅▅▃▂▂▂▂▁▁▁▁▂▂▂▂▃▃▃▃▄▃▃▄▃▄▄▅▅▄

0,1
test_accuracy,0.8
train/train_accuracy,0.9
train/train_loss,0.13226
val/val_accuracy,0.8084
val/val_loss,0.87973


Hyperparameter Set 3:

In [18]:
# Create data loaders for training and testing
batch_size = 64
train_loader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset_2, batch_size=batch_size, shuffle=False)
n_steps_per_epoch = math.ceil(len(train_loader.dataset) /batch_size)

wandb.init(
    project="resnet18_wandb_project",name="Hyper_param3",
)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001,momentum=0.9)
main_fun(model, optimizer)

Train accuracy: 0.98436, Train Loss: 0.021, val Loss: 0.843864, val accuracy: 0.82
Train accuracy: 0.98672, Train Loss: 0.003, val Loss: 0.829377, val accuracy: 0.82
Train accuracy: 0.98804, Train Loss: 0.048, val Loss: 0.837503, val accuracy: 0.82
Train accuracy: 0.9892, Train Loss: 0.048, val Loss: 0.820192, val accuracy: 0.82
Train accuracy: 0.9908, Train Loss: 0.016, val Loss: 0.821591, val accuracy: 0.82
Train accuracy: 0.99032, Train Loss: 0.015, val Loss: 0.822770, val accuracy: 0.83
Train accuracy: 0.99128, Train Loss: 0.006, val Loss: 0.835963, val accuracy: 0.83
Train accuracy: 0.99208, Train Loss: 0.024, val Loss: 0.836923, val accuracy: 0.83
Train accuracy: 0.99192, Train Loss: 0.081, val Loss: 0.827444, val accuracy: 0.83
Train accuracy: 0.99232, Train Loss: 0.006, val Loss: 0.831061, val accuracy: 0.83
Train accuracy: 0.99332, Train Loss: 0.004, val Loss: 0.841351, val accuracy: 0.83
Train accuracy: 0.99404, Train Loss: 0.006, val Loss: 0.831057, val accuracy: 0.83
Train 

0,1
train/train_accuracy,▅█▅▁▅█▁▅██▅▅██▅██████▅██▅████▅█▅████▅██▅
train/train_loss,█▁▄▄▃▂▄▂▂▂▅▂▁▁▂▁▁▁▁▁▂▄▁▁▂▂▁▁▁▃▁▅▁▁▁▁▃▁▁▂
val/val_accuracy,▁▂▂▃▃▄▅▅▄▅▅▅▆▅▅█▆▆▇▅▇████▇█▇██
val/val_loss,▃▂▂▁▁▁▂▂▂▂▃▂▄▃▃▄▃▄▅▅▅▅▅▆▆▆▆▆█▆

0,1
test_accuracy,0.8
train/train_accuracy,1.0
train/train_loss,0.01871
val/val_accuracy,0.8346
val/val_loss,0.87995


Hyperparameter Set 4:

In [19]:
# Create data loaders for training and testing
batch_size = 128
train_loader = torch.utils.data.DataLoader(trainset_2, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset_2, batch_size=batch_size, shuffle=False)
n_steps_per_epoch = math.ceil(len(train_loader.dataset) /batch_size)


wandb.init(
    project="resnet18_wandb_project",name="Hyper_param4",
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01,momentum=0.9)
main_fun(model, optimizer)

Train accuracy: 0.34444, Train Loss: 3.639, val Loss: 1.940185, val accuracy: 0.29
Train accuracy: 0.40256, Train Loss: 1.286, val Loss: 2.313123, val accuracy: 0.43
Train accuracy: 0.47476, Train Loss: 1.088, val Loss: 1.088516, val accuracy: 0.57
Train accuracy: 0.56384, Train Loss: 2.135, val Loss: 1.345823, val accuracy: 0.61
Train accuracy: 0.64348, Train Loss: 1.053, val Loss: 0.991835, val accuracy: 0.63
Train accuracy: 0.69772, Train Loss: 0.722, val Loss: 0.920502, val accuracy: 0.66
Train accuracy: 0.73404, Train Loss: 0.521, val Loss: 0.764405, val accuracy: 0.71
Train accuracy: 0.72104, Train Loss: 0.466, val Loss: 0.712273, val accuracy: 0.74
Train accuracy: 0.7458, Train Loss: 0.544, val Loss: 0.776872, val accuracy: 0.73
Train accuracy: 0.7308, Train Loss: 0.761, val Loss: 0.866811, val accuracy: 0.69
Train accuracy: 0.72156, Train Loss: 0.982, val Loss: 1.160334, val accuracy: 0.68
Train accuracy: 0.77792, Train Loss: 5.344, val Loss: 0.761498, val accuracy: 0.71
Train 

0,1
train/train_accuracy,▂▂▁▃▄▅▆▆▅▆▆▇▆▆▆▆▅▇▇▆▆▅▆▆▇▇█▆▇▆▇▇█▇▇▇█▇▇█
train/train_loss,█▄▆▄▃▃▂▂▂▂▂▂▂▂▂▂▃▂▁▂▃▆▄▂▂▂▁▂▂▃▂▂▁▂▁▁▁▁▁▁
val/val_accuracy,▁▃▅▆▆▆▇▇▇▇▆▇██▇▁▇▄▄▇█▇███████▇
val/val_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▂▁▁▁▁▁▁▁▁▁▁▁

0,1
test_accuracy,0.8
train/train_accuracy,0.8
train/train_loss,1.34635
val/val_accuracy,0.7408
val/val_loss,1.60103
