# Check resource availability and Import relavent packages

In [None]:
# check GPU availability
!nvidia-smi

In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision as tv
from torchvision import datasets, models, transforms

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import time
import os
import copy
import requests
import io
import csv

plt.ion()   # interactive mode

import timm 
import tqdm

In [2]:
# check if cuda is available
print(torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))

cuda:0


In [3]:
all_vit_models = timm.list_models('*vit*', pretrained=True)
all_vit_models

['convit_base.fb_in1k',
 'convit_small.fb_in1k',
 'convit_tiny.fb_in1k',
 'crossvit_9_240.in1k',
 'crossvit_9_dagger_240.in1k',
 'crossvit_15_240.in1k',
 'crossvit_15_dagger_240.in1k',
 'crossvit_15_dagger_408.in1k',
 'crossvit_18_240.in1k',
 'crossvit_18_dagger_240.in1k',
 'crossvit_18_dagger_408.in1k',
 'crossvit_base_240.in1k',
 'crossvit_small_240.in1k',
 'crossvit_tiny_240.in1k',
 'davit_base.msft_in1k',
 'davit_small.msft_in1k',
 'davit_tiny.msft_in1k',
 'efficientvit_b0.r224_in1k',
 'efficientvit_b1.r224_in1k',
 'efficientvit_b1.r256_in1k',
 'efficientvit_b1.r288_in1k',
 'efficientvit_b2.r224_in1k',
 'efficientvit_b2.r256_in1k',
 'efficientvit_b2.r288_in1k',
 'efficientvit_b3.r224_in1k',
 'efficientvit_b3.r256_in1k',
 'efficientvit_b3.r288_in1k',
 'efficientvit_l1.r224_in1k',
 'efficientvit_l2.r224_in1k',
 'efficientvit_l2.r256_in1k',
 'efficientvit_l2.r288_in1k',
 'efficientvit_l2.r384_in1k',
 'efficientvit_l3.r224_in1k',
 'efficientvit_l3.r256_in1k',
 'efficientvit_l3.r320_in1

# Create Data Loader

In [4]:
from timm.data import create_dataset, create_loader

In [5]:
# config 
input_size = 3, 224, 224
img_size = 224
num_classes = 15
batch_size = 32

interpolation = 'bicubic'
DEFAULT_CROP_PCT = 1

In [6]:
train_dir = '../Dataset/images/train'
val_dir = '../Dataset/images/validation'

In [7]:
class_map = {
        'Normal': 0,
        'Atelectasis': 1,
        'Cardiomegaly': 2,
        'Effusion': 3,
        'Infiltration': 4,
        'Mass': 5,
        'Nodule': 6,
        'Pneumonia': 7,
        'Pneumothorax': 8,
        'Consolidation': 9,
        'Edema': 10,
        'Emphysema': 11,
        'Fibrosis': 12,
        'Pleural_Thickening': 13,
        'Hernia': 14,
        }

In [8]:
# create the train and eval datasets
train_dataset = create_dataset(name='', root=train_dir, split='train', is_training=True, batch_size=batch_size, class_map = class_map)
val_dataset = create_dataset(name='', root=val_dir, split='validation', is_training=False, batch_size=batch_size, class_map = class_map)
train_len, val_len = len(train_dataset), len(val_dataset)
print('Training set size: ' + str(train_len))
print('Validation set size: ' + str(val_len))

Training set size: 91295
Validation set size: 13175


In [9]:
# resize images to fit the input of pretrained model
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset.transform = train_transform

In [10]:
# create data loaders 
loader_train = create_loader(
        train_dataset,
        input_size=input_size,
        batch_size=batch_size,
        is_training=True,
        interpolation=interpolation,
        num_workers=4)

loader_val = create_loader(
        val_dataset,
        input_size=input_size,
        batch_size=batch_size,
        is_training=False,
        interpolation=interpolation,
        crop_pct=DEFAULT_CROP_PCT)

In [11]:
# check if labels are loaded as defined
train_dataset.reader.class_to_idx

{'Normal': 0,
 'Atelectasis': 1,
 'Cardiomegaly': 2,
 'Effusion': 3,
 'Infiltration': 4,
 'Mass': 5,
 'Nodule': 6,
 'Pneumonia': 7,
 'Pneumothorax': 8,
 'Consolidation': 9,
 'Edema': 10,
 'Emphysema': 11,
 'Fibrosis': 12,
 'Pleural_Thickening': 13,
 'Hernia': 14}

In [12]:
# check how many images for each class. confirm if this number is correct to make sure images are loaded properly
class_images_num = dict(zip(class_map.values(),[0]*15))
for i in range(len(train_dataset.reader)):
    _, class_idx = train_dataset.reader[i]
    class_images_num[class_idx] += 1

class_images_num

{0: 44379,
 1: 7250,
 2: 1505,
 3: 7475,
 4: 11958,
 5: 3471,
 6: 4067,
 7: 761,
 8: 2320,
 9: 2485,
 10: 1225,
 11: 1236,
 12: 1078,
 13: 1954,
 14: 131}

# Model

In [12]:
model_name = 'vit_base_r50_s16_224.orig_in21k'

In [None]:
#model_name = 'vit_base_patch16_224.orig_in21k'

In [None]:
#model_name = 'resnet50.a1_in1k'

In [13]:
model = timm.create_model(model_name, pretrained=True, num_classes=num_classes) #, img_size=img_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device) # if print 'cuda' then GPU is used
model.to(device)

cuda


VisionTransformer(
  (patch_embed): HybridEmbed(
    (backbone): ResNetV2(
      (stem): Sequential(
        (conv): StdConv2dSame(3, 64, kernel_size=(7, 7), stride=(2, 2), bias=False)
        (norm): GroupNormAct(
          32, 64, eps=1e-05, affine=True
          (drop): Identity()
          (act): ReLU(inplace=True)
        )
        (pool): MaxPool2dSame(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0), dilation=(1, 1), ceil_mode=False)
      )
      (stages): Sequential(
        (0): ResNetStage(
          (blocks): Sequential(
            (0): Bottleneck(
              (downsample): DownsampleConv(
                (conv): StdConv2dSame(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (norm): GroupNormAct(
                  32, 256, eps=1e-05, affine=True
                  (drop): Identity()
                  (act): Identity()
                )
              )
              (conv1): StdConv2dSame(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      

In [14]:
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

In [15]:
# weight_decay
skip = {}
if hasattr(model, 'no_weight_decay'):
    skip = model.no_weight_decay()
parameters = add_weight_decay(model, 0.0001, skip)
weight_decay = 0.

In [16]:
criterion = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(parameters, momentum=0.9, nesterov=True, lr=0.01, weight_decay=weight_decay)

In [17]:
from timm.scheduler import StepLRScheduler
# setup learning rate schedule and starting epoch
lr_scheduler = StepLRScheduler(optimizer, decay_t=30, decay_rate=0.1,
               warmup_lr_init=0.0001, warmup_t=3, noise_range_t=None, noise_pct=0.67,
               noise_std=1., noise_seed=42)

In [18]:
def eval_fn(model, eval_data):
    model.eval()

    total_correct = 0
    total_samples = 0
    
    preds = []
    with torch.no_grad():
        for images, labels in eval_data:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_samples += labels.size(0)
            total_correct += (predicted == labels).sum().item()
            
            break
    accuracy = total_correct / total_samples
    return accuracy

In [19]:
num_epochs = 10
losses = [[]]
accus_train = [[]]
accus_val = []

In [20]:
from datetime import datetime
current_datetime = datetime.now()
date_time = str(current_datetime)[:-7].replace('-','').replace(':','').replace(' ','_')

In [21]:
model_save_path = f'model_result/model_pth/MODEL_FINETUNE_{model_name}_{date_time}.pth'
model_save_path

'model_result/model_pth/MODEL_FINETUNE_vit_base_r50_s16_224.orig_in21k_20240330_125921.pth'

In [22]:
def output_log_writer(s, end='\r'):
    with open(f'model_result/log/output_{model_name}_{date_time}.txt', 'a') as output_file:
        output_file.write(s+'\n')
        print(s,end,flush=True)

In [None]:
for epoch in range(num_epochs):
    output_log_writer(f'-------------------------------[Epoch {epoch+1}]---------------------------------')
    output_log_writer(f'[Epoch {epoch+1}] Training...', end='')
    for batch, (images, labels) in enumerate(loader_train):
        print('=', end='')
        
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses[-1].append(loss.item()) # all losses for this epoch

        with torch.no_grad():
            accus_train[-1].append(torch.sum(torch.max(outputs, dim=1)[1] == labels)) # all train accuracy for this epoch
    
    
    print('\r')
    output_log_writer(f'[Epoch {epoch+1}] Computing Train Measurement...')
    # save all batches loss
    with open(f'model_result/measurement/batch_loss_{model_name}_{date_time}.csv', 'w') as b_loss_file:
        writer = csv.writer(b_loss_file)
        writer.writerow([epoch] + losses[-1])
    # compute total loss after this epoch 
    losses[-1] = sum(losses[-1]) 
    # save epoch loss
    with open(f'model_result/measurement/epoch_loss_{model_name}_{date_time}.csv', 'w') as e_loss_file:
        writer = csv.writer(e_loss_file)
        writer.writerow([epoch, losses[-1]])
    losses.append([])
    
    # compute average accuracy after this epoch
    accus_train[-1] = sum(accus_train[-1]) / train_len 
    with open(f'model_result/measurement/acc_train_{model_name}_{date_time}.csv', 'w') as acc_train_file:
        writer = csv.writer(acc_train_file)
        writer.writerow([epoch, float(accus_train[-1])])
    accus_train.append([])

    # step LR for next epoch
    lr_scheduler.step(epoch + 1)
    
    output_log_writer(f'[Epoch {epoch+1}] Computing Validation Measurement...')
    accus_val.append(eval_fn(model, loader_val))
    with open(f'model_result/measurement/acc_val_{model_name}_{date_time}.csv', 'w') as acc_val_file:
        writer = csv.writer(acc_val_file)
        writer.writerow([epoch, accus_val[-1]])
    
    model.train() # slowest line

    
    # print evaludation
    output_log_writer(f'[Epoch {epoch+1}] loss={losses[-2]:.2e} train accu={accus_train[-2]:.2%} validation accu={accus_val[-1]:.2%}')
    
    # save output log
    
        
    # save model checkpoint
    ckpt_save_path = f'model_result/model_checkpoint/MODEL_CKPT_{epoch:2d}_{model_name}_{date_time}.pt'
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': losses[-2],
            }, ckpt_save_path)
    output_log_writer(f'[Epoch {epoch+1}] Model checkpoint is saved.')

        
torch.save(model, model_save_path)
output_log_writer('\n\nFinal model saved. Training Finished.')