In [1]:
%matplotlib inline

# Matplotlib
import matplotlib.pyplot as plt
# Numpy
import numpy as np
# Pillow
from PIL import Image
# tqdm
from tqdm.notebook import tqdm
# Torch
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
# OS
import os
# Optuna for Hyperparameter Optimization
import optuna
from optuna.trial import TrialState
# conda install -c conda-forge optuna

Data Loading using the dataloader class

In [2]:
class Lung_Dataset(Dataset):
    """
    Lung Dataset Consisting of Infected and Non-Infected.
    """

    def __init__(self, purpose, verbose=0):
        """
        Constructor for generic Dataset class - simply assembles
        the important parameters in attributes.
        
        Parameter:
        -purpose variable should be set to a string of either 'train', 'test' or 'val'
        -verbose takes an int of either 0,1 or 2. 0 will only differentiate between normal and infected, 1 will differentiate
            between normal, covid and non-covid while 2 will only differentiate between covid and non-covid
        """
        self.purpose = purpose
        self.verbose = verbose
        
        # All images are of size 150 x 150
        self.img_size = (150, 150)
            
        # The dataset has been split in training, testing and validation datasets
        self.groups = ['train', 'test', 'val']
        
        # Path to images for different parts of the dataset
        self.dataset_paths = {'train_normal': './dataset/train/normal/',
                              'train_infected': './dataset/train/infected/',
                              'train_infected_covid': './dataset/train/infected/covid',
                              'train_infected_non_covid': './dataset/train/infected/non-covid',
                              'test_normal': './dataset/test/normal/',
                              'test_infected': './dataset/test/infected/',
                              'test_infected_covid': './dataset/test/infected/covid',
                              'test_infected_non_covid': './dataset/test/infected/non-covid',
                              'val_normal': './dataset/val/normal/',
                              'val_infected': './dataset/val/infected/',
                              'val_infected_covid': './dataset/val/infected/covid',
                              'val_infected_non_covid': './dataset/val/infected/non-covid'}
        
        self.dataset_numbers = {}
        
        # Consider normal and infected only
        if verbose == 0:
            self.classes = {0: 'normal', 1: 'infected'}
            
            #Populate self.dataset_numbers
            for condition in self.classes.values():
                key = "{}_{}".format(self.purpose, condition)
                if condition == "normal":
                    file_path = self.dataset_paths[key]
                    count = len(os.listdir(file_path))
                    self.dataset_numbers[key] = count
                else:
                    key1 = key + "_covid"
                    key2 = key + "_non_covid"
                    file_path1 = self.dataset_paths[key1]
                    file_path2 = self.dataset_paths[key2]
                    count1 = len(os.listdir(file_path1))
                    count2 = len(os.listdir(file_path2))
                    count = count1 + count2
                    self.dataset_numbers[key] = count
                       
        #Consider normal, covid and non-covid
        elif verbose == 1:
            self.classes = {0: 'normal', 1: 'covid', 2: 'non_covid'}
        
            #Populate self.dataset_numbers
            for condition in self.classes.values():
                if condition == "normal":
                    key = "{}_{}".format(self.purpose, condition)
                    file_path = self.dataset_paths[key]
                    count = len(os.listdir(file_path))
                    self.dataset_numbers[key] = count
                else:
                    key = "{}_infected".format(self.purpose)
                    key1 = key + "_covid"
                    key2 = key + "_non_covid"
                    file_path1 = self.dataset_paths[key1]
                    file_path2 = self.dataset_paths[key2]
                    count1 = len(os.listdir(file_path1))
                    count2 = len(os.listdir(file_path2))
                    self.dataset_numbers[key1] = count1
                    self.dataset_numbers[key2] = count2
                
        #Consider covid and non-covid
        elif verbose == 2:
            self.classes = {0: 'covid', 1 :'non_covid' }

            #Populate self.dataset_numbers
            for condition in self.classes.values():
                key = "{}_infected".format(self.purpose)
                key1 = key + "_covid"
                key2 = key + "_non_covid"
                file_path1 = self.dataset_paths[key1]
                file_path2 = self.dataset_paths[key2]
                count1 = len(os.listdir(file_path1))
                count2 = len(os.listdir(file_path2))
                self.dataset_numbers[key1] = count1
                self.dataset_numbers[key2] = count2
            
        else:
            err_msg  = "Verbose argument only takes in an int of either 0,1 or 2"
            raise TypeError(err_msg)
        
        
    def describe(self):
        """
        Descriptor function.
        Will print details about the dataset when called.
        """
        
        # Generate description
        msg = "This is the Lung {} Dataset in the 50.039 Deep Learning class project".format(self.purpose)
        msg += " in Feb-March 2021. \n"
        msg += "It contains a total of {} images, ".format(sum(self.dataset_numbers.values()))
        msg += "of size {} by {}.\n".format(self.img_size[0], self.img_size[1])
        msg += "The images are stored in the following locations "
        msg += "and each one contains the following number of images:\n"
        for key, val in self.dataset_numbers.items():
            if key != 'infected':
                file_path = self.dataset_paths[key]
            else:
                file_path = self.dataset_paths
            msg += " - {}, in folder {}: {} images.\n".format(key, file_path, val)
        print(msg)
        
        
    def open_img(self, class_val, index_val):
        """
        Opens image with specified parameters.
        
        Parameters:
        - class_val variable should be set to 'normal' or 'infected'.
        - index_val should be an integer with values between 0 and the maximal number of images in dataset.
        
        Returns loaded image as a normalized Numpy array.
        """
        group_val = self.purpose
        err_msg = "Error - class_val variable should be set to 'normal', 'infected', 'covid' or 'non_covid'."
        assert class_val in self.classes.values(), err_msg
        
        if class_val == 'covid' or class_val == 'non_covid':
            class_val = 'infected_' + class_val
            
        max_val = self.dataset_numbers['{}_{}'.format(group_val, class_val)]
        err_msg = "Error - index_val variable should be an integer between 0 and the maximal number of images."
        err_msg += "\n(In {}/{}, you have {} images.)".format(group_val, class_val, max_val)
        assert isinstance(index_val, int), err_msg
        assert index_val >= 0 and index_val <= max_val, err_msg
        
        # Open file as before
        if class_val != "infected":
            path_to_file = '{}/{}.jpg'.format(self.dataset_paths['{}_{}'.format(group_val, class_val)], index_val)
        else:
            covid_count = len(os.listdir(self.dataset_paths['{}_{}_covid'.format(group_val, class_val)]))
            if index_val < covid_count:
                path_to_file = '{}/{}.jpg'.format(self.dataset_paths['{}_{}_covid'.format(group_val, class_val)], index_val)
            else:
                index_val = index_val - covid_count
                path_to_file = '{}/{}.jpg'.format(self.dataset_paths['{}_{}_non_covid'.format(group_val, class_val)], index_val)
        with open(path_to_file, 'rb') as f:
            im = np.asarray(Image.open(f))/255
        f.close()
        return im
    
    
    def show_img(self, class_val, index_val):
        """
        Opens, then displays image with specified parameters.
        
        Parameters:
        - class_val variable should be set to 'normal' or 'infected'.
        - index_val should be an integer with values between 0 and the maximal number of images in dataset.
        """
        # Open image
        im = self.open_img(class_val, index_val)
        
        # Display
        plt.imshow(im)
        
    def __len__(self):
        """
        Length special method, returns the number of images in dataset.
        """
        
        # Length function
        return sum(self.dataset_numbers.values())
    
    
    def __getitem__(self, index):
        """
        Getitem special method.
        
        Expects an integer value index, between 0 and len(self) - 1.
        
        Returns the image and its label as a one hot vector, both
        in torch tensor format in dataset.
        """
        #If we only have 2 classes
        if self.verbose == 0 or self.verbose == 2:
            first_val = int(list(self.dataset_numbers.values())[0])
            if index < first_val:
                class_val = self.classes[0]
                label = torch.Tensor([1, 0])
            else:
                class_val = self.classes[1]
                index = index - first_val
                label = torch.Tensor([0, 1])
            im = self.open_img(class_val, index)
            im = transforms.functional.to_tensor(np.array(im)).float()
          
        #If we have 3 classes to consider
        elif self.verbose == 1:
            first_val = int(list(self.dataset_numbers.values())[0])
            second_val = int(list(self.dataset_numbers.values())[1])
            if index < first_val:
                class_val = self.classes[0]
                label = torch.Tensor([1, 0, 0])
            elif index >= first_val and index < first_val + second_val:
                index = index - first_val
                class_val = self.classes[1]
                label = torch.Tensor([0,1,0])
            else:
                index = index-(first_val + second_val)
                class_val = self.classes[2]
                label = torch.Tensor([0,0,1])
            im = self.open_img(class_val, index)
            im = transforms.functional.to_tensor(np.array(im)).float()
                
        else:
            raise TypeError("Verbose value is not 0,1 or 2")
        return im, label

