<a href="https://colab.research.google.com/github/MateiGrama/diss/blob/CovidNet/COVIDNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Covidx dataset extraction and training with dl models

## clone githubs and unzip the two datasets

In [0]:
! git clone https://github.com/ieee8023/covid-chestxray-dataset.git
! git clone https://github.com/IliasPap/COVID-Net.git

COPY_FILE = True

# # !pip install pydicom
# ! pip install -q kaggle
# ! mkdir ~/.kaggle

# ! pip install kaggle==1.5.6
# ! cp kaggle.json ~/.kaggle/
# ! chmod 600 ~/.kaggle/kaggle.json


# ! kaggle competitions download -c rsna-pneumonia-detection-challenge

Cloning into 'covid-chestxray-dataset'...
remote: Enumerating objects: 1079, done.[K
remote: Total 1079 (delta 0), reused 0 (delta 0), pack-reused 1079[K
Receiving objects: 100% (1079/1079), 188.60 MiB | 32.86 MiB/s, done.
Resolving deltas: 100% (492/492), done.
Checking out files: 100% (270/270), done.
Cloning into 'COVID-Net'...
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 202 (delta 1), reused 5 (delta 1), pack-reused 195[K
Receiving objects: 100% (202/202), 3.20 MiB | 6.86 MiB/s, done.
Resolving deltas: 100% (111/111), done.


## KAGGLE dataset from google drive

In [0]:
! mkdir /content/rsna_dataset
! unzip '/content/drive/My Drive/MEDICAL/rsna-pneumonia-detection-challenge.zip' -d /content/rsna_dataset/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/rsna_dataset/stage_2_train_images/34bf2fcd-131a-428c-9a21-cd2fa9041f9b.dcm  

In [0]:
! pip install pydicom
import numpy as np
import pandas as pd
import os
import random 
from shutil import copyfile
import pydicom as dicom
import cv2

In [0]:

seed = 0
np.random.seed(seed) # Reset the seed so all runs are the same.
random.seed(seed)
MAXVAL = 255  # Range [0 255]
root = '/content/covid-chestxray-dataset'

if (COPY_FILE):
    savepath = root + '/data'
    if(not os.path.exists(savepath)):
        os.makedirs(savepath)
    savepath = root + '/data/train'
    if(not os.path.exists(savepath)):
        os.makedirs(savepath)
    savepath = root + '/data/test'
    if(not os.path.exists(savepath)):
        os.makedirs(savepath)

savepath = root + '/data'
# path to covid-19 dataset from https://github.com/ieee8023/covid-chestxray-dataset
imgpath = root + '/images' 
csvpath = root + '/metadata.csv'

# path to https://www.kaggle.com/c/rsna-pneumonia-detection-challenge
kaggle_datapath = '/content/rsna_kaggle_dataset'
kaggle_csvname = 'stage_2_detailed_class_info.csv' # get all the normal from here
kaggle_csvname2 = 'stage_2_train_labels.csv' # get all the 1s from here since 1 indicate pneumonia
kaggle_imgpath = 'stage_2_train_images'

# parameters for COVIDx dataset
train = []
test = []
test_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
train_count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}

mapping = dict()
mapping['COVID-19'] = 'COVID-19'
mapping['SARS'] = 'pneumonia'
mapping['MERS'] = 'pneumonia'
mapping['Streptococcus'] = 'pneumonia'
mapping['Normal'] = 'normal'
mapping['Lung Opacity'] = 'pneumonia'
mapping['1'] = 'pneumonia'

# train/test split
split = 0.1

In [0]:
# adapted from https://github.com/mlmed/torchxrayvision/blob/master/torchxrayvision/datasets.py#L814
csv = pd.read_csv(csvpath, nrows=None)
idx_pa = csv["view"] == "PA"  # Keep only the PA view
csv = csv[idx_pa]

pneumonias = ["COVID-19", "SARS", "MERS", "ARDS", "Streptococcus"]
pathologies = ["Pneumonia","Viral Pneumonia", "Bacterial Pneumonia", "No Finding"] + pneumonias
pathologies = sorted(pathologies)

## Data distribution covid-chestxray-dataset

In [0]:
# get non-COVID19 viral, bacteria, and COVID-19 infections from covid-chestxray-dataset
# stored as patient id, image filename and label
filename_label = {'normal': [], 'pneumonia': [], 'COVID-19': []}
count = {'normal': 0, 'pneumonia': 0, 'COVID-19': 0}
print(csv.keys())
for index, row in csv.iterrows():
    f = row['finding']
    if f in mapping:
        count[mapping[f]] += 1
        entry = [int(row['patientid']), row['filename'], mapping[f]]
        filename_label[mapping[f]].append(entry)

print('Data distribution from covid-chestxray-dataset:')
print(count)

## add covid-chestxray-dataset into COVIDx datase

In [0]:
# add covid-chestxray-dataset into COVIDx dataset
# since covid-chestxray-dataset doesn't have test dataset
# split into train/test by patientid
# for COVIDx:
# patient 8 is used as non-COVID19 viral test
# patient 31 is used as bacterial test
# patients 19, 20, 36, 42, 86 are used as COVID-19 viral test

for key in filename_label.keys():
    arr = np.array(filename_label[key])
    if arr.size == 0:
        continue
    # split by patients
    # num_diff_patients = len(np.unique(arr[:,0]))
    # num_test = max(1, round(split*num_diff_patients))
    # select num_test number of random patients
    if key == 'pneumonia':
        test_patients = ['8', '31']
    elif key == 'COVID-19':
        test_patients = ['19', '20', '36', '42', '86'] # random.sample(list(arr[:,0]), num_test)
    else: 
        test_patients = []
    print('Key: ', key)
    print('Test patients: ', test_patients)
    # go through all the patients
    for patient in arr:
        if patient[0] in test_patients:
            if (COPY_FILE):
                copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'test', patient[1]))
                test.append(patient)
                test_count[patient[2]] += 1
            else:
                print("WARNING   :   passing copy file !!!!!!!!!!!!!!!!!!!!!!")
                break
        else:
            if (COPY_FILE):
                copyfile(os.path.join(imgpath, patient[1]), os.path.join(savepath, 'train', patient[1]))
                train.append(patient)
                train_count[patient[2]] += 1

            else:
                print("WARNING   :   passing copy file !!!!!!!!!!!!!!!!!!!!!!")
                break

