In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import json
import seaborn as sns
import matplotlib.pyplot as plt
import cv2
import torch
from torch import nn
import random
import torchvision
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import sklearn
from sklearn.model_selection import GroupKFold, StratifiedKFold
from torch.cuda.amp import autocast, GradScaler
!pip install timm
import timm
import time

from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

In [None]:
Base_Dir = '../input/cassava-leaf-disease-classification/'
CFG = {
    'fold_num' : 5,
    'seed' : 48,
    'model_arch' : 'tf_efficientnet_b4_ns',
    'img_size':  512,
    'epochs' : 10,
    'train_bs' : 16,
    'valid_bs': 32,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'T_0':10
}

In [None]:
'''
Helper Functions
'''

def seed_everything(seed):
    '''
        creates seed for all necessary things, and setups the environment 
    '''
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

def get_img(path):
    '''
        Get's the image from the image path
    '''
    im_bgr= cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb

In [None]:
with open(os.path.join(Base_Dir,'label_num_to_disease_map.json')) as file:
        map_classes = json.loads(file.read())
        map_classes = {int(k): v for k, v in map_classes.items()}

In [None]:
df_train = pd.read_csv(os.path.join(Base_Dir,'train.csv'))
df_train['class_name'] = df_train['label'].map(map_classes)
df_train

In [None]:
plt.figure(figsize=(10,3))
sns.countplot(y='class_name',data=df_train)

In [None]:
class CassavaDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        
        self.output_label = output_label
        
            
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.df['label'][index]
          
        img  = get_img("{}/{}".format(self.data_root, self.df.loc[index]['image_id']))
        
        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.output_label == True:
            return img, target
        else:
            return img

In [None]:
def get_train_transforms():
    return Compose([
        RandomResizedCrop(CFG['img_size'],CFG['img_size']),
        Transpose(p=0.5),
        HorizontalFlip(p=0.5),
        VerticalFlip(p=0.5),
        ShiftScaleRotate(p=0.5),
        HueSaturationValue(hue_shift_limit=0.2,sat_shift_limit=0.2,val_shift_limit=0.2,p=0.5),
        RandomBrightnessContrast(brightness_limit=(-0.1,0.1),contrast_limit=(-0.1,0.1),p=0.5),
        Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.255],max_pixel_value=255.0,p=1.0),
        CoarseDropout(p=0.5),
        Cutout(p=0.5),
        ToTensorV2(p=1.0),
    ],p=1.0)

def get_valid_transforms():
    return Compose([
        CenterCrop(CFG['img_size'],CFG['img_size'],p=1.0),
        Resize(CFG['img_size'],CFG['img_size']),
        Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.255],max_pixel_value=255.0,p=1.0),
        ToTensorV2(p=1.0),
    ],p=1)

In [None]:
class CassavaImageClassifier(nn.Module):
    def __init__(self,model_arch,n_class,pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch,pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features,n_class)

    def forward(self,x):
        x=self.model(x)
        return x

In [None]:
def prepare_dataloader(df, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    train_ = df_train.loc[trn_idx, :].reset_index(drop=True)
    valid_ = df_train.loc[val_idx, :].reset_index(drop=True)
    
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_loader = DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,
        num_workers = CFG['num_workers'],
    )
    
    val_loader = DataLoader(
        valid_ds,
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

In [None]:
 def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        #print(image_labels.shape, exam_label.shape)
        with autocast():
            image_preds = model(imgs)   #output = model(input)
            #print(image_preds.shape, exam_pred.shape)

            loss = loss_fn(image_preds, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                
                pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

In [None]:
if __name__ == '__main__':
    
    seed_everything(CFG['seed'])
    
    folds = StratifiedKFold(n_splits = CFG['fold_num']).split(np.arange(df_train.shape[0]),df_train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        
        if fold>0:
            break
            
        print('Training with fold {} started'.format(fold))
        print(len(trn_idx),len(val_idx))
        train_loader, val_loader = prepare_dataloader(df_train, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')
        
        device = torch.device(CFG['device'])
        model = CassavaImageClassifier(CFG['model_arch'],df_train.label.nunique(),pretrained=True).to(device)
        scaler = GradScaler()
        optimizer = torch.optim.Adam(model.parameters(),lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG['T_0'],T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
        
        loss_tr = nn.CrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        
        for epoch in range(CFG['epochs']):
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, schd_batch_update=False)
            with torch.no_grad():
                valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)
                    
            torch.save(model.state_dict(), '{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))
        
        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()