In [3]:
"""
Data Split for First Layer classification task (normal vs infected) 
"""

fl_labels = {
0 : "Normal",
1 : "Infected"
}

fl_train = Lung_Dataset('train', verbose = 0)
fl_test = Lung_Dataset('test', verbose = 0)
fl_val = Lung_Dataset('val', verbose = 0)

"""
Data Split for Second Layer classification task (COVID vs Non-COVID) 
"""

sl_labels = {
0 : "COVID",
1 : "Non-COVID"
}

sl_train = Lung_Dataset('train', verbose = 2)
sl_test = Lung_Dataset('test', verbose = 2)
sl_val = Lung_Dataset('val', verbose = 2)

Model Construction

In [4]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        
        """
        Insert each layer blocks. Same architecture will be used for the first layer and second layer CNNs
        """
        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Dropout(p=0.05))
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Dropout(p=0.05))
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.Dropout(p=0.05))
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.BatchNorm2d(512),
            nn.Dropout(p=0.05))

        self.final = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten())

        self.fc = nn.Linear(512, 1)
        
        """
        Initialization of the Conv layers and FC layer using Kaiming initialization
        """
        self.conv1.apply(init_weights)
        self.conv2.apply(init_weights)
        self.conv3.apply(init_weights)
        self.conv4.apply(init_weights)
        self.fc.apply(init_weights)
    
    def forward(self,x):
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.final(x)
        x = self.fc(x)

        return torch.sigmoid(x)