print('test count: ', test_count)
print('train count: ', train_count)

## Copy kaggle data to train and test folders

In [0]:
# add normal and rest of pneumonia cases from https://www.kaggle.com/c/rsna-pneumonia-detection-challenge


kaggle_datapath = '/content/rsna_dataset'

print(kaggle_datapath)
csv_normal = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname), nrows=None)
csv_pneu = pd.read_csv(os.path.join(kaggle_datapath, kaggle_csvname2), nrows=None)
patients = {'normal': [], 'pneumonia': []}

for index, row in csv_normal.iterrows():
    if row['class'] == 'Normal':
        patients['normal'].append(row['patientId'])

for index, row in csv_pneu.iterrows():
    if int(row['Target']) == 1:
        patients['pneumonia'].append(row['patientId'])

for key in patients.keys():
    arr = np.array(patients[key])
    if arr.size == 0:
        continue
    # split by patients 
    # num_diff_patients = len(np.unique(arr))
    # num_test = max(1, round(split*num_diff_patients))
    #'/content/COVID-Net/'
    test_patients = np.load('/content/COVID-Net/rsna_test_patients_{}.npy'.format(key)) # random.sample(list(arr), num_test)
    # np.save('rsna_test_patients_{}.npy'.format(key), np.array(test_patients))
    for patient in arr:
        ds = dicom.dcmread(os.path.join(kaggle_datapath, kaggle_imgpath, patient + '.dcm'))
        pixel_array_numpy = ds.pixel_array
        imgname = patient + '.png'
        if patient in test_patients:
            if (COPY_FILE):
                cv2.imwrite(os.path.join(savepath, 'test', imgname), pixel_array_numpy)
                test.append([patient, imgname, key])
                test_count[key] += 1
            else:
                print("WARNING   :   passing copy file !!!!!!!!!!!!!!!!!!!!!!")
                break
        else:
            if (COPY_FILE):
                cv2.imwrite(os.path.join(savepath, 'train', imgname), pixel_array_numpy)
                train.append([patient, imgname, key])
                train_count[key] += 1
            else:
                print("WARNING   :   passing copy file !!!!!!!!!!!!!!!!!!!!!!")
                break
