In [1]:
# ==== Install Dependencies

!pip install -q efficientnet_pytorch
!pip install -q albumentations
!pip install -q pytorch-fanatics 
!pip install -q pytorch_ranger

You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip install --upgrade pip' command.[0m


In [2]:
# ==== Import Libraries

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F
import torch.nn as nn
import seaborn as sns
import random
import os


import albumentations as aug
from albumentations.pytorch.transforms import ToTensor
import matplotlib.pyplot as plt

from efficientnet_pytorch import EfficientNet
from tqdm import tqdm

from sklearn.model_selection import train_test_split as tts
from sklearn.metrics import accuracy_score
from torch.utils.data import Dataset,DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# from pytorch_fanatics.dataloader import Cloader
from pytorch_fanatics.utils import EarlyStop 
from pytorch_fanatics.trainer import Trainer
from pytorch_fanatics.logger import Logger

import warnings
warnings.filterwarnings("ignore") 
warnings.filterwarnings("ignore", category=DeprecationWarning) 

from pytorch_ranger import Ranger

In [3]:
from glob import glob

In [4]:
train_path = []
train_label = []

val_path = []
val_label = []

In [5]:
for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/train/MildDemented/*"):
    train_path.append(i)
    train_label.append(0)
    
for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/train/ModerateDemented/*"):
    train_path.append(i)
    train_label.append(1)

for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/train/NonDemented/*"):
    train_path.append(i)
    train_label.append(2)    
    
for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/train/VeryMildDemented/*"):
    train_path.append(i)
    train_label.append(3)

for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/test/MildDemented/*"):
    val_path.append(i)
    val_label.append(0)
    
for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/test/ModerateDemented/*"):
    val_path.append(i)
    val_label.append(1)

for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/test/NonDemented/*"):
    val_path.append(i)
    val_label.append(2)    
    
for i in glob("../input/alzheimers-dataset-4-class-of-images/Alzheimer_s Dataset/test/VeryMildDemented/*"):
    val_path.append(i)
    val_label.append(3)    
    

In [6]:
SEED = 42
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

In [7]:
# ===== Augmentations

mean       = (0.485, 0.456, 0.406)
std        = (0.229, 0.224, 0.225)
train_tfms = aug.Compose([
            aug.Resize(224,224),
            aug.Flip(p=0.5),
            aug.ShiftScaleRotate(rotate_limit=(-45,45)),
            aug.GaussNoise(p=0.35),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

test_tfms  = aug.Compose([
            aug.Resize(224,224),
            aug.Normalize(mean,std,max_pixel_value=255.0,always_apply=True),
            ])

In [8]:
import numpy as np
from PIL import Image
from PIL import ImageFile   #ImageFile contains support for PIL to open and save images
ImageFile.LOAD_TRUNCATED_IMAGES=True  #If image is truncated(or corrupt) then also load it..
import torch

class Cloader:
    def __init__(self,image_path,targets,resize=None,transforms=None):
        self.image_path=image_path
        self.targets=targets
        self.resize=resize
        self.transforms=transforms
        
    def __len__(self):
        return len(self.image_path)
    
    def __getitem__(self,idx):
        image = Image.open(self.image_path[idx]).convert('RGB')
        targets = self.targets[idx]
        if self.resize is not None:
            image = image.resize(
                (self.resize[1], self.resize[0]), resample=Image.BILINEAR
            )
        image = np.array(image)
        if self.transforms is not None:
            augmented = self.transforms(image=image)
            image = augmented["image"]
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)
        return {
            "image": torch.tensor(image, dtype=torch.float),
            "targets": torch.tensor(targets, dtype=torch.long),
        }


In [9]:

train_dataset    = Cloader(train_path,train_label,None,train_tfms)
test_dataset     = Cloader(val_path,val_label,None,test_tfms)

train_dataloader = DataLoader(train_dataset,batch_size=32,shuffle=True,num_workers=4)
val_dataloader   = DataLoader(test_dataset,batch_size=32,shuffle=False,num_workers=4)

device           = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [10]:
# ===== Define model

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.base_model = EfficientNet.from_pretrained('efficientnet-b0',num_classes=4)
    def forward(self, image, targets):
        batch_size, _, _, _ = image.shape
        out = self.base_model(image)
        targets = torch.tensor(targets,dtype=torch.int64)
        loss = nn.CrossEntropyLoss()(out.view(batch_size,4), targets)
        return out, loss

model = Net()
model.to(device);

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b0-355c32eb.pth


HBox(children=(FloatProgress(value=0.0, max=21388428.0), HTML(value='')))


Loaded pretrained weights for efficientnet-b0


In [11]:
def softmax(array):
    return np.exp(array)/np.sum(np.exp(array),axis=1).reshape(-1,1)

In [12]:
optimizer = Ranger(model.parameters(),lr=1e-4)
scheduler = ReduceLROnPlateau(optimizer,factor=0.8, mode="min",min_lr= 5e-6, patience=2)

trainer   = Trainer(model=model,optimizer=optimizer,device=device,val_scheduler=scheduler)
logger    = Logger()

es        = EarlyStop(patience=5,mode="min") # mode = min to minimise loss

In [13]:
epochs = 50
best_loss = np.Inf

for epoch in range(epochs):
    logger.write(f"+ ===== Epoch {epoch+1}/{epochs} ===== +")
    train_loss              = trainer.train(train_dataloader)
    y_true,y_pred ,val_loss = trainer.evaluate(val_dataloader)
    y_pred                  = softmax(y_pred)
    accuracy                = accuracy_score(y_true,np.argmax(y_pred,axis=1))
    logger.write(f"train_loss {train_loss} val_loss {val_loss} ")
    logger.write(f"val accuracy_score {accuracy} ")
    logger.write(" ")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(),'best.pth')


+ ===== Epoch 1/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 1.332680420845932 val_loss 1.1896885231137277 
val accuracy_score 0.5027365129007036 
 
+ ===== Epoch 2/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 1.1195236500005545 val_loss 1.0179789379239081 
val accuracy_score 0.5230648944487881 
 
+ ===== Epoch 3/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.9707580941804449 val_loss 0.9744964979588983 
val accuracy_score 0.5418295543393276 
 
+ ===== Epoch 4/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.8957206754951004 val_loss 0.9223350264132025 
val accuracy_score 0.5598123534010946 
 
+ ===== Epoch 5/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.8541670660054459 val_loss 0.9649281807243825 
val accuracy_score 0.5488663017982799 
 
+ ===== Epoch 6/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.8061070445901859 val_loss 0.8543689988553523 
val accuracy_score 0.6020328381548085 
 
+ ===== Epoch 7/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.7758421024180345 val_loss 0.8179149933159349 
val accuracy_score 0.6286161063330727 
 
+ ===== Epoch 8/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.7461396600148699 val_loss 0.8174756415188312 
val accuracy_score 0.6512900703674745 
 
+ ===== Epoch 9/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.705512902011042 val_loss 0.8951152510941027 
val accuracy_score 0.6379984362783424 
 
+ ===== Epoch 10/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.6624965523340687 val_loss 0.7750344887375833 
val accuracy_score 0.6755277560594214 
 
+ ===== Epoch 11/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.6240309913706334 val_loss 0.7654146775603295 
val accuracy_score 0.6598905394839718 
 
+ ===== Epoch 12/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.6113331799181352 val_loss 0.7705073870718478 
val accuracy_score 0.6919468334636435 
 
+ ===== Epoch 13/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.5532120930852357 val_loss 0.7781909413635733 
val accuracy_score 0.6864738076622361 
 
+ ===== Epoch 14/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.5308381848453734 val_loss 0.7409298069775104 
val accuracy_score 0.6989835809225958 
 
+ ===== Epoch 15/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.47359034733742666 val_loss 0.7530571214854719 
val accuracy_score 0.6888193901485535 
 
+ ===== Epoch 16/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.44228048585586666 val_loss 0.7594362266361714 
val accuracy_score 0.6919468334636435 
 
+ ===== Epoch 17/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.4125137208781629 val_loss 0.8064902033656838 
val accuracy_score 0.7099296325254105 
 
+ ===== Epoch 18/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.3754276170493652 val_loss 0.8236003257334233 
val accuracy_score 0.7310398749022674 
 
+ ===== Epoch 19/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.3606451975632897 val_loss 0.7303829330950974 
val accuracy_score 0.7412040656763096 
 
+ ===== Epoch 20/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.3566103551698768 val_loss 0.7755232183262705 
val accuracy_score 0.7326035965598123 
 
+ ===== Epoch 21/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.3186632983432793 val_loss 0.747872245311737 
val accuracy_score 0.7458952306489445 
 
+ ===== Epoch 22/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.29735732453395114 val_loss 0.8413765344768761 
val accuracy_score 0.7146207974980453 
 
+ ===== Epoch 23/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.29792130511739984 val_loss 0.7462646454572677 
val accuracy_score 0.7404222048475372 
 
+ ===== Epoch 24/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.2656846600844994 val_loss 0.7385421253740787 
val accuracy_score 0.746677091477717 
 
+ ===== Epoch 25/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.2502294268389666 val_loss 0.7989237792789937 
val accuracy_score 0.7474589523064894 
 
+ ===== Epoch 26/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.2584931003566114 val_loss 0.7495335325598714 
val accuracy_score 0.7513682564503519 
 
+ ===== Epoch 27/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.24269279712660713 val_loss 0.7283386636525391 
val accuracy_score 0.763096168881939 
 
+ ===== Epoch 28/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.21776151758913662 val_loss 0.762834333628416 
val accuracy_score 0.7670054730258014 
 
+ ===== Epoch 29/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.20629182733271423 val_loss 0.7544992204755545 
val accuracy_score 0.7724784988272088 
 
+ ===== Epoch 30/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.21181096664126622 val_loss 0.7680088046938179 
val accuracy_score 0.766223612197029 
 
+ ===== Epoch 31/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.20481038026465395 val_loss 0.7134311769157649 
val accuracy_score 0.7857701329163409 
 
+ ===== Epoch 32/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.2171812288648222 val_loss 0.7300512731075288 
val accuracy_score 0.7732603596559813 
 
+ ===== Epoch 33/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1835400804368234 val_loss 0.807387161999941 
val accuracy_score 0.7740422204847537 
 
+ ===== Epoch 34/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.18047587471719115 val_loss 0.7946529582142831 
val accuracy_score 0.7498045347928068 
 
+ ===== Epoch 35/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.18484718755165247 val_loss 0.7539854541420938 
val accuracy_score 0.7826426896012509 
 
+ ===== Epoch 36/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1657760814295053 val_loss 0.757728199660778 
val accuracy_score 0.7732603596559813 
 
+ ===== Epoch 37/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.18065831961792814 val_loss 0.7991533294320106 
val accuracy_score 0.7584050039093041 
 
+ ===== Epoch 38/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.22318770658942114 val_loss 0.7628744550049308 
val accuracy_score 0.784206411258796 
 
+ ===== Epoch 39/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1885209768223835 val_loss 0.8012989528477191 
val accuracy_score 0.7701329163408913 
 
+ ===== Epoch 40/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.15170562237223476 val_loss 0.7838983695954085 
val accuracy_score 0.7701329163408913 
 
+ ===== Epoch 41/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.16467571675037007 val_loss 0.7836336754262446 
val accuracy_score 0.7763878029710711 
 
+ ===== Epoch 42/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1790526195990372 val_loss 0.7795793108642103 
val accuracy_score 0.7787333854573886 
 
+ ===== Epoch 43/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1610348045663989 val_loss 0.7530464347451926 
val accuracy_score 0.784206411258796 
 
+ ===== Epoch 44/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1470140052947754 val_loss 0.773567172512412 
val accuracy_score 0.7849882720875684 
 
+ ===== Epoch 45/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.15096986244045058 val_loss 0.8085085835307837 
val accuracy_score 0.7771696637998436 
 
+ ===== Epoch 46/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1449260465014055 val_loss 0.8009914431720974 
val accuracy_score 0.7756059421422987 
 
+ ===== Epoch 47/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1556706982115226 val_loss 0.799178710207343 
val accuracy_score 0.7740422204847537 
 
+ ===== Epoch 48/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1466179146020678 val_loss 0.8009882006794213 
val accuracy_score 0.7763878029710711 
 
+ ===== Epoch 49/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.1559488419790446 val_loss 0.8110296070575714 
val accuracy_score 0.7709147771696638 
 
+ ===== Epoch 50/50 ===== +


HBox(children=(FloatProgress(value=0.0, max=161.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


train_loss 0.14609237668908528 val_loss 0.7981901068240403 
val accuracy_score 0.7740422204847537 
 
