In [1]:
# Torch
import torch
from torch.utils.data import DataLoader
from torch import optim, nn

# Custom
import data
import model

In [2]:
"""
Data Split for First Layer classification task (normal vs infected) 
"""

fl_labels = {
0 : "Normal",
1 : "Infected"
}

fl_train = data.Lung_Dataset('train', verbose = 0)
fl_test = data.Lung_Dataset('test', verbose = 0)
fl_val = data.Lung_Dataset('val', verbose = 0)

"""
Data Split for Second Layer classification task (COVID vs Non-COVID) 
"""

sl_labels = {
0 : "COVID",
1 : "Non-COVID"
}

sl_train = data.Lung_Dataset('train', verbose = 2)
sl_test = data.Lung_Dataset('test', verbose = 2)
sl_val = data.Lung_Dataset('val', verbose = 2)

In [3]:
def main():
    N_EPOCH = 200
    L_RATE = 0.001
    BATCH_SIZE = 32
    PATIENCE = 5
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    fl_train_loader = DataLoader(fl_train, batch_size=BATCH_SIZE, shuffle=True)
    fl_val_loader = DataLoader(fl_val, batch_size=BATCH_SIZE, shuffle=True)
    fl_test_loader = DataLoader(fl_test, batch_size=BATCH_SIZE, shuffle=True)
    
    fl_model = model.CNN().to(device)
    fl_optimizer = optim.Adam(fl_model.parameters(), lr=L_RATE)
    
    print("Training the first model to classify normal and infected images")
    model.train(fl_model, device, nn.BCELoss(), fl_optimizer, fl_train_loader, fl_val_loader, N_EPOCH, PATIENCE)

    print("\n\n")
    print("Test Accuracy of the first model:")
    model.test(fl_model, device, fl_test_loader)

    fl_model.to("cpu")
    
    
    #Second Model
    sl_train_loader = DataLoader(sl_train, batch_size=BATCH_SIZE, shuffle=True)
    sl_val_loader = DataLoader(sl_val, batch_size=BATCH_SIZE, shuffle=True)
    sl_test_loader = DataLoader(sl_test, batch_size=BATCH_SIZE, shuffle=True)
    
    sl_model = model.CNN().to(device)
    
    sl_optimizer = optim.Adam(sl_model.parameters(), lr=L_RATE)
    
    print("\n\n")
    print("Training the second model to classify COVID and non-COVID images")
    model.train(sl_model, device, nn.BCELoss(), sl_optimizer, sl_train_loader, sl_val_loader, N_EPOCH, PATIENCE)
    
    print("\n\n")
    print("Test Accuracy of the second model:")
    model.test(sl_model, device, sl_test_loader)
    
if __name__ == '__main__':
    main()

Training the first model to classify normal and infected images
Epoch 1


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.7909051775932312
Validation set accuracy:  68.0 %

Epoch 2


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 1.283989667892456
Validation set accuracy:  72.0 %

Epoch 3


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.5638637542724609
Validation set accuracy:  80.0 %

Epoch 4


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.22470298409461975
Validation set accuracy:  88.0 %

Epoch 5


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.3365747034549713
Validation set accuracy:  88.0 %

Epoch 6


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.6223231554031372
Validation set accuracy:  84.0 %

Epoch 7


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.17938606441020966
Validation set accuracy:  88.0 %

Epoch 8


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.38574036955833435
Validation set accuracy:  88.0 %

Epoch 9


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.3254808485507965
Validation set accuracy:  84.0 %

Epoch 10


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.0829935148358345
Validation set accuracy:  96.0 %

Epoch 11


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.28695565462112427
Validation set accuracy:  92.0 %

Epoch 12


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.9076387286186218
Validation set accuracy:  80.0 %

Epoch 13


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.327330082654953
Validation set accuracy:  84.0 %

Epoch 14


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.11998473107814789
Validation set accuracy:  96.0 %

Epoch 15


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))


Validation loss: 0.15012872219085693
Validation set accuracy:  96.0 %




Test Accuracy of the first model:
Test set accuracy:  76.58536585365853 %



Training the second model to classify COVID and non-COVID images
Epoch 1


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.815004289150238
Validation set accuracy:  58.8235294117647 %

Epoch 2


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.887675940990448
Validation set accuracy:  47.05882352941177 %

Epoch 3


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.8344905376434326
Validation set accuracy:  52.94117647058823 %

Epoch 4


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.9076252579689026
Validation set accuracy:  58.8235294117647 %

Epoch 5


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.8071427345275879
Validation set accuracy:  64.70588235294117 %

Epoch 6


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.851006805896759
Validation set accuracy:  47.05882352941177 %

Epoch 7


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.917757511138916
Validation set accuracy:  41.1764705882353 %

Epoch 8


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 0.890629231929779
Validation set accuracy:  64.70588235294117 %

Epoch 9


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 1.0558222532272339
Validation set accuracy:  35.294117647058826 %

Epoch 10


HBox(children=(FloatProgress(value=0.0, max=122.0), HTML(value='')))


Validation loss: 1.007980465888977
Validation set accuracy:  58.8235294117647 %




Test Accuracy of the second model:
Test set accuracy:  92.1259842519685 %


In [4]:
'''
Placeholder - Code to generate images and the corresponding labels
if plot == True:
        example_data = np.zeros([24, 150, 150])
        example_pred = np.zeros(24)
        
        for i in range(24):
            example_data[i] = data[i][0].to("cpu").numpy()
            example_pred[i] = pred[i].to("cpu").numpy()
                    
        for i in range(24):
            plt.subplot(5,5,i+1)
            plt.imshow(example_data[i], cmap='gray', interpolation='none')
            plt.title(fl_labels[example_pred[i]])
            plt.xticks([])
            plt.yticks([])
        plt.show()
'''

'\nPlaceholder - Code to generate images and the corresponding labels\nif plot == True:\n        example_data = np.zeros([24, 150, 150])\n        example_pred = np.zeros(24)\n        \n        for i in range(24):\n            example_data[i] = data[i][0].to("cpu").numpy()\n            example_pred[i] = pred[i].to("cpu").numpy()\n                    \n        for i in range(24):\n            plt.subplot(5,5,i+1)\n            plt.imshow(example_data[i], cmap=\'gray\', interpolation=\'none\')\n            plt.title(fl_labels[example_pred[i]])\n            plt.xticks([])\n            plt.yticks([])\n        plt.show()\n'