print('test count: ', test_count)
print('train count: ', train_count)

## Final data stats

In [0]:
# final stats
print('Final stats')
print('Train count: ', train_count)
print('Test count: ', test_count)
print('Total length of train: ', len(train))
print('Total length of test: ', len(test))

## Train and test file extraction

In [0]:
# export to train and test csv
# format as patientid, filename, label, separated by a space
train_file = open("train_split_v2.txt","w") 
for sample in train:
    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\n'
    train_file.write(info)

train_file.close()

test_file = open("test_split_v2.txt", "w")
for sample in test:
    info = str(sample[0]) + ' ' + sample[1] + ' ' + sample[2] + '\n'
    test_file.write(info)

test_file.close()

In [0]:
# import glob

# images = glob.glob('/content/drive/My Drive/MEDICAL/data/*/*')
# print(len(images))

# train = glob.glob('/content/drive/My Drive/MEDICAL/data/train/*')
# test = glob.glob('/content/drive/My Drive/MEDICAL/data/test/*')
# print(len(train))
# print(len(test))

# Training on Covidx dataset

## Training imports


In [0]:
! pip install torchsummaryX


In [0]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import transforms,models
from torchsummaryX import summary
import numpy as np


import argparse
import csv
from PIL import Image


import os
import shutil
import time
from collections import OrderedDict
import json

## Utils 

In [0]:

def write_score(writer, iter, mode, metrics):
    writer.add_scalar(mode + '/loss', metrics.data['loss'], iter)
    writer.add_scalar(mode + '/acc', metrics.data['correct'] / metrics.data['total'], iter)


def write_train_val_score(writer, epoch, train_stats, val_stats):
    writer.add_scalars('Loss', {'train': train_stats[0],
                                'val': val_stats[0],
                                }, epoch)
    writer.add_scalars('Coeff', {'train': train_stats[1],
                                 'val': val_stats[1],
                                 }, epoch)

    writer.add_scalars('Air', {'train': train_stats[2],
                               'val': val_stats[2],
                               }, epoch)

    writer.add_scalars('CSF', {'train': train_stats[3],
                               'val': val_stats[3],
                               }, epoch)
    writer.add_scalars('GM', {'train': train_stats[4],
                              'val': val_stats[4],
                              }, epoch)
    writer.add_scalars('WM', {'train': train_stats[5],
                              'val': val_stats[5],
                              }, epoch)
    return


def showgradients(model):
    for param in model.parameters():
        print(type(param.data), param.size())
        print("GRADS= \n", param.grad)





def datestr():
    now = time.gmtime()
    return '{}{:02}{:02}_{:02}{:02}'.format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min)


def save_checkpoint(state, is_best, path,  filename='last'):

    name = os.path.join(path, filename+'_checkpoint.pth.tar')
    print(name)
    torch.save(state, name)



