In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('../models')
from vgg import VGG16_model

import os
import random
from imutils import paths
from collections import defaultdict 
import numpy as np
import time

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

### Fixing Random Seeds

In [2]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

### Initializing parameters

In [3]:
DATA_DIR = '../data'
CROSS_VAL_DIR = '../data/cross_validation'
MODEL_SAVE_DIR = '../trained_models'
MODEL_SAVE_NAME = 'vgg_trained.pt'
N_EPOCHS = 5
FOLD = 0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
LR = 1e-4
BATCH_SIZE = 16
IMG_WIDTH, IMG_HEIGHT = 224, 224

# Create the model save dir if it already doesn't exist
if not os.path.exists(MODEL_SAVE_DIR):
    os.makedirs(MODEL_SAVE_DIR)

### Load the Pocovid Dataset according to the fold passed

In [4]:
import sys
sys.path.append('../scripts')
from pocovid_dataset import PocovidDataset
from torch.utils.data import Dataset, DataLoader

In [9]:
class_map = trainset.get_class_map()

### The Training Class

In [None]:
class Trainer():
    
    def __init__(self, model_name='vgg16', lr=LR, n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, 
                 image_width=IMG_WIDTH, image_height=IMG_HEIGHT, cross_val_dir=CROSS_VAL_DIR,
                model_save_dir=MODEL_SAVE_DIR):
        
        if(model_name=='vgg16'):
            self.model = VGG16_model().to(device)
        elif(model_name=='resnet50'):
            self.model = RESNET50_model().to(device)
        else:
            print('Select models from the following:\n 1) vgg16\n 2) resnet50')
                    
        self.lr = lr
        self.n_epochs = n_epochs
        self.batch_size = batch_size
        
        self.criterion = nn.CrossEntropyLoss().to(device)
        self.optimizer = optim.Adam(params = self.model.parameters(), lr=self.lr) #experiment with weigth_decay
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, 1, gamma=0.95) # use scheduler
        
        self.model_save_dir = model_save_dir
        
        
    def get_train_test_info(self):
        """
        Get information dictionaries for train and test data
        """
    
        imagePaths = list(paths.list_images(cross_val_dir))

        train_path_info = defaultdict(list)
        test_path_info = defaultdict(list)

        for imagePath in imagePaths:
            path_parts = imagePath.split(os.path.sep)
            fold_number = path_parts[-3][-1]
            label = path_parts[-2]
            if(fold_number==str(FOLD)):
                test_path_info['path_list'].append(imagePath)
                test_path_info['label_list'].append(label)
            else:
                train_path_info['path_list'].append(imagePath)
                train_path_info['label_list'].append(label)

        return train_path_info, test_path_info
    
    
    def get_train_test_loaders(self, num_workers=2):
        
        """
        Get the train and test data according to the fold
        """
        
        train_path_info, test_path_info = self.get_train_test_info()

        train_transform = transforms.Compose([transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
                                           transforms.RandomAffine(10,translate=(0.1,0.1)),
                                           transforms.ToTensor()])

        test_transform = transforms.Compose([transforms.Resize((IMG_WIDTH, IMG_HEIGHT)),
                                           transforms.ToTensor()])

        trainset = PocovidDataset(train_path_info, transform = train_transform)
        testset = PocovidDataset(test_path_info, transform = test_transform)

        train_loader = torch.utils.data.DataLoader(trainset, num_workers=num_workers, shuffle=True,
                                          batch_size=self.batch_size, drop_last=True)

        test_loader = torch.utils.data.DataLoader(testset, num_workers=num_workers, shuffle=True,
                                        batch_size=self.batch_size)
        
        return train_loader, test_loader
        
    
    def train(self, iterator):
        """
        The train function
        """
    
        self.model.train()

        epoch_loss = 0

        for i, batch in enumerate(iterator):

            inputs, labels = batch[0].to(device), batch[1].to(device)

            self.optimizer.zero_grad()

            outputs = self.model(inputs)

            loss = self.criterion(outputs, labels)

            loss.backward()

            self.optimizer.step()

            epoch_loss += loss.item()

        return epoch_loss / len(iterator)

    def evaluate(self, model, criterion, iterator):
        """
        The eval function
        """
    
        self.model.eval()

        epoch_loss = 0

        with torch.no_grad():    
            for i, batch in enumerate(iterator):    

                inputs, labels = batch[0].to(device), batch[1].to(device)

                outputs = self.model(inputs)

                loss = self.criterion(outputs, labels)

                epoch_loss += loss.item()

        return epoch_loss / len(iterator)
    
    def epoch_time(self, start_time, end_time):
        """
        The utility function to measure the time taken by an epoch to run
        """
        elapsed_time = end_time - start_time
        elapsed_mins = int(elapsed_time / 60)
        elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
        return elapsed_mins, elapsed_secs
    
    def start_training(self):
        """
        The training function which does the training by calling train and eval functions
        """
    
        best_valid_loss = np.inf
        c = 0
        
        train_loader, test_loader = self.get_train_test_loaders()
        
        for epoch in range(self.n_epochs):

            print(f'Epoch: {epoch+1:02}')

            start_time = time.time()

            train_loss = self.train(train_loader)
            valid_loss = self.evaluate(test_loader)

            epoch_mins, epoch_secs = self.epoch_time(start_time, time.time())

            c+=1
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), os.path.join(model_save_dir, '{}_trained.pt'.format(model_name)))
                c=0

            if c>4:
                #decrease lr if loss does not decrease after 5 steps
                self.scheduler.step()
                c=0

            print(f'Time: {epoch_mins}m {epoch_secs}s') 
            print(f'Train Loss: {train_loss:.3f}')
            print(f'Val   Loss: {valid_loss:.3f}')
            print('-'*100)
        print(best_valid_loss)