In [5]:
def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d:
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    
    running_loss = 0
    
    tk0 = tqdm(train_loader, total=int(len(train_loader)))
    counter = 0
    
    for batch_idx, (data, target) in enumerate(tk0):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model.forward(data)
        
        target = target.argmax(dim=1, keepdim=True).float()
        
        loss_criterion = nn.BCELoss()
        loss = loss_criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        counter += 1
        tk0.set_postfix(loss=(running_loss / (counter * train_loader.batch_size)))
        
        if (batch_idx + 1) % 100 == 0:
            print('Epoch:', epoch, ',Training Loss:', running_loss / 100)
            running_loss = 0
        
def validate(model, device, val_loader, loss_criterion, plot=False):
    model.eval()

    correct = 0
    val_loss = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            target = target.argmax(dim=1, keepdim=True).float()
            
            output = model.forward(data)
            val_loss += loss_criterion(output, target).item()
            
            pred = torch.round(output)
            equal_data = torch.sum(target.data == pred).item()
            correct += equal_data

    print("Validation loss: {}".format(val_loss / len(val_loader)))
    print('Validation set accuracy: ', 100. * correct / len(val_loader.dataset), '%')
    
    return (val_loss / len(val_loader))
    
def test(model, device, test_loader, plot=False):
    model.eval()
    
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)

            target = target.argmax(dim=1, keepdim=True).float()
            
            output = model.forward(data)
            pred = torch.round(output)

            equal_data = torch.sum(target.data == pred).item()
            correct += equal_data

    print('Test set accuracy: ', 100. * correct / len(test_loader.dataset), '%')

