# training the binary classification task using a VGG16 network written from scratch

* The python file model.py contains the code for the custom written VGG16 network
* this notebook also performs training on a pretrained ResNet50 netwoek which is much more sophisticated network compared to VGG16.  * results of both models will be evalusted in the evaluation step
* in addition to the model, a train function is also written (in model.py) and imported here which is basically wrapper that performs the forward and backward loop of the train and forward loop of the eval model simultaneously.
* the function "get_loader" is a custom wrapper that makes a pytorch dataset and applies a dataloader to it.

In [4]:
import os
os.chdir('/Users/simpleai/Desktop/az_task/scripts/')
import pandas as pd
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import warnings
warnings.filterwarnings('ignore')

from dataset import get_loader # custom written for this task
from model import MyVGG16      # custom written for this task
from train import train        # custom written for this task

In [24]:
# hyperparameters
NUM_CLASSES = 2
NUM_EPOCHS = 40
BATCH_SIZE = 4
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 1E-4
CLASS_WEIGHT = torch.tensor([0.8,0.2])
device = "mps" if torch.backends.mps.is_available() else "cpu" # I'm using an M1 chip which can perform vectorized calculations over Mac GPU using ARM technology

In [27]:
# defining the transforms to be applied to the images
# transform will include on-the-fly augmentations later
transform = transforms.Compose([
   transforms.ToTensor(), 
   transforms.Resize((256,256)), # resizing to 256 by 256 due to having limited amount of ram available on my Mac
   transforms.Normalize(
       mean=[0.4914, 0.4822, 0.4465], 
       std=[0.2023, 0.1994, 0.2010]
       )
   ])

# opening the label_df where both labels and path to images are located in
label_df_path = '/Users/simpleai/Desktop/az_task/results/task1/binary_classification_metadata.csv'
label_df = pd.read_csv(label_df_path, index_col=0)

# dividing the label_df into train, val, test dataset
train_df = label_df.loc[label_df['type']=='Train', :] #1519 Disease; 401 Healthy 0.8:0.2 imbalance
#train_df = train_df.sample(frac=0.25) # selecting a random quarter of the data for the speed
train_df = train_df.reset_index(drop=True)

val_df = label_df.loc[label_df['type']=='Val', :]
#val_df = val_df.sample(frac=0.25) # selecting a random quarter of the data for the speed
val_df = val_df.reset_index(drop=True)

test_df = label_df.loc[label_df['type']=='Test', :]
test_df = test_df.reset_index(drop=True)

# defining the dataloader dictionary
dataloader_dict = {
    'train': get_loader(
        data_df = train_df, 
        label_cols = ['Healthy', 'Disease_Risk'],
        batch_size = BATCH_SIZE, 
        transform = transform, 
        shuffle = True),
    'val': get_loader(
        data_df = val_df, 
        label_cols = ['Healthy', 'Disease_Risk'],
        batch_size = BATCH_SIZE, 
        transform = transform, 
        shuffle = False),
    }

### Traing VGG16 model

In [28]:
# defining the model
model = MyVGG16(num_classes=NUM_CLASSES, input_height=256, input_width=256).to(device)


# defining criterion, optimizer
criterion = nn.CrossEntropyLoss(weight=CLASS_WEIGHT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# training the model
model, train_val_df = train(
    model, 
    dataloader_dict, 
    criterion, 
    optimizer, 
    device, 
    checkpoint_path='/Users/simpleai/Desktop/az_task/results/task2/vgg16_binary_classifier.pth', 
    train_val_info_path='/Users/simpleai/Desktop/az_task/results/task2/vgg16_binary_classifier_info.csv',
    num_epochs=NUM_EPOCHS,
    vgg=True
    )


Epoch 1/40
__________


100%|█████████████████████████████████████████| 480/480 [06:16<00:00,  1.28it/s]


epoch 1 train >>> loss: 0.264, accuracy: 0.620


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.49it/s]


epoch 1 val >>> loss: 0.165, accuracy: 0.786
saving the best model in /Users/simpleai/Desktop/az_task/results/task2/vgg16_binary_classifier.pth
Epoch 2/40
__________


100%|█████████████████████████████████████████| 480/480 [06:07<00:00,  1.31it/s]


epoch 2 train >>> loss: 0.173, accuracy: 0.677


100%|█████████████████████████████████████████| 160/160 [01:03<00:00,  2.53it/s]


epoch 2 val >>> loss: 0.151, accuracy: 0.791
saving the best model in /Users/simpleai/Desktop/az_task/results/task2/vgg16_binary_classifier.pth
Epoch 3/40
__________