def save_model(model,optimizer, args, metrics, epoch, best_pred_loss,confusion_matrix):
    loss = metrics.data['loss']
    save_path = args.save
    make_dirs(save_path)
    
    with open(save_path + '/training_arguments.txt', 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    
    is_best = False
    if loss < best_pred_loss:
        is_best = True
        best_pred_loss = loss
        save_checkpoint({'epoch': epoch,
                         'state_dict': model.state_dict(),
                         'optimizer': optimizer.state_dict(),
                         'metrics': metrics.data },
                        is_best, save_path, args.model + "_best")
        np.save(save_path + '/best_confusion_matrix.npy',confusion_matrix.cpu().numpy())
            
    else:
        save_checkpoint({'epoch': epoch,
                         'state_dict': model.state_dict(),
                         'optimizer': optimizer.state_dict(),
                         'metrics': metrics.data},
                        False, save_path, args.model + "_last")

    return best_pred_loss


def make_dirs(path):
    if not os.path.exists(path):

        os.makedirs(path)


def create_stats_files(path):
    train_f = open(os.path.join(path, 'train.csv'), 'w')
    val_f = open(os.path.join(path, 'val.csv'), 'w')
    return train_f, val_f


def read_json_file(fname):
    with open(fname, 'r') as handle:
        return json.load(handle, object_hook=OrderedDict)


def write_json_file(content, fname):
    with open(fname, 'w') as handle:
        json.dump(content, handle, indent=4, sort_keys=False)


def read_filepaths(file):
    paths, labels = [], []
    with open(file, 'r') as f:
        lines = f.read().splitlines()

        for idx, line in enumerate(lines):
            if ('/ c o' in line):
                break
            subjid, path, label = line.split(' ')

            paths.append(path)
            labels.append(label)
    return paths, labels



def select_model(args):
    if args.model == 'COVIDNet_small':
        return CovidNet('small', n_classes=args.classes)

    elif args.model == 'COVIDNet_large':
        return CovidNet('large', n_classes=args.classes)
    elif args.model == 'resnet18':
        return CNN(args.classes, 'resnet18')


def select_optimizer(args, model):
    if args.opt == 'sgd':
        return optim.SGD(model.parameters(), lr=args.lr, momentum=0.5, weight_decay=args.weight_decay)
    elif args.opt == 'adam':
        return optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    elif args.opt == 'rmsprop':
        return optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)


def print_stats(args, epoch, num_samples, trainloader, metrics):
    if (num_samples % args.log_interval == 1):
        print("Epoch:{:2d}\tSample:{:5d}/{:5d}\tLoss:{:.4f}\tAccuracy:{:.2f}".format(epoch,
                                                                                         num_samples,
                                                                                         len(
                                                                                             trainloader) * args.batch_size,
                                                                                         metrics.data[
                                                                                             'loss'] / num_samples,
                                                                                         metrics.data[
                                                                                             'correct'] /
                                                                                         metrics.data[
                                                                                             'total']))


def print_summary(args, epoch, num_samples, metrics, mode=''):
    print(mode + "\n SUMMARY EPOCH:{:2d}\tSample:{:5d}/{:5d}\tLoss:{:.4f}\tAccuracy:{:.2f}\n".format(epoch,
                                                                                                     num_samples,
                                                                                                     num_samples ,
                                                                                                     metrics.data[
                                                                                                         'loss'] / num_samples,                                                                             
                                                                                                     metrics.data[
                                                                                                         'correct'] /
                                                                                                     metrics.data[
                                                                                                         'total']))