### The Trained Model Class

In [15]:
class TrainedModel():
    
    def __init__(self, model_name='vgg16'):
        """
        To get the details of the pre-trained model
        """
        trainer = Trainer(model_name=model_name)
        self.model = trainer.model
        self.model_save_dir = trainer.model_save_dir
        
    def loadModel(self, modelPath):
        """
        To load the pre trained model
        """
        self.model.load_state_dict(torch.load(os.path.join(self.model_save_dir, '{}_trained.pt'.format(model_name)), map_location=torch.device(device)))
        return self.model
    
    def countParameters(self):
        """
        To get the number of trainable parameters of the model
        """
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)
    
    def printModel(self):
        """
        To print the network architecture
        """
        print(self.model)

### Testing

In [17]:
# Load the model here, if you are testing
# uncomment while testing

# model = VGG16_model().to(device) 
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_DIR, MODEL_SAVE_NAME), map_location=device))

<All keys matched successfully>

In [1]:
def test(model, iterator, proba=False, one_batch=False):
    
    model.eval()
    
    images = []
    true_labels = []
    pred_labels = []
    pred_probs = []
    
    with torch.no_grad():    
        for i, batch in enumerate(iterator):    
            
            inputs, labels = batch[0].to(device), batch[1].to(device)
            
            outputs = model(inputs)
            
            y_prob = F.softmax(outputs, dim = -1)
            
            top_preds = y_prob.argmax(1, keepdim = True)
            
            images.append(inputs.to(device))
            true_labels.append(labels.to(device))
            pred_labels.append(top_preds.to(device))
            pred_probs.append(y_prob.to(device))
            
            if(one_batch):
                break
            
    images = torch.cat(images, dim=0)
    true_labels = torch.cat(true_labels, dim=0)
    pred_labels = torch.cat(pred_labels, dim=0)
    pred_probs = torch.cat(pred_probs, dim=0)

    
    if(proba):
        return images, true_labels, pred_labels, pred_probs
            
    return images, true_labels, pred_labels

In [27]:
images, true_labels, pred_labels = test(model, test_loader)

### Plot Confusion Matrix

In [None]:
def plot_confusion_matrix(labels, pred_labels, classes):
    
    fig = plt.figure(figsize = (10, 10));
    ax = fig.add_subplot(1, 1, 1);
    cm = confusion_matrix(labels, pred_labels);
    cm = ConfusionMatrixDisplay(cm, display_labels = classes);
    cm.plot(values_format = 'd', cmap = 'Blues', ax = ax)
    plt.xticks(rotation = 20)

In [None]:
plot_confusion_matrix(true_labels.cpu().numpy(), pred_labels.cpu().numpy(), ['covid', 'pneumonia', 'regular'])

# Use this in the training script
#cm = confusion_matrix(true_labels.cpu().numpy(), pred_labels.cpu().numpy())
#print(cm)

### Get the Classification Report

In [None]:
print(classification_report(
        true_labels.cpu().numpy(), pred_labels.cpu().numpy(), target_names= ['covid', 'pneumonia', 'regular']
    ))

### Visualize Sample Test Results

In [None]:
def visualize_test_samples(model, test_loader):
    
    images, true_labels, pred_labels, pred_probs = test(model, test_loader, proba=True, one_batch=True)

    true_labels = true_labels.cpu().numpy()
    pred_labels = pred_labels.cpu().numpy()
    pred_probs = pred_probs.cpu().numpy()


    rows = int(np.sqrt(len(images)))
    cols = int(np.sqrt(len(images)))

    fig = plt.figure(figsize = (25, 20))

    for i in range(rows*cols):

        ax = fig.add_subplot(rows, cols, i+1)

        image, true_label, pred_label, pred_prob = images[i], true_labels[i], pred_labels[i], pred_probs[i]
        image = image.permute(1, 2, 0)
        ax.imshow(image.cpu().numpy())
        ax.set_title(f'true label: {class_map[true_label]}\n' \
                    f'pred label: {class_map[pred_label[0]]} (Prob: {max(pred_prob):.3f})',
                    color = ('green' if true_label==pred_label[0] else 'red'))
        ax.axis('off')

    fig.subplots_adjust(hspace = 0.4)
    
    plt.show()

In [None]:
visualize_test_samples(model, test_loader)

In [None]:
# Todo:
# 1) Implement ROC graph with confidence values (J)
# 2) Hypertune the paramaters for best results (N)
# 3) Implement voting classifier from all 5 folds using 5 models (A)
# 4) Add code for early stopping (A)
# 5) Do visualization stuff (N)
# 6) Refactor all code and make the training and testing scripts with args (A)
# 7) Finalize this notebook with all code included (A)
# 8) Implement the flask api (A)
# 9) Test the script and prepare a notebook for Resnet model (J)