100%|█████████████████████████████████████████| 480/480 [06:12<00:00,  1.29it/s]


epoch 3 train >>> loss: 0.165, accuracy: 0.692


100%|█████████████████████████████████████████| 160/160 [01:03<00:00,  2.54it/s]


epoch 3 val >>> loss: 0.158, accuracy: 0.705
Epoch 4/40
__________


100%|█████████████████████████████████████████| 480/480 [06:00<00:00,  1.33it/s]


epoch 4 train >>> loss: 0.162, accuracy: 0.711


100%|█████████████████████████████████████████| 160/160 [01:03<00:00,  2.52it/s]


epoch 4 val >>> loss: 0.153, accuracy: 0.791
Epoch 5/40
__________


100%|█████████████████████████████████████████| 480/480 [06:11<00:00,  1.29it/s]


epoch 5 train >>> loss: 0.158, accuracy: 0.732


100%|█████████████████████████████████████████| 160/160 [01:03<00:00,  2.52it/s]


epoch 5 val >>> loss: 0.154, accuracy: 0.631
Epoch 6/40
__________


 31%|████████████▉                            | 151/480 [01:51<04:03,  1.35it/s]


KeyboardInterrupt: 

### Training pretrained ResNet50 model

In [30]:
# Using pretrained weights:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

# modifying output number of classes from 1000 to NUM_CLASSES
model.fc = nn.Linear(2048, NUM_CLASSES)
model = model.to(device)

# defining criterion, optimizer
criterion = nn.CrossEntropyLoss(weight=CLASS_WEIGHT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# training the model
model, train_val_df = train(
    model, 
    dataloader_dict, 
    criterion, 
    optimizer, 
    device, 
    checkpoint_path='/Users/simpleai/Desktop/az_task/results/task2/resnet50_binary_classifier.pth', 
    train_val_info_path='/Users/simpleai/Desktop/az_task/results/task2/resnet50_binary_classifier_info.csv',
    num_epochs=NUM_EPOCHS
    )

Epoch 1/40
__________


100%|█████████████████████████████████████████| 480/480 [03:47<00:00,  2.11it/s]


epoch 1 train >>> loss: 0.132, accuracy: 0.778


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.85it/s]


epoch 1 val >>> loss: 0.086, accuracy: 0.842
saving the best model in /Users/simpleai/Desktop/az_task/results/task2/resnet50_binary_classifier.pth
Epoch 2/40
__________


100%|█████████████████████████████████████████| 480/480 [03:51<00:00,  2.07it/s]


epoch 2 train >>> loss: 0.101, accuracy: 0.840


100%|█████████████████████████████████████████| 160/160 [00:55<00:00,  2.86it/s]


epoch 2 val >>> loss: 0.074, accuracy: 0.864
saving the best model in /Users/simpleai/Desktop/az_task/results/task2/resnet50_binary_classifier.pth
Epoch 3/40
__________


100%|█████████████████████████████████████████| 480/480 [03:48<00:00,  2.10it/s]


epoch 3 train >>> loss: 0.086, accuracy: 0.868


100%|█████████████████████████████████████████| 160/160 [00:57<00:00,  2.80it/s]


epoch 3 val >>> loss: 0.120, accuracy: 0.761
Epoch 4/40
__________


100%|█████████████████████████████████████████| 480/480 [03:49<00:00,  2.09it/s]


epoch 4 train >>> loss: 0.077, accuracy: 0.880


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.85it/s]


epoch 4 val >>> loss: 0.081, accuracy: 0.853
Epoch 5/40
__________


100%|█████████████████████████████████████████| 480/480 [03:45<00:00,  2.13it/s]


epoch 5 train >>> loss: 0.068, accuracy: 0.901


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.84it/s]


epoch 5 val >>> loss: 0.086, accuracy: 0.858
Epoch 6/40
__________


100%|█████████████████████████████████████████| 480/480 [03:46<00:00,  2.12it/s]


epoch 6 train >>> loss: 0.055, accuracy: 0.928


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.84it/s]


epoch 6 val >>> loss: 0.088, accuracy: 0.850
Epoch 7/40
__________


100%|█████████████████████████████████████████| 480/480 [03:49<00:00,  2.09it/s]


epoch 7 train >>> loss: 0.042, accuracy: 0.940


100%|█████████████████████████████████████████| 160/160 [00:57<00:00,  2.78it/s]


epoch 7 val >>> loss: 0.109, accuracy: 0.838
Epoch 8/40
__________


100%|█████████████████████████████████████████| 480/480 [03:47<00:00,  2.11it/s]


epoch 8 train >>> loss: 0.039, accuracy: 0.939


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.85it/s]