def confusion_matrix(nb_classes):



    confusion_matrix = torch.zeros(nb_classes, nb_classes)
    with torch.no_grad():
        for i, (inputs, classes) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model_ft(inputs)
            _, preds = torch.max(outputs, 1)
            for t, p in zip(classes.view(-1), preds.view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1

    print(confusion_matrix)


## METRICS

In [0]:



class Metrics:
    def __init__(self, path, keys=None, writer=None):
        self.writer = writer

        self.data = {'correct': 0,
                     'total': 0,
                     'loss': 0,
                     'accuracy': 0,
                     }
        self.save_path = path

    def reset(self):
        for key in self.data:
            self.data[key] = 0

    def update_key(self, key, value, n=1):
        if self.writer is not None:
            self.writer.add_scalar(key, value)
        self.data[key] += value

    def update(self, values):
        for key in self.data:
            self.data[key] += values[key]

    def avg_acc(self):
        return self.data['correct'] / self.data['total']

    def avg_loss(self):
        return self.data['loss'] / self.data['total']

    def save(self):
        with open(self.save_path, 'w') as save_file:
            a = 0  # csv.writer()
            # TODO


def accuracy(output, target):
    with torch.no_grad():
        pred = torch.argmax(output, dim=1)
        assert pred.shape[0] == len(target)
        correct = 0
        correct += torch.sum(pred == target).item()
    return correct, len(target), correct / len(target)


def top_k_acc(output, target, k=3):
    with torch.no_grad():
        pred = torch.topk(output, k, dim=1)[1]
        assert pred.shape[0] == len(target)
        correct = 0
        for i in range(k):
            correct += torch.sum(pred[:, i] == target).item()
    return correct / len(target)


## LOSS

In [0]:
def nll_loss(output, target):
    return F.nll_loss(output, target)


def crossentropy_loss(output, target):
    return F.cross_entropy(output, target)

def focal_loss(output,target):
    ce_loss = F.cross_entropy(output, target, reduction='none')
    #print(ce_loss.shape)
    pt = torch.exp(-ce_loss)
    alpha = 0.25
    gamma = 2
    focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean() 
    return focal_loss

## CNN models

In [0]:


class CNN(nn.Module):
    def __init__(self,classes,model='resnet18'):
        super(CNN,self).__init__()
        if(model == 'resnet18'):
            self.cnn = models.resnet18(pretrained=True)
            self.cnn.fc = nn.Linear(512,classes)
        elif (model == 'mobilenet2'):

            self.cnn = models.resnext50_32x4d(pretrained=True)
            self.cnn.classifier = nn.Linear(1280,classes)
    def forward (self,x):
        return self.cnn(x)

##  COVID-NET

In [0]:



class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class PEXP(nn.Module):
    def __init__(self, n_input, n_out):
        super(PEXP, self).__init__()

        '''
        • First-stage Projection: 1×1 convolutions for projecting input features to a lower dimension,

        • Expansion: 1×1 convolutions for expanding features
            to a higher dimension that is different than that of the
            input features,


        • Depth-wise Representation: efficient 3×3 depthwise convolutions for learning spatial characteristics to
            minimize computational complexity while preserving
            representational capacity,

        • Second-stage Projection: 1×1 convolutions for projecting features back to a lower dimension, and

        • Extension: 1×1 convolutions that finally extend channel dimensionality to a higher dimension to produce
             the final features.
             
        # self.first_stage = nn.Conv2d(in_channels = n_input, out_channels=n_input//2, kernel_size=1)
        # self.expansion = nn.Conv2d(in_channels = n_input//2, out_channels=int(3*n_input/4), kernel_size=1)
        # self.dwc = nn.Conv2d(in_channels = int(3*n_input/4), out_channels=int(3*n_input/4), kernel_size=3,groups=int(3*n_input/4))
        # self.second_stage = nn.Conv2d(in_channels = int(3*n_input/4), out_channels=n_input//2, kernel_size=1)
        # self.expansion = nn.Conv2d(in_channels = n_input//2, out_channels=n_out, kernel_size=1)
        self.network = nn.Sequential(nn.Conv2d(in_channels=n_input, out_channels=n_input // 2, kernel_size=1),

                                     nn.Conv2d(in_channels=n_input // 2, out_channels=int(3 * n_input / 4),
                                               kernel_size=1),

                                     nn.Conv2d(in_channels=int(3 * n_input / 4), out_channels=int(3 * n_input / 4),
                                               kernel_size=3, groups=int(3 * n_input / 4), padding=1),

                                     nn.Conv2d(in_channels=int(3 * n_input / 4), out_channels=n_input // 2,
                                               kernel_size=1),

                                     nn.Conv2d(in_channels=n_input // 2, out_channels=n_out, kernel_size=1))
        '''


        self.network = nn.Sequential(nn.Conv2d(in_channels=n_input, out_channels=n_input // 4, kernel_size=1),

                                     nn.Conv2d(in_channels=n_input // 4, out_channels=n_input // 2,
                                               kernel_size=1),

                                     nn.Conv2d(in_channels=n_input // 2, out_channels=n_input // 2,
                                               kernel_size=3, groups=n_input // 2, padding=1),

                                     nn.Conv2d(in_channels=n_input // 2, out_channels=n_input // 4,
                                               kernel_size=1),

                                     nn.Conv2d(in_channels=n_input // 4, out_channels=n_out, kernel_size=1))

    def forward(self, x):
        return self.network(x)


class CovidNet(nn.Module):
    def __init__(self, model='small',n_classes=3):
        super(CovidNet, self).__init__()
        filters = {
            'pexp1_1': [64, 256],
            'pexp1_2': [256, 256],
            'pexp1_3': [256, 256],
            'pexp2_1': [256, 512],
            'pexp2_2': [512, 512],
            'pexp2_3': [512, 512],
            'pexp2_4': [512, 512],
            'pexp3_1': [512, 1024],
            'pexp3_2': [1024, 1024],
            'pexp3_3': [1024, 1024],
            'pexp3_4': [1024, 1024],
            'pexp3_5': [1024, 1024],
            'pexp3_6': [1024, 1024],
            'pexp4_1': [1024, 2048],
            'pexp4_2': [2048, 2048],
            'pexp4_3': [2048, 2048],
        }


        self.add_module('conv1', nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3))
        for key in filters:

          if ('pool' in key):
              self.add_module(key, nn.MaxPool2d(filters[key][0], filters[key][1]))
          else:
              self.add_module(key, PEXP(filters[key][0], filters[key][1]))


        if(model == 'large'):
            
            self.add_module('conv1_1x1', nn.Conv2d(in_channels=64, out_channels=256, kernel_size=1))
            self.add_module('conv2_1x1', nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1))
            self.add_module('conv3_1x1', nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1))
            self.add_module('conv4_1x1', nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1))

            self.__forward__  = self.forward_large_net
        else:
            self.__forward__ = self.forward_small_net
        self.add_module('flatten', Flatten())
        self.add_module('fc1', nn.Linear(7 * 7 * 2048, 1024))

        self.add_module('fc2', nn.Linear(1024, 256))
        self.add_module('classifier', nn.Linear(256, n_classes))

    def forward(self,x):
        return self.__forward__(x)


    def forward_large_net(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)
        out_conv1_1x1 = self.conv1_1x1(x)

        pepx11 = self.pexp1_1(x)
        pepx12 = self.pexp1_2(pepx11 + out_conv1_1x1)
        pepx13 = self.pexp1_3(pepx12 + pepx11 + out_conv1_1x1)

        out_conv2_1x1 = F.max_pool2d(self.conv2_1x1(pepx12 + pepx11 + pepx13 +  out_conv1_1x1),2)

        pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2) + F.max_pool2d(out_conv1_1x1,2))
        pepx22 = self.pexp2_2(pepx21 + out_conv2_1x1)
        pepx23 = self.pexp2_3(pepx22 + pepx21 + out_conv2_1x1)
        pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22 + out_conv2_1x1)

        out_conv3_1x1 = F.max_pool2d(self.conv3_1x1(pepx22 + pepx21 + pepx23 + pepx24 + out_conv2_1x1),2)

        pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2) + F.max_pool2d(out_conv2_1x1,2))
        pepx32 = self.pexp3_2(pepx31 + out_conv3_1x1)
        pepx33 = self.pexp3_3(pepx31 + pepx32 + out_conv3_1x1)
        pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33 + out_conv3_1x1)
        pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34 + out_conv3_1x1)
        pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35 + out_conv3_1x1)

        out_conv4_1x1 = F.max_pool2d(self.conv4_1x1(pepx31 + pepx32 + pepx33 + pepx34 + pepx35+ pepx36 + out_conv3_1x1),2)

        pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2)+ F.max_pool2d(out_conv3_1x1,2))
        pepx42 = self.pexp4_2(pepx41 + out_conv4_1x1)
        pepx43 = self.pexp4_3(pepx41 + pepx42 + out_conv4_1x1)
        flattened = self.flatten(pepx41 + pepx42 + pepx43 + out_conv4_1x1)

        fc1out = F.relu(self.fc1(flattened))
        fc2out = F.relu(self.fc2(fc1out))
        logits = self.classifier(fc2out)
        return logits

    def forward_small_net(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),2)


        pepx11 = self.pexp1_1(x)
        pepx12 = self.pexp1_2(pepx11 )
        pepx13 = self.pexp1_3(pepx12 + pepx11 )

        

        pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2) )
        pepx22 = self.pexp2_2(pepx21 )
        pepx23 = self.pexp2_3(pepx22 + pepx21)
        pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22 )

        

        pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2) )
        pepx32 = self.pexp3_2(pepx31)
        pepx33 = self.pexp3_3(pepx31 + pepx32)
        pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33)
        pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34)
        pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35)



        pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2))
        pepx42 = self.pexp4_2(pepx41 )
        pepx43 = self.pexp4_3(pepx41 + pepx42)
        flattened = self.flatten(pepx41 + pepx42 + pepx43)

        fc1out = F.relu(self.fc1(flattened))
        fc2out = F.relu(self.fc2(fc1out))
        logits = self.classifier(fc2out)
        return logits







