# Final Supervised Learning on Binary COVID Classification Task 

In [1]:
import torch, gc
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import os
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import StepLR
import numpy as np
from datetime import datetime
import pandas as pd
import random 
from torchvision.datasets import ImageFolder
from torchvision.models import resnet
import re
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import roc_auc_score
from skimage.io import imread, imsave
import skimage
from PIL import ImageFile
from PIL import Image
import argparse
from functools import partial

gc.collect()
torch.cuda.empty_cache()

### Check GPU Info ### 

In [2]:
#show gpu info
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0))
torch.cuda.current_device()
device = 'cuda'

gpu_info = !nvidia-smi -i 0
gpu_info = '\n'.join(gpu_info)
print(gpu_info)

True
1
Tesla V100-SXM2-16GB
Wed Apr 20 14:39:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000000:18:00.0 Off |                    0 |
| N/A   45C    P0    44W / 300W |      0MiB / 16160MiB |      0%   E. Process |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------

### Define all Argument ###

In [3]:
parser = argparse.ArgumentParser(description='Final Training Classifier on COVIDCT')

parser.add_argument('-a', '--arch', default='resnet18')
parser.add_argument('-bs','--batch-size', default=64, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--epochs', default=500, type=int, metavar='N', help='number of total epochs to run')
parser.add_argument('--results-dir', default='', type=str, metavar='PATH', help='path to cache (default: none)')
parser.add_argument('--resume',default=False,help='if resume training ')
parser.add_argument('--start-epoch',default = 1, type=int)
parser.add_argument('--total-epoch',default = 100, type=int)
parser.add_argument('--model-load',default = '',help='pretrained model file path')
args = parser.parse_args('')


args.results_dir = './Covid-cache/cache-' + 'train-Cifar-Luna-Covid'
args.model_load = './Covid-cache/cache-moco-Cifar-Luna-Covid/model_last.pth'
print(args)

Namespace(arch='resnet18', batch_size=64, epochs=500, model_load='./Covid-cache/cache-moco-Cifar-Luna-Covid/model_last.pth', results_dir='./Covid-cache/cache-train-Cifar-Luna-Covid', resume=False, start_epoch=1, total_epoch=100)


### Evaluation Metric ###

In [4]:
#define metric function 

def metric(predlist,tarlist,scorelist):
    TP = 0
    TN = 0
    FN = 0
    FP = 0
    for i in range(len(predlist)):
        if predlist[i]==1 and tarlist[i]==1:
            TP+=1
        elif predlist[i]==0 and tarlist[i]==0:
            TN+=1
        elif predlist[i]==0 and tarlist[i]==1:
            FN+=1
        elif predlist[i]==1 and tarlist[i]==0:
            FP+=1
    #TP = ((predlist == 1) & (tarlist == 1)).sum()
    #TN = ((predlist == 0) & (tarlist == 0)).sum()
    #FN = ((predlist == 0) & (tarlist == 1)).sum()
    #FP = ((predlist == 1) & (tarlist == 0)).sum()
    
    p = TP / (TP + FP)
    r = TP / (TP + FN)
    F1 = 2 * r * p / (r + p)
    acc = (TP + TN) / (TP + TN + FP + FN)
    AUC = roc_auc_score(tarlist,scorelist)
    return TP,TN,FN,FP,p,r,F1,acc,AUC

### The Image Augmentation and DataLoader ###

In [5]:
#define image augmentation for training set and val&test set 
normalize = transforms.Normalize(mean=[0.45271412, 0.45271412, 0.45271412],
                                     std=[0.33165374, 0.33165374, 0.33165374])
train_transformer = transforms.Compose([
    transforms.Resize(128),
    transforms.RandomResizedCrop((64),scale=(0.5,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    normalize
])

val_transformer = transforms.Compose([
    transforms.Resize((64,64)),
    transforms.ToTensor(),
    normalize
])

In [6]:

def read_txt(txt_path):
    with open(txt_path) as f:
        lines = f.readlines()
    txt_data = [line.strip() for line in lines]
    return txt_data

class CovidCTDataset(Dataset):
    def __init__(self, root_dir, txt_COVID, txt_NonCOVID, transform=None):
        """
        Args:
            txt_path (string): Path to the txt file with annotations.
            root_dir (string): Directory with all the images.
            transform: image augmentation 
        File structure:
        - root_dir
            - CT_COVID
                - img1.png
                - img2.png
                - ......
            - CT_NonCOVID
                - img1.png
                - img2.png
                - ......
        """
        self.root_dir = root_dir
        self.txt_path = [txt_COVID,txt_NonCOVID]
        self.classes = ['CT_COVID', 'CT_NonCOVID']
        self.num_cls = len(self.classes)
        self.img_list = []
        for c in range(self.num_cls):
            cls_list = [[os.path.join(self.root_dir,self.classes[c],item), c] for item in read_txt(self.txt_path[c])]
            self.img_list += cls_list
        self.transform = transform

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.img_list[idx][0]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)
        sample = {'img': image,
                  'label': int(self.img_list[idx][1])}
        return sample



    
if __name__ == '__main__':
    trainset = CovidCTDataset(root_dir='/projectnb/dl523/projects/COVIDCT2/dataset/COVIDCT',
                              txt_COVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/COVID/trainCT_COVID.txt',
                              txt_NonCOVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/NonCOVID/trainCT_NonCOVID.txt',
                              transform= train_transformer)
    valset = CovidCTDataset(root_dir='/projectnb/dl523/projects/COVIDCT2/dataset/COVIDCT',
                              txt_COVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/COVID/valCT_COVID.txt',
                              txt_NonCOVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/NonCOVID/valCT_NonCOVID.txt',
                              transform= val_transformer)
    testset = CovidCTDataset(root_dir='/projectnb/dl523/projects/COVIDCT2/dataset/COVIDCT',
                              txt_COVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/COVID/testCT_COVID.txt',
                              txt_NonCOVID='/projectnb/dl523/projects/COVIDCT2/dataset/COVID-Data-split/NonCOVID/testCT_NonCOVID.txt',
                              transform= val_transformer)
    print(trainset.__len__())
    print(valset.__len__())
    print(testset.__len__())

    train_loader = DataLoader(trainset, batch_size=args.batch_size,drop_last=True, shuffle=True,num_workers=16)
    val_loader = DataLoader(valset, batch_size=args.batch_size, drop_last=False, shuffle=False,num_workers=16)
    test_loader = DataLoader(testset, batch_size=args.batch_size,drop_last=False, shuffle=False,num_workers=16)
    

425
118
203


### Train, Validation and Test Function ###

In [7]:
def train(optimizer, epoch):
    
    model.train()
    
    train_loss = 0
    train_correct = 0
    total_num = 0 
    
    for batch_index, batch_samples in enumerate(train_loader):
        
        data, target = batch_samples['img'].to(device), batch_samples['label'].to(device) 
        optimizer.zero_grad()
        output = model(data)
        criteria = nn.CrossEntropyLoss()
        loss = criteria(output, target.long())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_num += target.size(0)
        train_loss += loss.item()*target.size(0)                
        pred = output.argmax(dim=1, keepdim=True)
        train_correct += pred.eq(target.long().view_as(pred)).sum().item()
    
        if batch_index % 3 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}'.format(
                epoch, batch_index, len(train_loader),
                100.0 * batch_index / len(train_loader), loss.item()))
    return train_loss/total_num      