epoch 8 val >>> loss: 0.127, accuracy: 0.855
Epoch 9/40
__________


100%|█████████████████████████████████████████| 480/480 [03:49<00:00,  2.09it/s]


epoch 9 train >>> loss: 0.027, accuracy: 0.966


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.85it/s]


epoch 9 val >>> loss: 0.108, accuracy: 0.847
Epoch 10/40
__________


100%|█████████████████████████████████████████| 480/480 [03:47<00:00,  2.11it/s]


epoch 10 train >>> loss: 0.036, accuracy: 0.951


100%|█████████████████████████████████████████| 160/160 [00:55<00:00,  2.86it/s]


epoch 10 val >>> loss: 0.155, accuracy: 0.794
Epoch 11/40
__________


100%|█████████████████████████████████████████| 480/480 [03:48<00:00,  2.10it/s]


epoch 11 train >>> loss: 0.027, accuracy: 0.964


100%|█████████████████████████████████████████| 160/160 [00:57<00:00,  2.76it/s]


epoch 11 val >>> loss: 0.180, accuracy: 0.811
Epoch 12/40
__________


100%|█████████████████████████████████████████| 480/480 [03:49<00:00,  2.09it/s]


epoch 12 train >>> loss: 0.013, accuracy: 0.982


100%|█████████████████████████████████████████| 160/160 [00:55<00:00,  2.88it/s]


epoch 12 val >>> loss: 0.143, accuracy: 0.816
Epoch 13/40
__________


100%|█████████████████████████████████████████| 480/480 [03:46<00:00,  2.12it/s]


epoch 13 train >>> loss: 0.017, accuracy: 0.978


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.82it/s]


epoch 13 val >>> loss: 0.206, accuracy: 0.853
Epoch 14/40
__________


100%|█████████████████████████████████████████| 480/480 [03:48<00:00,  2.10it/s]


epoch 14 train >>> loss: 0.025, accuracy: 0.965


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.83it/s]


epoch 14 val >>> loss: 0.141, accuracy: 0.852
Epoch 15/40
__________


100%|█████████████████████████████████████████| 480/480 [03:46<00:00,  2.12it/s]


epoch 15 train >>> loss: 0.009, accuracy: 0.986


100%|█████████████████████████████████████████| 160/160 [00:56<00:00,  2.82it/s]


epoch 15 val >>> loss: 0.137, accuracy: 0.855
Epoch 16/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 16 train >>> loss: 0.012, accuracy: 0.985


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.47it/s]


epoch 16 val >>> loss: 0.145, accuracy: 0.852
Epoch 17/40
__________


100%|█████████████████████████████████████████| 480/480 [04:17<00:00,  1.87it/s]


epoch 17 train >>> loss: 0.024, accuracy: 0.967


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.44it/s]


epoch 17 val >>> loss: 0.129, accuracy: 0.847
Epoch 18/40
__________


100%|█████████████████████████████████████████| 480/480 [04:18<00:00,  1.86it/s]


epoch 18 train >>> loss: 0.013, accuracy: 0.984


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.45it/s]


epoch 18 val >>> loss: 0.115, accuracy: 0.859
Epoch 19/40
__________


100%|█████████████████████████████████████████| 480/480 [04:20<00:00,  1.84it/s]


epoch 19 train >>> loss: 0.003, accuracy: 0.996


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.44it/s]


epoch 19 val >>> loss: 0.120, accuracy: 0.887
Epoch 20/40
__________


100%|█████████████████████████████████████████| 480/480 [04:16<00:00,  1.87it/s]


epoch 20 train >>> loss: 0.002, accuracy: 0.998


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.46it/s]


epoch 20 val >>> loss: 0.146, accuracy: 0.872
Epoch 21/40
__________


100%|█████████████████████████████████████████| 480/480 [04:17<00:00,  1.86it/s]


epoch 21 train >>> loss: 0.000, accuracy: 0.999


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.45it/s]


epoch 21 val >>> loss: 0.149, accuracy: 0.870
Epoch 22/40
__________


100%|█████████████████████████████████████████| 480/480 [04:20<00:00,  1.84it/s]


epoch 22 train >>> loss: 0.000, accuracy: 1.000


100%|█████████████████████████████████████████| 160/160 [01:06<00:00,  2.41it/s]


epoch 22 val >>> loss: 0.170, accuracy: 0.877
Epoch 23/40
__________


100%|█████████████████████████████████████████| 480/480 [04:19<00:00,  1.85it/s]


epoch 23 train >>> loss: 0.036, accuracy: 0.952


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.43it/s]


epoch 23 val >>> loss: 0.121, accuracy: 0.848
Epoch 24/40
__________