In [6]:
def main():
    N_EPOCH = 200
    L_RATE = 0.005
    BETAS = (0.9,0.999)
    bs_val = 32
    
    stop_threshold = 0.0005
    temp_loss = None
    early_stop = 0
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    fl_train_loader = DataLoader(fl_train, batch_size=bs_val, shuffle=True)
    fl_val_loader = DataLoader(fl_val, batch_size=bs_val, shuffle=True)
    fl_test_loader = DataLoader(fl_test, batch_size=bs_val, shuffle=True)
    
    fl_model = CNN().to(device)
    fl_optimizer = optim.Adam(fl_model.parameters(), lr=L_RATE, betas=BETAS)
    
    print("Training the first model to classify normal and infected images")
    for epoch in range(1,N_EPOCH+1):
        train(fl_model, device, fl_train_loader, fl_optimizer, epoch)
        val_loss = validate(fl_model, device, fl_val_loader, nn.BCELoss())
        
        """
        Early Stopping 
        """
        if temp_loss == None:
            temp_loss = val_loss
        else:
            if val_loss < temp_loss:
                early_stop = 0
                temp_loss = val_loss
            else:
                early_stop += 1
                
        if early_stop == 5:
            break

    print("\n\n")
    print("Test Accuracy of the first model:")
    test(fl_model, device, fl_test_loader)

    fl_model.to("cpu")
    
    
    #Second Model
    temp_loss = None
    early_stop = 0
    
    sl_train_loader = DataLoader(sl_train, batch_size=bs_val, shuffle=True)
    sl_val_loader = DataLoader(sl_val, batch_size=bs_val, shuffle=True)
    sl_test_loader = DataLoader(sl_test, batch_size=bs_val, shuffle=True)
    
    sl_model = CNN().to(device)
    
    sl_optimizer = optim.Adam(sl_model.parameters(), lr=L_RATE, betas=BETAS)
    
    print("\n\n")
    print("Training the second model to classify COVID and non-COVID images")
    for epoch in range(1,N_EPOCH+1):
        train(sl_model, device, sl_train_loader, sl_optimizer, epoch)
        val_loss = validate(sl_model, device, sl_val_loader, nn.BCELoss())
        
        """
        Early Stopping 
        """
        if temp_loss == None:
            temp_loss = val_loss
        else:
            if val_loss < temp_loss:
                early_stop = 0
                temp_loss = val_loss
            else:
                early_stop += 1
                
        if early_stop == 5:
            break
    
    print("\n\n")
    print("Test Accuracy of the second model:")
    test(sl_model, device, sl_test_loader)
    
if __name__ == '__main__':
    main()

Training the first model to classify normal and infected images


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 1 ,Training Loss: 0.406059917062521
Validation loss: 0.6342885494232178
Validation set accuracy:  72.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 2 ,Training Loss: 0.17371247582137583
Validation loss: 1.6552448272705078
Validation set accuracy:  44.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 3 ,Training Loss: 0.12315673748031258
Validation loss: 0.6368852853775024
Validation set accuracy:  76.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 4 ,Training Loss: 0.11258306334726513
Validation loss: 0.43794435262680054
Validation set accuracy:  84.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 5 ,Training Loss: 0.09918390599545092
Validation loss: 0.4787502586841583
Validation set accuracy:  80.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 6 ,Training Loss: 0.08280799758154898
Validation loss: 1.0345041751861572
Validation set accuracy:  68.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 7 ,Training Loss: 0.06955229019513354
Validation loss: 0.8159579038619995
Validation set accuracy:  76.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 8 ,Training Loss: 0.060979280386818574
Validation loss: 0.29998862743377686
Validation set accuracy:  84.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 9 ,Training Loss: 0.046801506727933886
Validation loss: 0.4144107401371002
Validation set accuracy:  80.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 10 ,Training Loss: 0.04497560428921133
Validation loss: 0.612773060798645
Validation set accuracy:  76.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 11 ,Training Loss: 0.04656482307123952
Validation loss: 0.19504080712795258
Validation set accuracy:  88.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 12 ,Training Loss: 0.03025804430595599
Validation loss: 0.22106721997261047
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 13 ,Training Loss: 0.028688356122584083
Validation loss: 0.22571447491645813
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 14 ,Training Loss: 0.04043141747417394
Validation loss: 0.13673973083496094
Validation set accuracy:  88.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 15 ,Training Loss: 0.033768204959342255
Validation loss: 0.20618373155593872
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 16 ,Training Loss: 0.034934489543084056
Validation loss: 0.2015175223350525
Validation set accuracy:  96.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 17 ,Training Loss: 0.018504678254103055
Validation loss: 0.5077541470527649
Validation set accuracy:  84.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 18 ,Training Loss: 0.03302537915966241
Validation loss: 0.814653754234314
Validation set accuracy:  76.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 19 ,Training Loss: 0.01909679183132539
Validation loss: 0.0425771027803421
Validation set accuracy:  96.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 20 ,Training Loss: 0.03124385985414847
Validation loss: 0.01478220522403717
Validation set accuracy:  100.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 21 ,Training Loss: 0.018870074216683862
Validation loss: 0.24802367389202118
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 22 ,Training Loss: 0.01669026617630152
Validation loss: 0.20362897217273712
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 23 ,Training Loss: 0.016115569342582604
Validation loss: 0.9228680729866028
Validation set accuracy:  80.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 24 ,Training Loss: 0.015922598126562663
Validation loss: 0.12833164632320404
Validation set accuracy:  92.0 %


  0%|          | 0/163 [00:00<?, ?it/s]