In [8]:
def val(epoch):
    
    model.eval()
    val_loss = 0
    total_num = 0 
    correct = 0
    results = []
   
    criteria = nn.CrossEntropyLoss()
    with torch.no_grad():
        
        
        predlist=[]
        scorelist=[]
        targetlist=[]
        for batch_index, batch_samples in enumerate(val_loader):
            data, target = batch_samples['img'].to(device), batch_samples['label'].to(device)
            output = model(data)
            
            total_num += target.size(0)
            val_loss += criteria(output, target.long()).item()*target.size(0)
            score = F.softmax(output, dim=1)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.long().view_as(pred)).sum().item()
            targetcpu=target.long().cpu().numpy()
            predlist=np.append(predlist, pred.cpu().numpy())
            scorelist=np.append(scorelist, score.cpu().numpy()[:,1])
            targetlist=np.append(targetlist,targetcpu)
           
    return val_loss/total_num, targetlist, scorelist, predlist
    

In [9]:
def test(epoch):
    
    model.eval()
    test_loss = 0
    total_num = 0 
    correct = 0
    results = []
    
    criteria = nn.CrossEntropyLoss()
    with torch.no_grad():
        
        predlist=[]
        scorelist=[]
        targetlist=[]
        for batch_index, batch_samples in enumerate(test_loader):
            data, target = batch_samples['img'].to(device), batch_samples['label'].to(device)
            output = model(data)
            test_loss += criteria(output, target.long())
            
            total_num += target.size(0)
            test_loss += criteria(output, target.long()).item()/target.size(0)
            
            score = F.softmax(output, dim=1)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.long().view_as(pred)).sum().item()
            targetcpu=target.long().cpu().numpy()
            predlist=np.append(predlist, pred.cpu().numpy())
            scorelist=np.append(scorelist, score.cpu().numpy()[:,1])
            targetlist=np.append(targetlist,targetcpu)
            
    return test_loss/total_num, targetlist, scorelist, predlist

