In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('../models')
from vgg import VGG16_model

import os
from imutils import paths
from collections import defaultdict 
import numpy as np
import time

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt

### Initializing parameters

In [2]:
# Add remaining parameters in the end
# And convert everything to parse args

DATA_DIR = '../data'
CROSS_VAL_DIR = '../data/cross_validation'
MODEL_SAVE_DIR = '../trained_models'
N_EPOCHS = 5
FOLD = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# MODEL_NAME = args['model_name']
# LR = args['learning_rate']
# BATCH_SIZE = args['batch_size']
# MODEL_ID = args['model_id'] ---> for model factory. Implement this.
# TRAINABLE_BASE_LAYERS = args['trainable_base_layers']
# IMG_WIDTH, IMG_HEIGHT = args['img_width'], args['img_height']
# LOG_SOFTMAX = args['log_softmax']
# HIDDEN_SIZE = args['hidden_size']

In [3]:
# # Create the model save dir if it already doesn't exist
if not os.path.exists(MODEL_SAVE_DIR):
    os.makedirs(MODEL_SAVE_DIR)

### Load the Pocovid Dataset according to the fold passed

In [4]:
import sys
sys.path.append('../scripts')
from pocovid_dataset import PocovidDataset
from torch.utils.data import Dataset, DataLoader

In [5]:
def get_train_test_paths():
    
    imagePaths = list(paths.list_images(CROSS_VAL_DIR))

    train_path_info = defaultdict(list)
    test_path_info = defaultdict(list)

    for imagePath in imagePaths:
        path_parts = imagePath.split(os.path.sep)
        fold_number = path_parts[-3][-1]
        label = path_parts[-2]
        if(fold_number==str(FOLD)):
            test_path_info['path_list'].append(imagePath)
            test_path_info['label_list'].append(label)
        else:
            train_path_info['path_list'].append(imagePath)
            train_path_info['label_list'].append(label)

    return train_path_info, test_path_info

In [6]:
train_path_info, test_path_info = get_train_test_paths()

In [7]:
# Get the train and test data according to the fold 

train_transform = transforms.Compose([transforms.Resize((224,224)),
                                   transforms.RandomAffine(10,translate=(0.1,0.1)),
                                   transforms.ToTensor()])

test_transform = transforms.Compose([transforms.Resize((224,224)),
                                   transforms.ToTensor()])

trainset = PocovidDataset(train_path_info, transform = train_transform)
testset = PocovidDataset(test_path_info, transform = test_transform)

batch_size = 16
num_workers = 2

train_loader = torch.utils.data.DataLoader(trainset, num_workers=num_workers, shuffle=True,
                                  batch_size=batch_size, drop_last=True)

test_loader = torch.utils.data.DataLoader(testset, num_workers=num_workers, shuffle=False,
                                batch_size=batch_size)

In [8]:
class_map = {'covid' : trainset.get_covid_class_idx(),
            'pneumonia' : trainset.get_pneu_class_idx(),
            'regular' : trainset.get_regular_class_idx()}

### Model

In [9]:
model = VGG16_model().to(device)

### Loss function and Optimizer

In [10]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(params = model.parameters(), lr=1e-4) #experiment with weigth_decay
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.95) # use scheduler

### Train and Eval functions

In [11]:
def train(model, optimizer, criterion, iterator):
    
    model.train()
    
    epoch_loss = 0
    
    for i, batch in enumerate(iterator):
        print(i)
        
        inputs, labels = batch[0].to(device), batch[1].to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs)

        loss = criterion(outputs, labels)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

In [12]:
def evaluate(model, criterion, iterator):
    
    model.eval()
    
    epoch_loss = 0
    
    with torch.no_grad():    
        for i, batch in enumerate(iterator):    
            
            inputs, labels = batch[0].to(device), batch[1].to(device)
            
            outputs = model(inputs)

            loss = criterion(outputs, labels)

            epoch_loss += loss.item()

    return epoch_loss / len(iterator)

### Epoch Time Measuring 

In [13]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

### Training

In [15]:
best_valid_loss = np.inf
c = 0

for epoch in range(N_EPOCHS):
    
    print(f'Epoch: {epoch+1:02}')

    start_time = time.time()

    train_loss = train(model, optimizer, criterion, train_loader)
    valid_loss = evaluate(model, criterion, test_loader)

    epoch_mins, epoch_secs = epoch_time(start_time, time.time())
    
    c+=1
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, 'vgg_trained.pt'))
        c=0
 
    if c>4:
        #decrease lr if loss does not decrease after 5 steps
        scheduler.step()
        c=0

    print(f'Time: {epoch_mins}m {epoch_secs}s') 
    print(f'Train Loss: {train_loss:.3f}')
    print(f'Val   Loss: {valid_loss:.3f}')
    print('-'*100)
print(best_valid_loss)

### Testing

In [17]:
# Load the model here, if you are testing
# uncomment while testing

# model = VGG16_model().to(device) 
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_DIR, 'vgg_trained.pt'), map_location=device))

<All keys matched successfully>

In [1]:
def test(model, iterator):
    
    model.eval()
    
    images = []
    true_labels = []
    pred_labels = []
    
    with torch.no_grad():    
        for i, batch in enumerate(iterator):    
            
            inputs, labels = batch[0].to(device), batch[1].to(device)
            
            outputs = model(inputs)
            
            y_prob = F.softmax(outputs, dim = -1)
            
            top_preds = y_prob.argmax(1, keepdim = True)
            
            images.append(inputs.to(device))
            true_labels.append(labels.to(device))
            pred_labels.append(top_preds.to(device))
            
    images = torch.cat(images, dim=0)
    true_labels = torch.cat(true_labels, dim=0)
    pred_labels = torch.cat(pred_labels, dim=0)
            
    return images, true_labels, pred_labels

In [27]:
images, true_labels, pred_labels = test(model, test_loader)

In [None]:
def calculate_accuracy(y, y_pred):
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

### Plot Confusion Matrix

In [None]:
def plot_confusion_matrix(labels, pred_labels, classes):
    
    fig = plt.figure(figsize = (10, 10));
    ax = fig.add_subplot(1, 1, 1);
    cm = confusion_matrix(labels, pred_labels);
    cm = ConfusionMatrixDisplay(cm, display_labels = classes);
    cm.plot(values_format = 'd', cmap = 'Blues', ax = ax)
    plt.xticks(rotation = 20)

In [None]:
plot_confusion_matrix(true_labels.cpu().numpy(), pred_labels.cpu().numpy(), ['covid', 'pneumonia', 'regular'])