# Detecting COVID-19 With Chest X Ray Using Resnet 18, densenet and loTenet In PyTorch

Image classification of Chest X Rays in one of three classes: Normal, Viral Pneumonia, COVID-19

Dataset from COVID-19 Radiography Dataset on Kaggle

# Importing Libraries

In [1]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np
import time
import torch
from models.lotenet import loTeNet
from torchvision import transforms, datasets
import pdb
from PIL import Image
from matplotlib import pyplot as plt
from models.Densenet import *
import argparse
torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

Using PyTorch version 1.3.1


# Preparing Training and Test Sets

In [2]:
class_names = ['normal', 'viral', 'covid']
root_dir = '/home/mashjunior/loTeNet_pytorch/Covid-19_images/COVID-19_Radiography/COVID-19 Radiography Database'
source_dirs = ['normal', 'viral', 'covid']


#if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
#    os.mkdir(os.path.join(root_dir, 'valid'))

#    for i, d in enumerate(source_dirs):
#        os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))

#    for c in class_names:
#        os.mkdir(os.path.join(root_dir, 'valid', c))

#    for c in class_names:
#        images = [x for x in os.listdir(os.path.join(root_dir, c)) if x.lower().endswith('png')]
#        selected_images = random.sample(images, 400)
#        for image in selected_images:
#            source_path = os.path.join(root_dir, c, image)
#            target_path = os.path.join(root_dir, 'valid', c, image)
#            shutil.move(source_path, target_path)

# Creating Custom Dataset

In [3]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs,transform):
        
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) if x.lower().endswith('png')]
            print(f'Found {len(images)}{class_name}')
            return images
        
        
        self.images={}
        self.class_names=['normal','viral','covid']
        
        for c in self.class_names:
            self.images[c]=get_images(c)
        self.image_dirs=image_dirs
        self.transform=transform
        
        
        
    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    
    
    def __getitem__(self, index):
        class_name=random.choice(self.class_names)
        index=index%len(self.images[class_name])
        image_name=self.images[class_name][index]
        image_path =os.path.join(self.image_dirs[class_name], image_name)
        image=Image.open(image_path).convert('L')
        image = self.transform(image)
        image = image.type(torch.FloatTensor)/255.0
        print(image, image.shape)

        return image, self.class_names.index(class_name)

In [4]:
class LIDC(Dataset):
	def __init__(self, rater=4, split='Train', data_dir = '/home/mashjunior/loTeNet_pytorch', transform=None):
		super().__init__()

		self.data_dir = data_dir
		self.rater = rater
		self.transform = transform
		self.data, self.targets = torch.load(data_dir+split+'.pt')
		self.targets = self.targets.type(torch.FloatTensor)		   
	def __len__(self):
		return len(self.targets)

	def __getitem__(self, index):

		image, label = self.data[index], self.targets[index]
		if self.rater == 4:
			label = (label.sum() > 2).type_as(self.targets)
		else:
			label = label[self.rater]
		image = image.type(torch.FloatTensor)/255.0
		if self.transform is not None:
			image = self.transform(image)
		return image, label

NameError: name 'Dataset' is not defined

In [5]:
#images, labels =next(iter(dl_test))
#show_images(images, labels, labels)

# Tensor Model 

In [6]:
# Globally load device identifier
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
def evaluate(loader):
     ### Evaluation funcntion for validation/testing

    with torch.no_grad():
        vl_acc = 0.
        vl_loss = 0.
        labelsNp = np.zeros(1)
        predsNp = np.zeros(1)
        model.eval()

        for i, (inputs, labels) in enumerate(loader):

            inputs = inputs.to(device)
            labels = labels.to(device)
            labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

            # Inference
            scores = torch.sigmoid(model(inputs))

            preds = scores
            loss = loss_fun(scores, labels)
            predsNp = np.concatenate((predsNp, preds.cpu().numpy()))
            vl_loss += loss.item()

        # Compute AUC over the full (valid/test) set
        vl_acc = computeAuc(labelsNp[1:],predsNp[1:])
        vl_loss = vl_loss/len(loader)

    return vl_acc, vl_loss

