In [None]:
!nvidia-smi -L

In [None]:
!git clone "https://github.com/tanish-g/covid-chestxray-dataset.git" --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -qq "/content/drive/MyDrive/hacknagpur/covid-19-x-ray-10000-images.zip"

In [1]:
# ==== Install Dependencies
!pip install -q efficientnet-pytorch
!pip install -q albumentations
!pip install -U -q pytorch-fanatics 
!pip install -q pytorch_ranger

In [2]:
# ==== Import Libraries

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
import seaborn as sns
import random
import os


import albumentations as aug
from albumentations.pytorch.transforms import ToTensor
import matplotlib.pyplot as plt


from tqdm import tqdm

from sklearn.model_selection import GroupKFold
from sklearn.metrics import accuracy_score, roc_auc_score
from torch.utils.data import Dataset,DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch_fanatics.dataloader import Cloader
from pytorch_fanatics.utils import EarlyStop ,LRFinder 
from pytorch_fanatics.trainer import Trainer
from pytorch_fanatics.logger import Logger

import warnings
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

from pytorch_ranger import Ranger

from efficientnet_pytorch import EfficientNet
from pathlib import Path

from torchvision import transforms

In [3]:
df=pd.read_csv('/content/covid-chestxray-dataset/metadata.csv')
df['RT_PCR_positive'].value_counts()

Y          371
Unclear    222
Name: RT_PCR_positive, dtype: int64

In [4]:
df.columns

Index(['patientid', 'offset', 'sex', 'age', 'finding', 'RT_PCR_positive',
       'survival', 'intubated', 'intubation_present', 'went_icu', 'in_icu',
       'needed_supplemental_O2', 'extubated', 'temperature', 'pO2_saturation',
       'leukocyte_count', 'neutrophil_count', 'lymphocyte_count', 'view',
       'modality', 'date', 'location', 'folder', 'filename', 'doi', 'url',
       'license', 'clinical_notes', 'other_notes', 'Unnamed: 29'],
      dtype='object')

In [5]:
d={'Pneumonia/Viral/COVID-19':0, 'Pneumonia':1, 'Pneumonia/Viral/SARS':1,
       'Pneumonia/Fungal/Pneumocystis':1,
       'Pneumonia/Bacterial/Streptococcus':1, 'No Finding':1,
       'Pneumonia/Bacterial/Chlamydophila':1, 'Pneumonia/Bacterial/E.Coli':1,
       'Pneumonia/Bacterial/Klebsiella':1, 'Pneumonia/Bacterial/Legionella':1,
       'Unknown':1, 'Pneumonia/Lipoid':1, 'Pneumonia/Viral/Varicella':1,
       'Pneumonia/Bacterial':1, 'Pneumonia/Bacterial/Mycoplasma':1,
       'Pneumonia/Viral/Influenza':1, 'todo':1, 'Tuberculosis':1,
       'Pneumonia/Viral/Influenza/H1N1':1, 'Pneumonia/Fungal/Aspergillosis':1,
       'Pneumonia/Viral/Herpes ':1, 'Pneumonia/Aspiration':1,
       'Pneumonia/Bacterial/Nocardia':1, 'Pneumonia/Viral/MERS-CoV':1,
       'Pneumonia/Bacterial/Staphylococcus/MRSA':1}

In [6]:
df['finding']=df['finding'].map(d)

In [7]:
import os

asps = []
for root, dirs, files in os.walk('/content/dataset/normal/'):
    # print(root)
    # print(dirs)
    # print(files)
    for file in files:
        if file.endswith('.png'):
            asps.append(file)
        if file.endswith('.jpeg'):
            asps.append(file)

In [8]:
image_id=[]
labels=[]
for x in range(len(asps)):
  asps[x]=os.path.join('/content/dataset/normal',asps[x])
  image_id.append(asps[x])
  labels.append(1)

In [9]:
d={'image_id':image_id,'labels':labels}
df1=pd.DataFrame(d)