### ResNet18 Backbone ### 

**SplitBatchNorm**: simulate the multiple gpu parallel computing in only one GPU. 

In [10]:
#modified ResNet18 model 
class SplitBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W)
            self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
            self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
            return outcome
        else:
            return nn.functional.batch_norm(
                input, self.running_mean, self.running_var, 
                self.weight, self.bias, False, self.momentum, self.eps)


class ModelBase(nn.Module):
    """
    Common CIFAR ResNet recipe.
    Comparing with ImageNet ResNet recipe, it:
    (i) replaces conv1 with kernel=3, str=1
    (ii) removes pool1
    """
    def __init__(self, feature_dim=128, arch=None, bn_splits=16):
        super(ModelBase, self).__init__()

        # use split batchnorm
        norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        resnet_arch = getattr(resnet, arch)
        net = resnet_arch(num_classes=feature_dim, norm_layer=norm_layer)

        self.net = []
        for name, module in net.named_children():
            if name == 'conv1':
                module = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            if isinstance(module, nn.MaxPool2d):
                continue
            if isinstance(module, nn.Linear):
                self.net.append(nn.Flatten(1))
            self.net.append(module)

        self.net = nn.Sequential(*self.net)

    def forward(self, x):
        x = self.net(x)
        # note: not normalized here
        return x

### Load The MoCo Pretrained Encoder  and Start Final Training for the Classifier ###

In [11]:
model = ModelBase(arch='resnet18', bn_splits=8)
model.to(device)

checkpoint = torch.load(args.model_load)

for i in model.state_dict():
    if 'encoder_q.'+i in checkpoint['state_dict']:
        model.state_dict()[i].copy_(checkpoint['state_dict']['encoder_q.'+i].data)

    


In [12]:
# train
args.resume = False
args.start_epoch = 1

votenum = 10

train_loss = np.zeros(1)
val_loss = np.zeros(1)
vote_pred = np.zeros(valset.__len__())
vote_score = np.zeros(valset.__len__())

optimizer = optim.Adam(model.parameters(), lr=0.0001)

#for resuming 
if args.resume == True:
    results = pd.read_csv(args.results_dir + '/train_val_log.csv',index_col=0)
    checkpoint = torch.load(args.results_dir + '/model_last.pth')
    model.load_state_dict(checkpoint['state_dict'])
else: 
    results = {'epoch':[],'train_loss':[],'val_loss':[],'val_accurary': [],'val_F1': [],'val_AUC': []}


if not os.path.exists(args.results_dir):
    os.mkdir(args.results_dir)