Epoch: 25 ,Training Loss: 0.017015099059262864
Validation loss: 0.19462187588214874
Validation set accuracy:  92.0 %



Test Accuracy of the first model:
Test set accuracy:  82.76422764227642 %



Training the second model to classify COVID and non-COVID images


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 1 ,Training Loss: 0.7310113555192947
Validation loss: 0.8027167916297913
Validation set accuracy:  58.8235294117647 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 2 ,Training Loss: 0.5681099840998649
Validation loss: 0.7549899816513062
Validation set accuracy:  58.8235294117647 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 3 ,Training Loss: 0.551237236559391
Validation loss: 0.7957115173339844
Validation set accuracy:  64.70588235294117 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 4 ,Training Loss: 0.5361668160557747
Validation loss: 0.860979437828064
Validation set accuracy:  58.8235294117647 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 5 ,Training Loss: 0.5309222292900085
Validation loss: 0.8926554918289185
Validation set accuracy:  52.94117647058823 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 6 ,Training Loss: 0.5263352981209755
Validation loss: 0.9302441477775574
Validation set accuracy:  47.05882352941177 %


  0%|          | 0/122 [00:00<?, ?it/s]

Epoch: 7 ,Training Loss: 0.5138878846168518
Validation loss: 0.9794571995735168
Validation set accuracy:  52.94117647058823 %



Test Accuracy of the second model:
Test set accuracy:  87.4015748031496 %


In [7]:
'''
Placeholder - Code to generate images and the corresponding labels
if plot == True:
        example_data = np.zeros([24, 150, 150])
        example_pred = np.zeros(24)
        
        for i in range(24):
            example_data[i] = data[i][0].to("cpu").numpy()
            example_pred[i] = pred[i].to("cpu").numpy()
                    
        for i in range(24):
            plt.subplot(5,5,i+1)
            plt.imshow(example_data[i], cmap='gray', interpolation='none')
            plt.title(fl_labels[example_pred[i]])
            plt.xticks([])
            plt.yticks([])
        plt.show()
'''

'\nPlaceholder - Code to generate images and the corresponding labels\nif plot == True:\n        example_data = np.zeros([24, 150, 150])\n        example_pred = np.zeros(24)\n        \n        for i in range(24):\n            example_data[i] = data[i][0].to("cpu").numpy()\n            example_pred[i] = pred[i].to("cpu").numpy()\n                    \n        for i in range(24):\n            plt.subplot(5,5,i+1)\n            plt.imshow(example_data[i], cmap=\'gray\', interpolation=\'none\')\n            plt.title(fl_labels[example_pred[i]])\n            plt.xticks([])\n            plt.yticks([])\n        plt.show()\n'