In [1]:
# ==== Install Dependencies

!pip install -q efficientnet_pytorch
!pip install -q albumentations
!pip install -q pytorch-fanatics 
!pip install -q pytorch_ranger

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
# ==== Import Libraries

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
import seaborn as sns
import random
import os


import albumentations as aug
from albumentations.pytorch.transforms import ToTensor
import matplotlib.pyplot as plt

from efficientnet_pytorch import EfficientNet
from tqdm import tqdm

from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset,DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

from pytorch_fanatics.dataloader import Cloader
from pytorch_fanatics.utils import EarlyStop 
from pytorch_fanatics.trainer import Trainer
from pytorch_fanatics.logger import Logger

import warnings
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

from pytorch_ranger import Ranger

In [3]:
from glob import glob

In [4]:
train_path = []
train_label = []

val_path = []
val_label = []

In [5]:
for i in glob("../input/brain-tumor-classification-mri/Training/glioma_tumor/*"):
    train_path.append(i)
    train_label.append(0)
    
for i in glob("../input/brain-tumor-classification-mri/Training/meningioma_tumor/*"):
    train_path.append(i)
    train_label.append(1)

for i in glob("../input/brain-tumor-classification-mri/Training/no_tumor/*"):
    train_path.append(i)
    train_label.append(2)    
    
for i in glob("../input/brain-tumor-classification-mri/Training/pituitary_tumor/*"):
    train_path.append(i)
    train_label.append(3)

for i in glob("../input/brain-tumor-classification-mri/Testing/glioma_tumor/*"):
    val_path.append(i)
    val_label.append(0)
    
for i in glob("../input/brain-tumor-classification-mri/Testing/meningioma_tumor/*"):
    val_path.append(i)
    val_label.append(1)

for i in glob("../input/brain-tumor-classification-mri/Testing/no_tumor/*"):
    val_path.append(i)
    val_label.append(2)    
    
for i in glob("../input/brain-tumor-classification-mri/Testing/pituitary_tumor/*"):
    val_path.append(i)
    val_label.append(3)    
    

In [6]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [7]:
# ===== Augmentations