100%|█████████████████████████████████████████| 480/480 [04:20<00:00,  1.85it/s]


epoch 24 train >>> loss: 0.024, accuracy: 0.972


100%|█████████████████████████████████████████| 160/160 [01:06<00:00,  2.41it/s]


epoch 24 val >>> loss: 0.092, accuracy: 0.872
Epoch 25/40
__________


100%|█████████████████████████████████████████| 480/480 [04:22<00:00,  1.83it/s]


epoch 25 train >>> loss: 0.013, accuracy: 0.986


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.42it/s]


epoch 25 val >>> loss: 0.129, accuracy: 0.878
Epoch 26/40
__________


100%|█████████████████████████████████████████| 480/480 [04:23<00:00,  1.83it/s]


epoch 26 train >>> loss: 0.006, accuracy: 0.992


100%|█████████████████████████████████████████| 160/160 [01:06<00:00,  2.40it/s]


epoch 26 val >>> loss: 0.152, accuracy: 0.830
Epoch 27/40
__________


100%|█████████████████████████████████████████| 480/480 [04:18<00:00,  1.85it/s]


epoch 27 train >>> loss: 0.003, accuracy: 0.995


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.48it/s]


epoch 27 val >>> loss: 0.144, accuracy: 0.858
Epoch 28/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 28 train >>> loss: 0.021, accuracy: 0.970


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.47it/s]


epoch 28 val >>> loss: 0.112, accuracy: 0.853
Epoch 29/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 29 train >>> loss: 0.008, accuracy: 0.989


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.46it/s]


epoch 29 val >>> loss: 0.150, accuracy: 0.859
Epoch 30/40
__________


100%|█████████████████████████████████████████| 480/480 [04:09<00:00,  1.92it/s]


epoch 30 train >>> loss: 0.011, accuracy: 0.989


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.47it/s]


epoch 30 val >>> loss: 0.133, accuracy: 0.864
Epoch 31/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 31 train >>> loss: 0.003, accuracy: 0.997


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.46it/s]


epoch 31 val >>> loss: 0.128, accuracy: 0.878
Epoch 32/40
__________


100%|█████████████████████████████████████████| 480/480 [04:10<00:00,  1.91it/s]


epoch 32 train >>> loss: 0.001, accuracy: 1.000


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.46it/s]


epoch 32 val >>> loss: 0.159, accuracy: 0.872
Epoch 33/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 33 train >>> loss: 0.001, accuracy: 1.000


100%|█████████████████████████████████████████| 160/160 [01:04<00:00,  2.46it/s]


epoch 33 val >>> loss: 0.151, accuracy: 0.858
Epoch 34/40
__________


100%|█████████████████████████████████████████| 480/480 [04:13<00:00,  1.90it/s]


epoch 34 train >>> loss: 0.030, accuracy: 0.961


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.44it/s]


epoch 34 val >>> loss: 0.138, accuracy: 0.839
Epoch 35/40
__________


100%|█████████████████████████████████████████| 480/480 [04:13<00:00,  1.89it/s]


epoch 35 train >>> loss: 0.011, accuracy: 0.989


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.43it/s]


epoch 35 val >>> loss: 0.123, accuracy: 0.866
Epoch 36/40
__________


100%|█████████████████████████████████████████| 480/480 [04:12<00:00,  1.90it/s]


epoch 36 train >>> loss: 0.009, accuracy: 0.991


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.44it/s]


epoch 36 val >>> loss: 0.126, accuracy: 0.863
Epoch 37/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 37 train >>> loss: 0.004, accuracy: 0.996


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.46it/s]


epoch 37 val >>> loss: 0.129, accuracy: 0.853
Epoch 38/40
__________


100%|█████████████████████████████████████████| 480/480 [04:12<00:00,  1.90it/s]


epoch 38 train >>> loss: 0.009, accuracy: 0.988


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.43it/s]


epoch 38 val >>> loss: 0.136, accuracy: 0.850
Epoch 39/40
__________


100%|█████████████████████████████████████████| 480/480 [04:11<00:00,  1.91it/s]


epoch 39 train >>> loss: 0.009, accuracy: 0.989


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.46it/s]


epoch 39 val >>> loss: 0.151, accuracy: 0.825
Epoch 40/40
__________


100%|█████████████████████████████████████████| 480/480 [04:12<00:00,  1.90it/s]


epoch 40 train >>> loss: 0.001, accuracy: 0.999


100%|█████████████████████████████████████████| 160/160 [01:05<00:00,  2.45it/s]

epoch 40 val >>> loss: 0.177, accuracy: 0.836