'''
 FORWARD ONLY WITH SKIP CONNECTIONS

    def forward(self, x):
        x = self.pool1(self.conv1(x))
        out_conv1_1x1 = self.conv1_1x1(x)

        pepx11 = self.pexp1_1(x)
        pepx12 = self.pexp1_2(pepx11)
        pepx13 = self.pexp1_3(pepx12 + pepx11)

        pepx21 = self.pexp2_1(F.max_pool2d(pepx13, 2) + F.max_pool2d(pepx11, 2) + F.max_pool2d(pepx12, 2))
        pepx22 = self.pexp2_2(pepx21)
        pepx23 = self.pexp2_3(pepx22 + pepx21)
        pepx24 = self.pexp2_4(pepx23 + pepx21 + pepx22)

        pepx31 = self.pexp3_1(F.max_pool2d(pepx24, 2) + F.max_pool2d(pepx21, 2) + F.max_pool2d(pepx22,2) + F.max_pool2d(pepx23, 2))
        pepx32 = self.pexp3_2(pepx31)
        pepx33 = self.pexp3_3(pepx31 + pepx32)
        pepx34 = self.pexp3_4(pepx31 + pepx32 + pepx33)
        pepx35 = self.pexp3_5(pepx31 + pepx32 + pepx33 + pepx34)
        pepx36 = self.pexp3_6(pepx31 + pepx32 + pepx33 + pepx34 + pepx35)

        pepx41 = self.pexp4_1(F.max_pool2d(pepx31, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx32, 2) + F.max_pool2d(pepx34, 2)+ F.max_pool2d(pepx35, 2)+ F.max_pool2d(pepx36, 2))
        pepx42 = self.pexp4_2(pepx41)
        pepx43 = self.pexp4_3(pepx41 + pepx42)
        flattened = self.flatten(pepx41 + pepx42 + pepx43)

        fc1out = self.fc1(flattened)
        fc2out = self.fc2(fc1out)
        logits = self.classifier(fc2out)
        return x


'''