In [10]:
df_final=df[['filename','finding']]
df_final=df_final.rename(columns={'finding': 'labels'})
df_final=df_final.rename(columns={'filename':'image_id'})

In [11]:
for x in range(len(df_final)):
  df_final['image_id'][x]=os.path.join('/content/covid-chestxray-dataset/images',df_final['image_id'][x])

In [12]:
from PIL import Image

In [13]:
list1=[]
list2=[]
for x in range(len(df_final)):
  try:
    Image.open(df_final['image_id'][x])
    list2.append(x)
  except:
    list1.append(x)

In [14]:
df_final=df_final.iloc[list2]

In [15]:
df=pd.concat([df_final,df1],axis=0)
df=df.sample(frac=1)
df=df.reset_index(drop=True)

In [16]:
len(df)

957

In [17]:
df.head(5)

Unnamed: 0,image_id,labels
0,/content/covid-chestxray-dataset/images/1e5348...,1
1,/content/covid-chestxray-dataset/images/tpmd20...,0
2,/content/covid-chestxray-dataset/images/thnov1...,0
3,/content/covid-chestxray-dataset/images/5ed7d0...,0
4,/content/covid-chestxray-dataset/images/AR-1.jpg,0


In [18]:
from sklearn.model_selection import train_test_split as tts

In [19]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(42)

In [20]:
df['labels'].value_counts()

0    563
1    394
Name: labels, dtype: int64

# Data Exploration

In [21]:
!mkdir "/content/drive/MyDrive/HackNagpur/"

mkdir: cannot create directory ‘/content/drive/MyDrive/HackNagpur/’: File exists


In [22]:
save_root = "/content/drive/MyDrive/HackNagpur/"

In [23]:
# training_data_path = "/content/ISIC_2019_Training_Input/ISIC_2019_Training_Input"

X_train , X_val ,Y_train , Y_val = tts(df, df.labels.values, test_size=0.25
                                       ,random_state=42,stratify=df.labels.values)
X_train       = X_train.reset_index(drop=True)
X_val         = X_val.reset_index(drop=True)

In [24]:
# ===== Augmentations

mean       = (0.485, 0.456, 0.406)
std        = (0.229, 0.224, 0.225)
image_size=224
train_tfms = aug.Compose([
            aug.Resize(224,224),
            aug.RandomSizedCrop(min_max_height=(64,224),height=224,width=224,p=0.5),
            aug.HorizontalFlip(p=0.5),
            aug.RandomBrightnessContrast(0.1,0.1),
            #aug.HueSaturationValue(10,10,10),
            aug.RGBShift(),
            aug.RandomContrast(limit=0.2),
            aug.RandomGamma(),
            #aug.ShiftScaleRotate(rotate_limit=(-45,45)),
            aug.GaussNoise(p=0.35),
            aug.IAASharpen(),
            aug.ToGray(p=0.35),
            aug.CLAHE(clip_limit=4.0, p=0.7),
            #aug.RandomSizedCrop(min_max_height=(64,256),height=256,width=256,p=0.5),
            aug.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.5),
            aug.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=0.85),
            #aug.Cutout(max_h_size=int(image_size * 0.375), max_w_size=int(image_size * 0.375), num_holes=1, p=0.7),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

test_tfms  = aug.Compose([
            aug.Resize(224,224),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

In [None]:
# training_data_path='/content/train'

In [25]:
import numpy as np
from PIL import Image
from PIL import ImageFile   #ImageFile contains support for PIL to open and save images
ImageFile.LOAD_TRUNCATED_IMAGES=True  #If image is truncated(or corrupt) then also load it..
import torch

class Cloader:
    def __init__(self,image_path,targets,resize=None,transforms=None):
        self.image_path=image_path
        self.targets=targets
        self.resize=resize
        self.transforms=transforms
        
    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self,idx):
        image = Image.open(self.image_path[idx]).convert("RGB")
        targets = self.targets[idx]
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )
        image = np.array(image)
        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [26]:
train_images     = X_train.image_id.values.tolist()
# train_images     = [os.path.join(training_data_path, i+".png") for i in train_images]

test_images      = X_val.image_id.values.tolist()
# test_images      = [os.path.join(training_data_path, i+".png") for i in test_images]

train_dataset    = Cloader(train_images,X_train.labels.values,None,train_tfms)
#train_dataset    = CutMix(train_dataset, num_class=8, beta=1.0, prob=0.5, num_mix=2)
test_dataset     = Cloader(test_images,X_val.labels.values,None,test_tfms)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=0)
val_dataloader   = DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=0)

device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [27]:
# ===== Define model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.base_model=EfficientNet.from_pretrained('efficientnet-b0',num_classes=2)
    def forward(self, image, targets):
        batch_size, _, _, _ = image.shape
        out = self.base_model(image)
        targets = torch.tensor(targets,dtype=torch.int64)
        loss = nn.CrossEntropyLoss()(out.view(batch_size,2), targets)
        return out, loss

model = Net()
model.to(device);

Loaded pretrained weights for efficientnet-b0


In [28]:
def softmax(array):
    return np.exp(array)/np.sum(np.exp(array),axis=1).reshape(-1,1)

In [29]:
optimizer = Ranger(model.parameters(),lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer,factor=0.8, mode="min", patience=2)

trainer   = Trainer(model=model,optimizer=optimizer,device=device,val_scheduler=scheduler)
logger    = Logger()

es        = EarlyStop(patience=15,mode="min") # mode = min to minimise loss

In [30]:
epochs = 30

for epoch in range(epochs):
    logger.write(f"+ ===== Epoch {epoch+1}/{epochs} ===== +")
    train_loss              = trainer.train(train_dataloader)
    y_true,y_pred ,val_loss = trainer.evaluate(val_dataloader)
    y_pred                  = softmax(y_pred)
    accuracy                = accuracy_score(y_true,np.argmax(y_pred,axis=1))
    es(val_loss,model,model_path ="/content/drive/MyDrive/HackNagpur/best_covid_ct_images.pth")
    logger.write(f"train_loss {train_loss} val_loss {val_loss} ")
    logger.write(f"val accuracy_score {accuracy} ")
    logger.write(" ")
    if es.early_stop:
        break

+ ===== Epoch 1/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (inf --> 0.6946322917938232). Saving model!
train_loss 0.6939596248709637 val_loss 0.6946322917938232 
val accuracy_score 0.5291666666666667 
 
+ ===== Epoch 2/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.6946322917938232 --> 0.674707718193531). Saving model!
train_loss 0.6801984361980274 val_loss 0.674707718193531 
val accuracy_score 0.5333333333333333 
 
+ ===== Epoch 3/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.674707718193531 --> 0.6613356247544289). Saving model!
train_loss 0.6630075703496519 val_loss 0.6613356247544289 
val accuracy_score 0.5583333333333333 
 
+ ===== Epoch 4/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.6613356247544289 --> 0.645863339304924). Saving model!
train_loss 0.6350418795710023 val_loss 0.645863339304924 
val accuracy_score 0.575 
 
+ ===== Epoch 5/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.645863339304924 --> 0.6142282038927078). Saving model!
train_loss 0.6295748974965968 val_loss 0.6142282038927078 
val accuracy_score 0.65 
 
+ ===== Epoch 6/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.6142282038927078 --> 0.5900362469255924). Saving model!
train_loss 0.6010395962259042 val_loss 0.5900362469255924 
val accuracy_score 0.6458333333333334 
 
+ ===== Epoch 7/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.5900362469255924 --> 0.5650490745902061). Saving model!
train_loss 0.5852141781993534 val_loss 0.5650490745902061 
val accuracy_score 0.7041666666666667 
 
+ ===== Epoch 8/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 1 out of 15
train_loss 0.5500117333038994 val_loss 0.6014665439724922 
val accuracy_score 0.6416666666666667 
 