In [8]:
# Miscellaneous initialization
torch.manual_seed(1)
start_time = time.time()

In [9]:
parser = argparse.ArgumentParser()

In [10]:
parser.add_argument('--num_epochs', type=int, default=5, help='Number of training epochs')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--lr', type=float, default=5e-4, help='Learning rate')
parser.add_argument('--l2', type=float, default=0, help='L2 regularisation')
parser.add_argument('--aug', action='store_true', default=False, help='Use data augmentation')
parser.add_argument('--data_path', type=str, default=root_dir,help='Path to data.')
parser.add_argument('--bond_dim', type=int, default=5, help='MPS Bond dimension')
parser.add_argument('--nChannel', type=int, default=1, help='Number of input channels')

_StoreAction(option_strings=['--nChannel'], dest='nChannel', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='Number of input channels', metavar=None)

In [11]:
args = parser.parse_args([])

In [12]:
batch_size = args.batch_size

# LoTeNet parameters
adaptive_mode = False 
periodic_bc   = False

kernel = 2 # Stride along spatial dimensions
output_dim = 1 # output dimension
 
feature_dim = 2

#logFile = time.strftime("%Y%m%d_%H_%M")+'.txt'
#makeLogFile(logFile)

normTensor = 0.5*torch.ones(args.nChannel)
### Data processing and loading....

In [13]:
### Data processing and loading....
valid_transform = transforms.Compose([transforms.Normalize(mean=normTensor,std=normTensor)])

train_transform = transforms.Compose([transforms.Resize(128),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomVerticalFlip(),
                                      transforms.RandomRotation(20),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=normTensor,std=normTensor)])
test_transform = valid_transform

# data loader

In [14]:
train_dirs = {
    'normal': root_dir + '/normal',
    'viral':  root_dir + '/viral',
    'covid': root_dir + '/covid'
}
train_dataset=ChestXRayDataset(train_dirs, train_transform)

Found 911normal
Found 915viral
Found 770covid


In [15]:
valid_dirs = {
    'normal': root_dir + '/valid/normal',
    'viral': root_dir + '/valid/viral',
    'covid': root_dir + '/valid/covid'
}

valid_dataset = ChestXRayDataset(valid_dirs, valid_transform)

Found 400normal
Found 400viral
Found 400covid


In [16]:
test_dirs = {
    'normal': root_dir + '/test/normal',
    'viral': root_dir + '/test/viral',
    'covid': root_dir + '/test/covid'
}

test_dataset = ChestXRayDataset(test_dirs, valid_transform)

Found 30normal
Found 30viral
Found 30covid


In [17]:
train_dataset.images