## Dataloader


In [0]:


class COVIDxDataset(Dataset):
    """
    Code for reading the COVIDxDataset
    """

    def __init__(self, mode, n_classes=3, dataset_path='./datasets', dim=(224, 224)):
        self.root = str(dataset_path)+'/'+mode+'/'
       
       
        self.CLASSES = n_classes
        self.dim = dim
        self.COVIDxDICT = {'pneumonia': 0, 'normal': 1, 'COVID-19': 2}
        testfile = '/content/test_split_v2.txt'
        trainfile = '/content/train_split_v2.txt'
        if (mode == 'train'):
            self.paths, self.labels = read_filepaths(trainfile)
        elif (mode == 'test'):
            self.paths, self.labels = read_filepaths(testfile)
        print("{} examples =  {}".format(mode,len(self.paths)))
        self.mode = mode

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):

        image_tensor = self.load_image(self.root+self.paths[index], self.dim, augmentation=self.mode)
        label_tensor = torch.tensor(self.COVIDxDICT[self.labels[index]],dtype=torch.long)

        return image_tensor,label_tensor

    def load_image(self, img_path, dim, augmentation='test'):
        if not os.path.exists(img_path):
            print("IMAGE DOES NOT EXIST {}".format(img_path))
        image = Image.open(img_path).convert('RGB') 
        image = image.resize(dim).convert('RGB') 
        
        #image.convert('RGB')
        t = transforms.ToTensor()
        # print(t(image).shape)
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        norm = transforms.Normalize(mean=[0.5, 0.5,0.5 ],
                                    std=[1, 1, 1])

        image_tensor = normalize(t(image))
       
        # if(image_tensor.size(0)>1):
        #     #print(img_path," > 1 channels")
        #     image_tensor = image_tensor.mean(dim=0,keepdim=True)
        return image_tensor