+ ===== Epoch 9/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.5650490745902061 --> 0.5053517855703831). Saving model!
train_loss 0.5737719898638517 val_loss 0.5053517855703831 
val accuracy_score 0.7708333333333334 
 
+ ===== Epoch 10/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 1 out of 15
train_loss 0.5269400179386139 val_loss 0.6425060629844666 
val accuracy_score 0.6416666666666667 
 
+ ===== Epoch 11/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 2 out of 15
train_loss 0.5463348601175391 val_loss 0.8756523877382278 
val accuracy_score 0.5416666666666666 
 
+ ===== Epoch 12/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 3 out of 15
train_loss 0.49307192797246185 val_loss 0.6517661362886429 
val accuracy_score 0.6375 
 
+ ===== Epoch 13/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 4 out of 15
train_loss 0.5156545950018842 val_loss 0.5133348032832146 
val accuracy_score 0.7125 
 
+ ===== Epoch 14/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 5 out of 15
train_loss 0.4566842421241429 val_loss 0.5298049449920654 
val accuracy_score 0.7208333333333333 
 
+ ===== Epoch 15/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.5053517855703831 --> 0.49544215202331543). Saving model!
train_loss 0.48183324414750806 val_loss 0.49544215202331543 
val accuracy_score 0.7666666666666667 
 
+ ===== Epoch 16/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 1 out of 15
train_loss 0.46856865286827093 val_loss 0.5203756131231785 
val accuracy_score 0.7291666666666666 
 
+ ===== Epoch 17/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 2 out of 15
train_loss 0.44766351839770446 val_loss 0.5488722622394562 
val accuracy_score 0.7208333333333333 
 
+ ===== Epoch 18/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 3 out of 15
train_loss 0.4249121961386308 val_loss 0.5093664452433586 
val accuracy_score 0.7708333333333334 
 
+ ===== Epoch 19/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.49544215202331543 --> 0.4921174570918083). Saving model!
train_loss 0.45250558593998774 val_loss 0.4921174570918083 
val accuracy_score 0.7958333333333333 
 
+ ===== Epoch 20/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


Metric Validation score improved (0.4921174570918083 --> 0.48646094277501106). Saving model!
train_loss 0.4379995465278626 val_loss 0.48646094277501106 
val accuracy_score 0.7708333333333334 
 
+ ===== Epoch 21/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 1 out of 15
train_loss 0.41617934729741973 val_loss 0.5037427060306072 
val accuracy_score 0.7625 
 
+ ===== Epoch 22/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 2 out of 15
train_loss 0.35712587444678595 val_loss 0.5176238175481558 
val accuracy_score 0.7833333333333333 
 
+ ===== Epoch 23/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 3 out of 15
train_loss 0.37998071248116694 val_loss 0.5738617200404406 
val accuracy_score 0.7583333333333333 
 
+ ===== Epoch 24/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 4 out of 15
train_loss 0.34253627580145135 val_loss 0.6184986643493176 
val accuracy_score 0.7333333333333333 
 
+ ===== Epoch 25/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 5 out of 15
train_loss 0.39613611477872607 val_loss 0.5416919589042664 
val accuracy_score 0.7541666666666667 
 
+ ===== Epoch 26/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 6 out of 15
train_loss 0.3976355506026226 val_loss 0.5304510481655598 
val accuracy_score 0.7875 
 
+ ===== Epoch 27/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 7 out of 15
train_loss 0.3157059945490049 val_loss 0.6019870564341545 
val accuracy_score 0.7666666666666667 
 
+ ===== Epoch 28/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 8 out of 15
train_loss 0.3373527747133504 val_loss 0.6935960948467255 
val accuracy_score 0.725 
 
+ ===== Epoch 29/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 9 out of 15
train_loss 0.32344773930052056 val_loss 0.5851790197193623 
val accuracy_score 0.7875 
 
+ ===== Epoch 30/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))


EarlyStop count: 10 out of 15
train_loss 0.3336532005797261 val_loss 0.6011535227298737 
val accuracy_score 0.7833333333333333 
 