{'normal': ['NORMAL (812).png',
  'NORMAL (230).png',
  'NORMAL (164).png',
  'NORMAL (296).png',
  'NORMAL (36).png',
  'NORMAL (767).png',
  'NORMAL (519).png',
  'NORMAL (729).png',
  'NORMAL (907).png',
  'NORMAL (590).png',
  'NORMAL (1099).png',
  'NORMAL (831).png',
  'NORMAL (679).png',
  'NORMAL (1171).png',
  'NORMAL (280).png',
  'NORMAL (105).png',
  'NORMAL (963).png',
  'NORMAL (112).png',
  'NORMAL (63).png',
  'NORMAL (688).png',
  'NORMAL (540).png',
  'NORMAL (262).png',
  'NORMAL (425).png',
  'NORMAL (44).png',
  'NORMAL (141).png',
  'NORMAL (436).png',
  'NORMAL (488).png',
  'NORMAL (1216).png',
  'NORMAL (986).png',
  'NORMAL (871).png',
  'NORMAL (1002).png',
  'NORMAL (584).png',
  'NORMAL (1070).png',
  'NORMAL (936).png',
  'NORMAL (189).png',
  'NORMAL (1271).png',
  'NORMAL (1251).png',
  'NORMAL (1172).png',
  'NORMAL (832).png',
  'NORMAL (815).png',
  'NORMAL (494).png',
  'NORMAL (484).png',
  'NORMAL (453).png',
  'NORMAL (1042).png',
  'NORMAL (1203)

# data visualization

In [18]:
class_names=train_dataset.class_names
def show_images(images, labels, preds):
    plt.figure(figsize=(8,4))
    for i, image in enumerate(images):
        plt.subplot(1,6,i+1, xticks=[], yticks=[])
        image=image.numpy().transpose((1,2,0))
        mean=np.array([0.485,0.456,0.406])
        std= np.array([0.229, 0.224, 0.225])
        image=image*std/mean
        image=np.clip(image,0.,1.)
        plt.imshow(image)
        col = 'green' if preds[i]==labels[i] else 'red'
        plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
        plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=col)
    plt.tight_layout()
    plt.show()

In [19]:
#batch_size=6
dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
print('Num of training batches', len(dl_train))
print('Num of validation batches', len(dl_valid))
print('Num of test batches', len(dl_test))

Num of training batches 21
Num of validation batches 10
Num of test batches 1


In [None]:
# Initiliaze input dimensions
dim = torch.ShortTensor(list(train_dataset[0][0].shape[1:]))
nCh = int(train_dataset[0][0].shape[0])

In [None]:
print(dim)

In [None]:
print(nCh)

In [None]:
train_dataset[0][0][0][0][0]

In [None]:
model = DenseNet(depth=40, growthRate=12, reduction=0.5,bottleneck=True,nClasses=output_dim)

In [None]:
# Choose loss function and optimizer
loss_fun = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 
                             weight_decay=args.l2)

In [None]:
nParam = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of parameters:%d"%(nParam))
print(f"Maximum MPS bond dimension = {args.bond_dim}")

print("Bond dim: %d"%(args.bond_dim))
print("Number of parameters:%d"%(nParam),)


In [None]:
print(f"Using Adam w/ learning rate = {args.lr:.1e}")
print("Feature_dim: %d, nCh: %d, B:%d"%(feature_dim,nCh,batch_size))

model = model.to(device)
nValid = len(dl_valid)
nTrain = len(dl_train)
nTest = len(dl_test)

maxAuc = 0
minLoss = 1e3
convCheck = 5
convIter = 0

In [None]:
# Let's start training!
for epoch in range(args.num_epochs):
    running_loss = 0.
    running_acc = 0.
    #t = time.time()
    model.train()
    predsNp = np.zeros(1)
    labelsNp = np.zeros(1)

    for i, (inputs, labels) in enumerate(dl_train):

        inputs = inputs.to(device)
        labels = labels.to(device)
        labelsNp = np.concatenate((labelsNp, labels.cpu().numpy()))

        scores = torch.sigmoid(model(inputs))

        preds = scores
        loss = loss_fun(scores, labels)

        with torch.no_grad():
            predsNp = np.concatenate((predsNp, preds.detach().cpu().numpy()))
            running_loss += loss

        # Backpropagate and update parameters
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 5 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, args.num_epochs, i+1, nTrain, loss.item()))

    accuracy = computeAuc(labelsNp,predsNp)

    # Evaluate on Validation set 
    with torch.no_grad():

        vl_acc, vl_loss = evaluate(dl_valid)
        if vl_acc > maxAuc or vl_loss < minLoss:
            if vl_loss < minLoss:
                minLoss = vl_loss
            if vl_acc > maxAuc:
                ### Predict on test set
                ts_acc, ts_loss = evaluate(dl_test)
                maxAuc = vl_acc
                print('New Max: %.4f'%maxAuc)
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
                print('Test Set Loss:%.4f	Auc:%.4f'%(ts_loss, ts_acc))
            convEpoch = epoch
            convIter = 0
        else:
            convIter += 1
        if convIter == convCheck:
            if not args.dense_net:
                print("MPS")
            else:
                print("DenseNet")
            print("Converged at epoch:%d with AUC:%.4f"%(convEpoch+1,maxAuc))

            break
#	writeLog(logFile, epoch, running_loss/nTrain, accuracy,
#			vl_loss, vl_acc, time.time()-t)