In [1]:
# ==== Install Dependencies

!pip install -q efficientnet-pytorch
!pip install -q albumentations
!pip install -q pytorch-fanatics 
!pip install -q pytorch_ranger

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
# ==== Import Libraries

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
import seaborn as sns
import random
import os


import albumentations as aug
from albumentations.pytorch.transforms import ToTensor
import matplotlib.pyplot as plt

from efficientnet_pytorch import EfficientNet
from tqdm import tqdm

from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset,DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from pytorch_fanatics.dataloader import Cloader
from pytorch_fanatics.utils import EarlyStop 
from pytorch_fanatics.trainer import Trainer
from pytorch_fanatics.logger import Logger

import warnings
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

from pytorch_ranger import Ranger

In [3]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [4]:
X_train = pd.read_csv('../input/pnemonia/train.csv')
X_train.sample(frac=1).reset_index(drop=True)
X_val   = pd.read_csv('../input/pnemonia/val (1).csv')
X_val.sample(frac=1).reset_index(drop=True)

Unnamed: 0,image_id,label
0,../input/chest-xray-pneumonia/chest_xray/test/...,1
1,../input/chest-xray-pneumonia/chest_xray/test/...,1
2,../input/chest-xray-pneumonia/chest_xray/test/...,1
3,../input/chest-xray-pneumonia/chest_xray/test/...,1
4,../input/chest-xray-pneumonia/chest_xray/test/...,1
...,...,...
619,../input/chest-xray-pneumonia/chest_xray/test/...,0
620,../input/chest-xray-pneumonia/chest_xray/test/...,1
621,../input/chest-xray-pneumonia/chest_xray/test/...,1
622,../input/chest-xray-pneumonia/chest_xray/test/...,1


In [5]:
X_train

Unnamed: 0,image_id,label
0,../input/chest-xray-pneumonia/chest_xray/train...,1
1,../input/chest-xray-pneumonia/chest_xray/train...,1
2,../input/chest-xray-pneumonia/chest_xray/train...,1
3,../input/chest-xray-pneumonia/chest_xray/train...,1
4,../input/chest-xray-pneumonia/chest_xray/train...,1
...,...,...
5211,../input/chest-xray-pneumonia/chest_xray/train...,0
5212,../input/chest-xray-pneumonia/chest_xray/train...,0
5213,../input/chest-xray-pneumonia/chest_xray/train...,0
5214,../input/chest-xray-pneumonia/chest_xray/train...,0


In [6]:
# ===== Augmentations

mean       = (0.485, 0.456, 0.406)
std        = (0.229, 0.224, 0.225)
train_tfms = aug.Compose([
            aug.Resize(224,224),
            aug.HorizontalFlip(p=0.5),
            aug.RandomBrightnessContrast(0.1,0.1),
            aug.ShiftScaleRotate(rotate_limit=(-45,45)),
            aug.GaussNoise(p=0.35),
            aug.IAASharpen(),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

test_tfms  = aug.Compose([
            aug.Resize(224,224),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

In [7]:
import numpy as np
from PIL import Image
from PIL import ImageFile   #ImageFile contains support for PIL to open and save images
ImageFile.LOAD_TRUNCATED_IMAGES=True  #If image is truncated(or corrupt) then also load it..
import torch

class Cloader:
    def __init__(self,image_path,targets,resize=None,transforms=None):
        self.image_path=image_path
        self.targets=targets
        self.resize=resize
        self.transforms=transforms
        
    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self,idx):
        image = Image.open(self.image_path[idx]).convert("RGB")
        targets = self.targets[idx]
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )
        image = np.array(image)
        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }

In [8]:
train_images     = X_train.image_id.values.tolist()

test_images      = X_val.image_id.values.tolist()

train_dataset    = Cloader(train_images,X_train.label.values,None,train_tfms)
test_dataset     = Cloader(test_images,X_val.label.values,None,test_tfms)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4)
val_dataloader   = DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=4)

device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
# ===== Define model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.base_model = EfficientNet.from_pretrained('efficientnet-b0',num_classes=2)
    def forward(self, image, targets):
        batch_size, _, _, _ = image.shape
        out = self.base_model(image)
        targets = torch.tensor(targets,dtype=torch.int64)
        loss = nn.CrossEntropyLoss()(out.view(batch_size,2), targets)
        return out, loss

model = Net()
model.to(device);

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth


HBox(children=(FloatProgress(value=0.0, max=21388428.0), HTML(value='')))


Loaded pretrained weights for efficientnet-b0


In [10]:
def softmax(array):
    return np.exp(array)/np.sum(np.exp(array),axis=1).reshape(-1,1)

In [11]:
optimizer = Ranger(model.parameters(),lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer,factor=0.8, mode="min", patience=2)

trainer   = Trainer(model=model,optimizer=optimizer,device=device,val_scheduler=scheduler)
logger    = Logger()