mean       = (0.485, 0.456, 0.406)
std        = (0.229, 0.224, 0.225)
train_tfms = aug.Compose([
            aug.Resize(224,224),
            aug.Flip(p=0.5),
            aug.ShiftScaleRotate(rotate_limit=(-45,45)),
            aug.GaussNoise(p=0.35),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

test_tfms  = aug.Compose([
            aug.Resize(224,224),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

In [8]:

train_dataset    = Cloader(train_path,train_label,None,train_tfms)
test_dataset     = Cloader(val_path,val_label,None,test_tfms)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4)
val_dataloader   = DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=4)

device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [9]:
# ===== Define model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.base_model = EfficientNet.from_pretrained('efficientnet-b0',num_classes=4)
    def forward(self, image, targets):
        batch_size, _, _, _ = image.shape
        out = self.base_model(image)
        targets = torch.tensor(targets,dtype=torch.int64)
        loss = nn.CrossEntropyLoss()(out.view(batch_size,4), targets)
        return out, loss

model = Net()
model.to(device);

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth


HBox(children=(FloatProgress(value=0.0, max=21388428.0), HTML(value='')))


Loaded pretrained weights for efficientnet-b0


In [10]:
def softmax(array):
    return np.exp(array)/np.sum(np.exp(array),axis=1).reshape(-1,1)

In [11]:
optimizer = Ranger(model.parameters(),lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer,factor=0.8, mode="min",min_lr= 5e-6, patience=2)

trainer   = Trainer(model=model,optimizer=optimizer,device=device,val_scheduler=scheduler)
logger    = Logger()

es        = EarlyStop(patience=5,mode="min") # mode = min to minimise loss

In [12]:
epochs = 50
best_loss = np.Inf

for epoch in range(epochs):
    logger.write(f"+ ===== Epoch {epoch+1}/{epochs} ===== +")
    train_loss              = trainer.train(train_dataloader)
    y_true,y_pred ,val_loss = trainer.evaluate(val_dataloader)
    y_pred                  = softmax(y_pred)
    accuracy                = accuracy_score(y_true,np.argmax(y_pred,axis=1))
    logger.write(f"train_loss {train_loss} val_loss {val_loss} ")
    logger.write(f"val accuracy_score {accuracy} ")
    logger.write(" ")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(),'best.pth')


+ ===== Epoch 1/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 1.3632821030086943 val_loss 1.3338165374902577 
val accuracy_score 0.33756345177664976 
 
+ ===== Epoch 2/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 1.2387222515212166 val_loss 1.2581836993877706 
val accuracy_score 0.5 
 
+ ===== Epoch 3/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 1.0439172983169558 val_loss 1.1604431867599487 
val accuracy_score 0.5406091370558376 
 
+ ===== Epoch 4/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.8189246210787033 val_loss 1.0637684005957384 
val accuracy_score 0.5964467005076142 
 
+ ===== Epoch 5/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.6244230535295273 val_loss 1.0042449923662038 
val accuracy_score 0.6243654822335025 
 
+ ===== Epoch 6/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.47523902853329963 val_loss 0.9510921056453998 
val accuracy_score 0.6294416243654822 
 
+ ===== Epoch 7/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.38428061306476585 val_loss 0.9502132557905636 
val accuracy_score 0.6548223350253807 
 
+ ===== Epoch 8/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.3212083910902341 val_loss 0.946543385203068 
val accuracy_score 0.6725888324873096 
 
+ ===== Epoch 9/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.2868098917106788 val_loss 0.9807196053174826 
val accuracy_score 0.6776649746192893 
 
+ ===== Epoch 10/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.2313024643394683 val_loss 0.9339298657499827 
val accuracy_score 0.7385786802030457 
 
+ ===== Epoch 11/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.21978503701587523 val_loss 0.92492159226766 
val accuracy_score 0.7436548223350253 
 
+ ===== Epoch 12/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.19331734085248584 val_loss 0.8784816245046946 
val accuracy_score 0.7639593908629442 
 
+ ===== Epoch 13/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.18358033899631762 val_loss 0.8270232263379372 
val accuracy_score 0.8071065989847716 
 
+ ===== Epoch 14/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.150186852655477 val_loss 0.9018686677400883 
val accuracy_score 0.7791878172588832 
 
+ ===== Epoch 15/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.14927311316132547 val_loss 0.868008213977401 
val accuracy_score 0.8071065989847716 
 
+ ===== Epoch 16/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.14485301433338063 val_loss 0.836107366789992 
val accuracy_score 0.8096446700507615 
 
+ ===== Epoch 17/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.12358163760768047 val_loss 0.8924434190759294 
val accuracy_score 0.7918781725888325 
 
+ ===== Epoch 18/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.10485137920412749 val_loss 0.8950011802192491 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 19/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.09492873307317494 val_loss 0.8848329430016187 
val accuracy_score 0.8045685279187818 
 
+ ===== Epoch 20/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.08935698925827942 val_loss 0.8973477974963876 
val accuracy_score 0.8020304568527918 
 
+ ===== Epoch 21/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.10759255471121934 val_loss 0.9078660105140163 
val accuracy_score 0.7893401015228426 
 
+ ===== Epoch 22/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.08987410788217354 val_loss 0.9075711324381139 
val accuracy_score 0.8045685279187818 
 
+ ===== Epoch 23/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.09395858448826601 val_loss 0.9362381704939673 
val accuracy_score 0.799492385786802 
 
+ ===== Epoch 24/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.08924757208054261 val_loss 0.9231995889522995 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 25/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.0841008554316229 val_loss 0.9638376145695267 
val accuracy_score 0.7868020304568528 
 
+ ===== Epoch 26/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.07280968305551343 val_loss 0.9636451556848791 
val accuracy_score 0.7893401015228426 
 
+ ===== Epoch 27/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06764663462009694 val_loss 0.9626357445779902 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 28/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.0853968713949952 val_loss 0.9911916477319137 
val accuracy_score 0.7918781725888325 
 
+ ===== Epoch 29/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.07050448903400038 val_loss 0.9805495764415425 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 30/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06380761959072616 val_loss 0.9564943373239099 
val accuracy_score 0.799492385786802 
 
+ ===== Epoch 31/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.063240387945229 val_loss 0.9557188080731207 
val accuracy_score 0.8020304568527918 
 
+ ===== Epoch 32/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06567518714194495 val_loss 0.9479144250239747 
val accuracy_score 0.7944162436548223 
 
+ ===== Epoch 33/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05690234834328293 val_loss 0.991676981525066 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 34/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05673669761874611 val_loss 0.9836619168860263 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 35/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06122781761094099 val_loss 0.9741821769678678 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 36/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.053841106479780534 val_loss 0.9624062671612661 
val accuracy_score 0.8045685279187818 
 
+ ===== Epoch 37/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06767502970113935 val_loss 0.9714078412641988 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 38/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.056387560693029734 val_loss 0.9717117136147303 
val accuracy_score 0.8045685279187818 
 
+ ===== Epoch 39/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05354607549702957 val_loss 0.9920670038733919 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 40/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05074073168086925 val_loss 0.9885854475074805 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 41/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.06721346508711576 val_loss 0.9884455701998938 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 42/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.056245919130742554 val_loss 0.9856122320947739 
val accuracy_score 0.7918781725888325 
 
+ ===== Epoch 43/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05047341765732399 val_loss 0.979161681741691 
val accuracy_score 0.7918781725888325 
 
+ ===== Epoch 44/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.055939342867996934 val_loss 0.9987579025328159 
val accuracy_score 0.7918781725888325 
 
+ ===== Epoch 45/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05680089574824602 val_loss 0.996385161442539 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 46/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05624925991675506 val_loss 0.9899911562410684 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 47/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.05410798910694817 val_loss 0.9908609696251985 
val accuracy_score 0.7969543147208121 
 
+ ===== Epoch 48/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.04592610070637116 val_loss 0.9890988227338171 
val accuracy_score 0.8020304568527918 
 
+ ===== Epoch 49/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.052408682141039116 val_loss 0.9830494107803902 
val accuracy_score 0.8020304568527918 
 
+ ===== Epoch 50/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=90.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))


train_loss 0.057560275758927075 val_loss 0.98839767555742 
val accuracy_score 0.8020304568527918 
 