# Trainer functions

In [0]:

def initialize(args):
    if args.device is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)
    model = select_model(args)
    
    optimizer = select_optimizer(args,model)
    if (args.cuda):
        model.cuda()

    train_params = {'batch_size': args.batch_size,
                    'shuffle': True,
                    'num_workers': 2}

    test_params = {'batch_size': args.batch_size,
                   'shuffle': False,
                   'num_workers': 1}

    train_loader = COVIDxDataset(mode='train', n_classes=args.classes, dataset_path=args.dataset,
                                 dim=(224, 224))
    val_loader = COVIDxDataset(mode='test', n_classes=args.classes, dataset_path=args.dataset,
                               dim=(224, 224))
    training_generator = DataLoader(train_loader, **train_params)
    val_generator = DataLoader(val_loader, **test_params)
    return model, optimizer,training_generator,val_generator


def train(args, model, trainloader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metrics = Metrics('')
    metrics.reset()
    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss = focal_loss(output, target)
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)

        num_samples = batch_idx * args.batch_size + 1
        metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc})
        print_stats(args, epoch, num_samples, trainloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Training")
    return metrics


def validation(args, model, testloader, epoch):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metrics = Metrics('')
    metrics.reset()
    confusion_matrix = torch.zeros(args.classes, args.classes)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss = focal_loss(output, target)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1
            _, preds = torch.max(output, 1)
            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                    confusion_matrix[t.long(), p.long()] += 1
            metrics.update({'correct': correct, 'total': total, 'loss': loss.item(), 'accuracy': acc})
            #print_stats(args, epoch, num_samples, testloader, metrics)

    print_summary(args, epoch, num_samples, metrics, mode="Validation")
    return metrics,confusion_matrix



# MAIN

In [0]:


def main():




    args = get_arguments()
    SEED = args.seed
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if(args.cuda):
        torch.cuda.manual_seed(SEED)
    model, optimizer,training_generator,val_generator = initialize(args)
    
    print(model)

    best_pred_loss = 1000.0
    scheduler = ReduceLROnPlateau(optimizer, factor=0.5, patience=2, min_lr=1e-5, verbose=True)
    print('Checkpoint folder ',args.save)
    #writer = SummaryWriter(log_dir='../runs/' + args.model, comment=args.model)
    for epoch in range(1, args.nEpochs + 1):
        train(args, model, training_generator, optimizer, epoch)
        val_metrics,confusion_matrix = validation(args, model, val_generator, epoch)
        #confusion_matrix = torch.tensor([0.0])
        #val_metrics = Metrics('')
        best_pred_loss = save_model(model,optimizer, args,val_metrics, epoch, best_pred_loss,confusion_matrix)
        #print('avg lpss ' ,val_metrics.avg_loss())
        print(confusion_matrix.cpu().numpy())
        scheduler.step(val_metrics.avg_loss())
        


def get_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=2)
    parser.add_argument('--log_interval', type=int, default=1000)
    parser.add_argument('--dataset_name', type=str, default="COVIDx")
    parser.add_argument('--nEpochs', type=int, default=250)
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--seed', type=int, default=123)
    parser.add_argument('--classes', type=int, default=3)
    parser.add_argument('--inChannels', type=int, default=1)
    parser.add_argument('--lr', default=2e-5, type=float,
                        help='learning rate (default: 1e-3)')
    parser.add_argument('--weight_decay', default=1e-7, type=float,
                        help='weight decay (default: 1e-6)')
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint (default: none)')
    parser.add_argument('--model', type=str, default='COVIDNet_large',
                        choices=('COVIDNET'))
    parser.add_argument('--opt', type=str, default='adam',
                        choices=('sgd', 'adam', 'rmsprop'))
    parser.add_argument('--dataset', type=str, default='/content/covid-chestxray-dataset/data',
                        help='path to dataset ')
    parser.add_argument('--save', type=str, default='/content/drive/My Drive/MEDICAL/saved/COVIDNet_small'+datestr() ,
                        help='path to checkpoint ')
    args = parser.parse_args([])
    return args


if __name__ == '__main__':
    main()