In [12]:
epochs = 30
best_loss = np.Inf
for epoch in range(epochs):
    logger.write(f"+ ===== Epoch {epoch+1}/{epochs} ===== +")
    train_loss              = trainer.train(train_dataloader)
    y_true,y_pred ,val_loss = trainer.evaluate(val_dataloader)
    y_pred                  = softmax(y_pred)
    accuracy                = accuracy_score(y_true,np.argmax(y_pred,axis=1))
    logger.write(f"train_loss {train_loss} val_loss {val_loss} ")
    logger.write(f"val accuracy_score {accuracy} ")
    logger.write(" ")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(),'best.pth')


+ ===== Epoch 1/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.6338608051004584 val_loss 0.5110404878854751 
val accuracy_score 0.8525641025641025 
 
+ ===== Epoch 2/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.3767538220604503 val_loss 0.3118230486288666 
val accuracy_score 0.8717948717948718 
 
+ ===== Epoch 3/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.20477609126114407 val_loss 0.3700940219685436 
val accuracy_score 0.8685897435897436 
 
+ ===== Epoch 4/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.14035852942501473 val_loss 0.563861498539336 
val accuracy_score 0.8253205128205128 
 
+ ===== Epoch 5/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.1167056391392749 val_loss 0.6100306114880368 
val accuracy_score 0.8108974358974359 
 
+ ===== Epoch 6/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.09496130302914081 val_loss 0.6083718031761237 
val accuracy_score 0.8189102564102564 
 
+ ===== Epoch 7/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.08550663144424833 val_loss 0.6195417784503662 
val accuracy_score 0.8125 
 
+ ===== Epoch 8/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.08143165533893683 val_loss 0.5801811606390401 
val accuracy_score 0.8237179487179487 
 
+ ===== Epoch 9/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.07390026275091378 val_loss 0.5385564258904195 
val accuracy_score 0.8413461538461539 
 
+ ===== Epoch 10/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.06780300799620112 val_loss 0.6376335435605142 
val accuracy_score 0.8189102564102564 
 
+ ===== Epoch 11/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.06618326105623067 val_loss 0.6906296622852097 
val accuracy_score 0.8108974358974359 
 
+ ===== Epoch 12/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.06028239079608234 val_loss 0.5408856855647172 
val accuracy_score 0.8413461538461539 
 
+ ===== Epoch 13/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.06367646567123526 val_loss 0.7370089248142904 
val accuracy_score 0.8173076923076923 
 
+ ===== Epoch 14/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.05580680016599639 val_loss 0.686850395379588 
val accuracy_score 0.8189102564102564 
 
+ ===== Epoch 15/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.05176930207947693 val_loss 0.5785449397691991 
val accuracy_score 0.8413461538461539 
 
+ ===== Epoch 16/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.051635362129327074 val_loss 0.8085516896942864 
val accuracy_score 0.8076923076923077 
 
+ ===== Epoch 17/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.05436099278940561 val_loss 0.6978585362579907 
val accuracy_score 0.8269230769230769 
 
+ ===== Epoch 18/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.04875912012080978 val_loss 0.5858038174454122 
val accuracy_score 0.8413461538461539 
 
+ ===== Epoch 19/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.05237455331531866 val_loss 0.6408585846453206 
val accuracy_score 0.8333333333333334 
 
+ ===== Epoch 20/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.042187035288033206 val_loss 0.5877949920424727 
val accuracy_score 0.8461538461538461 
 
+ ===== Epoch 21/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.05076581625085027 val_loss 0.573681224961183 
val accuracy_score 0.8493589743589743 
 
+ ===== Epoch 22/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.04210535104906448 val_loss 0.7045532629796071 
val accuracy_score 0.8301282051282052 
 
+ ===== Epoch 23/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.03673768883208214 val_loss 0.8451621029453236 
val accuracy_score 0.8108974358974359 
 
+ ===== Epoch 24/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.039531350174727344 val_loss 0.7067622697170008 
val accuracy_score 0.8285256410256411 
 
+ ===== Epoch 25/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.039786249597734004 val_loss 0.8500765514108934 
val accuracy_score 0.8076923076923077 
 
+ ===== Epoch 26/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.03865060113024743 val_loss 0.7563808261606028 
val accuracy_score 0.8253205128205128 
 
+ ===== Epoch 27/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.037979561501279735 val_loss 0.7984897986018041 
val accuracy_score 0.8141025641025641 
 
+ ===== Epoch 28/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.03896314850965502 val_loss 0.8736202424108342 
val accuracy_score 0.8125 
 
+ ===== Epoch 29/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.04567175929111777 val_loss 0.8325038460323413 
val accuracy_score 0.8141025641025641 
 
+ ===== Epoch 30/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=163.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))


train_loss 0.03735215076319645 val_loss 0.746669452614151 
val accuracy_score 0.8253205128205128 
 
