In [1]:
# ==== Install Dependencies

!pip install -q efficientnet-pytorch
!pip install -q albumentations
!pip install -q pytorch-fanatics 
!pip install -q pytorch_ranger

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
# ==== Import Libraries

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
import seaborn as sns
import random
import os


import albumentations as aug
from albumentations.pytorch.transforms import ToTensor
import matplotlib.pyplot as plt

from efficientnet_pytorch import EfficientNet
from tqdm import tqdm

from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset,DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch_fanatics.dataloader import Cloader
from pytorch_fanatics.utils import EarlyStop 
from pytorch_fanatics.trainer import Trainer
from pytorch_fanatics.logger import Logger

import warnings
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

from pytorch_ranger import Ranger

In [3]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [4]:
training_data_path = "../input/aptos2019-blindness-detection/train_images"

df = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')

X_train , X_val ,Y_train , Y_val = tts(df, df.diagnosis.values, test_size=0.20
                                       ,random_state=42,stratify=df.diagnosis.values)
X_train       = X_train.reset_index(drop=True)
X_val         = X_val.reset_index(drop=True)

In [5]:
# ===== Augmentations

mean       = (0.485, 0.456, 0.406)
std        = (0.229, 0.224, 0.225)
train_tfms = aug.Compose([
            aug.Resize(224,224),
            aug.HorizontalFlip(p=0.5),
            aug.RandomBrightnessContrast(0.1,0.1),
            aug.HueSaturationValue(10,10,10),
            aug.RGBShift(),
            aug.ShiftScaleRotate(rotate_limit=(-45,45)),
            aug.GaussNoise(p=0.35),
            aug.IAASharpen(),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

test_tfms  = aug.Compose([
            aug.Resize(224,224),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

In [6]:
train_images     = X_train.id_code.values.tolist()
train_images     = [os.path.join(training_data_path, i+".png") for i in train_images]

test_images      = X_val.id_code.values.tolist()
test_images      = [os.path.join(training_data_path, i+".png") for i in test_images]

train_dataset    = Cloader(train_images,X_train.diagnosis.values,None,train_tfms)
test_dataset     = Cloader(test_images,X_val.diagnosis.values,None,test_tfms)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4)
val_dataloader   = DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=4)

device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [7]:
# ===== Define model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.base_model = EfficientNet.from_pretrained('efficientnet-b0',num_classes=5)
    def forward(self, image, targets):
        batch_size, _, _, _ = image.shape
        out = self.base_model(image)
        targets = torch.tensor(targets,dtype=torch.int64)
        loss = nn.CrossEntropyLoss()(out.view(batch_size,5), targets)
        return out, loss

model = Net()
model.to(device);

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth


HBox(children=(FloatProgress(value=0.0, max=21388428.0), HTML(value='')))


Loaded pretrained weights for efficientnet-b0


In [8]:
def softmax(array):
    return np.exp(array)/np.sum(np.exp(array),axis=1).reshape(-1,1)

In [9]:
optimizer = Ranger(model.parameters(),lr=5e-4)
scheduler = ReduceLROnPlateau(optimizer,factor=0.6, mode="min", patience=2)

trainer   = Trainer(model=model,optimizer=optimizer,device=device,val_scheduler=scheduler)
logger    = Logger()


In [10]:
epochs = 30
best_loss = np.Inf
for epoch in range(epochs):
    logger.write(f"+ ===== Epoch {epoch+1}/{epochs} ===== +")
    train_loss              = trainer.train(train_dataloader)
    y_true,y_pred ,val_loss = trainer.evaluate(val_dataloader)
    y_pred                  = softmax(y_pred)
    accuracy                = accuracy_score(y_true,np.argmax(y_pred,axis=1))
    
    logger.write(f"train_loss {train_loss} val_loss {val_loss} ")
    logger.write(f"val accuracy_score {accuracy} ")
    logger.write(" ")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(),'best.pth')


+ ===== Epoch 1/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 1.3939328990552742 val_loss 1.138433655966883 
val accuracy_score 0.504774897680764 
 
+ ===== Epoch 2/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.8837939850662067 val_loss 0.7619689625242483 
val accuracy_score 0.732605729877217 
 
+ ===== Epoch 3/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.7090144743737967 val_loss 0.6493037835411404 
val accuracy_score 0.7517053206002728 
 
+ ===== Epoch 4/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.633161683121453 val_loss 0.597512521173643 
val accuracy_score 0.7735334242837654 
 
+ ===== Epoch 5/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.5796511869715609 val_loss 0.5587652675483537 
val accuracy_score 0.781718963165075 
 
+ ===== Epoch 6/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.5373851196921389 val_loss 0.539426793222842 
val accuracy_score 0.7844474761255116 
 
+ ===== Epoch 7/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.5121779556831587 val_loss 0.5403107326963673 
val accuracy_score 0.7926330150068213 
 
+ ===== Epoch 8/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.48634982886521727 val_loss 0.542342702979627 
val accuracy_score 0.819918144611187 
 
+ ===== Epoch 9/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.4471913532394432 val_loss 0.525859165450801 
val accuracy_score 0.810368349249659 
 
+ ===== Epoch 10/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.44089465944663325 val_loss 0.5147706717252732 
val accuracy_score 0.8130968622100955 
 
+ ===== Epoch 11/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.42123729646529834 val_loss 0.49517441832500964 
val accuracy_score 0.8240109140518418 
 
+ ===== Epoch 12/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.40458202216288325 val_loss 0.5007943679457124 
val accuracy_score 0.8212824010914052 
 
+ ===== Epoch 13/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.3847458357720271 val_loss 0.5610736465972402 
val accuracy_score 0.8267394270122783 
 
+ ===== Epoch 14/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.3604671556664551 val_loss 0.5959009476329972 
val accuracy_score 0.8090040927694406 
 
+ ===== Epoch 15/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.3410887568379225 val_loss 0.5676258867201597 
val accuracy_score 0.8035470668485676 
 
+ ===== Epoch 16/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.3114290278728891 val_loss 0.5366094008735988 
val accuracy_score 0.8226466575716235 
 
+ ===== Epoch 17/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.294886461256639 val_loss 0.6467421443566033 
val accuracy_score 0.8267394270122783 
 
+ ===== Epoch 18/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.2807566000391607 val_loss 0.5970478018988734 
val accuracy_score 0.834924965893588 
 
+ ===== Epoch 19/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.24528858956435448 val_loss 0.5766482068144757 
val accuracy_score 0.8267394270122783 
 
+ ===== Epoch 20/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.2659053227499775 val_loss 0.6511934464392454 
val accuracy_score 0.8335607094133697 
 
+ ===== Epoch 21/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.24499450186672417 val_loss 0.6005987330623295 
val accuracy_score 0.8294679399727148 
 
+ ===== Epoch 22/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.2330396636186735 val_loss 0.6054006737211476 
val accuracy_score 0.8321964529331515 
 
+ ===== Epoch 23/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.20996340505940753 val_loss 0.5967893678209056 
val accuracy_score 0.8171896316507503 
 
+ ===== Epoch 24/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.2075895534666336 val_loss 0.6371063875115437 
val accuracy_score 0.82537517053206 
 
+ ===== Epoch 25/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.20605564676225196 val_loss 0.6258371394613516 
val accuracy_score 0.8294679399727148 
 
+ ===== Epoch 26/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.1929361527056798 val_loss 0.6398223755152329 
val accuracy_score 0.8267394270122783 
 
+ ===== Epoch 27/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.19186901088561054 val_loss 0.640695792177449 
val accuracy_score 0.8308321964529332 
 
+ ===== Epoch 28/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.18413919384550795 val_loss 0.6254686052384584 
val accuracy_score 0.82537517053206 
 
+ ===== Epoch 29/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.19880014933321782 val_loss 0.6318837080312814 
val accuracy_score 0.8321964529331515 
 
+ ===== Epoch 30/30 ===== +


HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))


train_loss 0.18849391767593185 val_loss 0.6391214443289716 
val accuracy_score 0.8267394270122783 
 