for epoch in range(args.start_epoch, args.total_epoch+1):
    trainloss = train(optimizer, epoch)
    
    valloss,targetlist, scorelist, predlist = val(epoch)
    
    train_loss += trainloss
    val_loss += valloss 
    vote_pred += predlist 
    vote_score += scorelist 
    
    if epoch % votenum == 0:
        
        # major vote
        vote_pred[vote_pred <= (votenum/2)] = 0
        vote_pred[vote_pred > (votenum/2)] = 1
        vote_score = vote_score/votenum
        
        TP,TN,FN,FP,p,r,F1,acc,AUC = metric(vote_pred,targetlist,vote_score)
        
        results['epoch'].append(epoch)
        results['train_loss'].append((train_loss/votenum).item())
        results['val_loss'].append((val_loss/votenum).item())
        results['val_accurary'].append(acc)
        results['val_F1'].append(F1)
        results['val_AUC'].append(AUC)
        data_frame = pd.DataFrame(data=results)
        data_frame.to_csv(args.results_dir + '/train_val_log.csv')
        
        #save the most current model 
        torch.save({'epoch': epoch, 'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict(),}, args.results_dir + '/model_last.pth')
        
        vote_pred = np.zeros(valset.__len__())
        vote_score = np.zeros(valset.__len__())
        print('\n The epoch is {}, 10 epoch avg: train loss: {:.4f}, val loss: {:.4f}, acc: {:.4f}, F1: {:.4f}, AUC: {:.4f}'.format(
                epoch,(train_loss/votenum).item(),(val_loss/votenum).item(),acc,F1,AUC))
        
        train_loss = np.zeros(1)
        val_loss = np.zeros(1)
        vote_pred = np.zeros(valset.__len__())
        vote_score = np.zeros(valset.__len__())




 The epoch is 10, 10 epoch avg: train loss: 4.0427, val loss: 3.4868, acc: 0.6271, F1: 0.6857, AUC: 0.6325

 The epoch is 20, 10 epoch avg: train loss: 2.8745, val loss: 3.7045, acc: 0.6949, F1: 0.7097, AUC: 0.7928

 The epoch is 30, 10 epoch avg: train loss: 1.7224, val loss: 2.7551, acc: 0.7119, F1: 0.7018, AUC: 0.7911

 The epoch is 40, 10 epoch avg: train loss: 0.8686, val loss: 1.8085, acc: 0.7373, F1: 0.7304, AUC: 0.7882

 The epoch is 50, 10 epoch avg: train loss: 0.4503, val loss: 1.2226, acc: 0.7288, F1: 0.7288, AUC: 0.8138

 The epoch is 60, 10 epoch avg: train loss: 0.2500, val loss: 1.1151, acc: 0.7288, F1: 0.7091, AUC: 0.7954

 The epoch is 70, 10 epoch avg: train loss: 0.1713, val loss: 1.0027, acc: 0.7203, F1: 0.7273, AUC: 0.7974



 The epoch is 80, 10 epoch avg: train loss: 0.1271, val loss: 0.9215, acc: 0.7542, F1: 0.7563, AUC: 0.7968

 The epoch is 90, 10 epoch avg: train loss: 0.0954, val loss: 0.9370, acc: 0.7288, F1: 0.7333, AUC: 0.8069

 The epoch is 100, 10 epoch avg: train loss: 0.0747, val loss: 0.9014, acc: 0.7203, F1: 0.7317, AUC: 0.7954


# Final Results #

In [20]:
# test
model = ModelBase(arch='resnet18', bn_splits=8)
model.to(device)

checkpoint1 = torch.load('Covid-cache/cache-train-random-ini/model_last.pth')
checkpoint2 = torch.load('Covid-cache/cache-train-Cifar/model_last.pth')
checkpoint3 = torch.load('Covid-cache/cache-train-Cifar-Covid/model_last.pth')
checkpoint4 = torch.load('Covid-cache/cache-train-Cifar-Luna-Covid/model_last.pth')


### 1. Resnet18 Random Initialization (without Moco Pretraining) ### 
Final best performance:  acc, F1, AUC score: 0.551, 0.683, 0.715

In [21]:
#Resnet18 Random Initialization 
args.batch_size = 1

model.load_state_dict(checkpoint1['state_dict'])
epoch = 1

vote_pred = np.zeros(testset.__len__())
vote_score = np.zeros(testset.__len__())


testloss,targetlist, scorelist, predlist = test(epoch)
#print('target',targetlist)
#print('score',scorelist)
#print('predict',predlist)
vote_pred = vote_pred + predlist 
vote_score = vote_score + scorelist 

TP,TN,FN,FP,p,r,F1,acc,AUC = metric(vote_pred,targetlist,vote_score)
print(acc,F1,AUC)


0.5517241379310345 0.6829268292682926 0.7146744412050534


### 2. one round of MoCo pretraining on Cifar-10 Dataset  ### 
Final best performance:  acc, F1, AUC score: 0.695, 0.750, 0.818

In [22]:
#moco: Resnet18 pretrained on Cifar 
args.batch_size = 1

model.load_state_dict(checkpoint2['state_dict'])
epoch = 1

vote_pred = np.zeros(testset.__len__())
vote_score = np.zeros(testset.__len__())


testloss,targetlist, scorelist, predlist = test(epoch)
#print('target',targetlist)
#print('score',scorelist)
#print('predict',predlist)
vote_pred = vote_pred + predlist 
vote_score = vote_score + scorelist 

TP,TN,FN,FP,p,r,F1,acc,AUC = metric(vote_pred,targetlist,vote_score)
print(acc,F1,AUC)



0.6945812807881774 0.7499999999999999 0.8179786200194364


### 2. two rounds of MoCo pretraining on Cifar-10 and COVID-CT unlabeled Dataset  ### 
Final best performance:  acc, F1, AUC score: 0.783, 0.796, 0.880

In [18]:
#moco: Resnet18 pretrained on Cifar and COVIDCT
args.batch_size = 1

model.load_state_dict(checkpoint3['state_dict'])
epoch = 1

vote_pred = np.zeros(testset.__len__())
vote_score = np.zeros(testset.__len__())


testloss,targetlist, scorelist, predlist = test(epoch)
#print('target',targetlist)
#print('score',scorelist)
#print('predict',predlist)
vote_pred = vote_pred + predlist 
vote_score = vote_score + scorelist 

TP,TN,FN,FP,p,r,F1,acc,AUC = metric(vote_pred,targetlist,vote_score)
print(acc,F1,AUC)




0.7832512315270936 0.7962962962962963 0.8798833819241981


### 2. three rounds of MoCo pretraining on Cifar-10, LUNA-CT and COVID-CT unlabeled Dataset  ### 
Final best performance:  acc, F1, AUC score: 0.695, 0.677, 0.770

In [19]:
#moco: Resnet18 pretrained on Cifar, Luna and COVIDCT
args.batch_size = 10

model.load_state_dict(checkpoint4['state_dict'])
epoch = 1

vote_pred = np.zeros(testset.__len__())
vote_score = np.zeros(testset.__len__())


testloss,targetlist, scorelist, predlist = test(epoch)
#print('target',targetlist)
#print('score',scorelist)
#print('predict',predlist)
vote_pred = vote_pred + predlist 
vote_score = vote_score + scorelist 

TP,TN,FN,FP,p,r,F1,acc,AUC = metric(vote_pred,targetlist,vote_score)
print(acc,F1,AUC)





0.6945812807881774 0.6770833333333333 0.7699708454810495


### Partial Code Refer to: ### 
[link] (https://github.com/UCSD-AI4H/COVID-CT/blob/master/baseline%20methods/Self-Trans/CT-predict-pretrain.ipynb)


@article{zhao2020COVID-CT-Dataset,
  title={COVID-CT-Dataset: a CT scan dataset about COVID-19},
  author={Zhao, Jinyu and Zhang, Yichen and He, Xuehai and Xie, Pengtao},
  journal={arXiv preprint arXiv:2003.13865}, 
  year={2020}
}


