In [1]:
package_paths = [
    '../input/pytorch-image-models/pytorch-image-models-master', #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
    '../input/image-fmix/FMix-master'
]
import sys; 

for pth in package_paths:
    sys.path.append(pth)
    
# from fmix import sample_mask, make_low_freq_image, binarise_mask

In [33]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import timm

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage.interpolation import zoom
from catalyst.data.sampler import BalanceClassSampler

import fastai
from fastai.vision import *
from fastai.layers import AdaptiveConcatPool2d, Flatten, Mish

In [34]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')

In [38]:
train.head()

Unnamed: 0,image_id,label
0,1000015157.jpg,0
1,1000201771.jpg,3
2,100042118.jpg,1
3,1000723321.jpg,1
4,1000812911.jpg,3


In [52]:
oof_df = pd.read_pickle('oof_df.pickle')
noisy_images = list(set(oof_df[oof_df['log_loss']>1].image_id.tolist())&set(oof_df[oof_df['euclidean']>np.quantile(oof_df['euclidean'], .95)].image_id.tolist()))

In [53]:
train = train[~train['image_id'].isin(noisy_images)].reset_index(False)

In [54]:
#CassvaImgClassifier
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'vit_base_patch16_384',#'tf_efficientnet_b4_ns',
    'img_size': 512,
    'epochs': 10,
    'train_bs': 4,#16,
    'valid_bs': 4,#32,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 0, #4
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'freeze_bn_epochs':5,
}

# Helper Functions

In [55]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    #print(im_rgb)
    return im_rgb


# Dataset

In [56]:
def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


class CassavaDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True, 
                 one_hot_label=False,
                 do_fmix=False, 
                 fmix_params={
                     'alpha': 1., 
                     'decay_power': 3., 
                     'shape': (CFG['img_size'], CFG['img_size']),
                     'max_soft': True, 
                     'reformulate': False
                 },
                 do_cutmix=False,
                 cutmix_params={
                     'alpha': 1,
                 }
                ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        
        if output_label == True:
            self.labels = self.df['label'].values
            #print(self.labels)
            
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]
                #print(self.labels)
            
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.labels[index]
          
        img  = get_img("{}/{}".format(self.data_root, self.df.loc[index]['image_id']))

        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                #lam, mask = sample_mask(**self.fmix_params)
                
                lam = np.clip(np.random.beta(self.fmix_params['alpha'], self.fmix_params['alpha']),0.6,0.7)
                
                # Make mask, get mean / std
                mask = make_low_freq_image(self.fmix_params['decay_power'], self.fmix_params['shape'])
                mask = binarise_mask(mask, lam, self.fmix_params['shape'], self.fmix_params['max_soft'])
    
                fmix_ix = np.random.choice(self.df.index, size=1)[0]
                fmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)
                
                # mix image
                img = mask_torch*img+(1.-mask_torch)*fmix_img

                #print(mask.shape)

                #assert self.output_label==True and self.one_hot_label==True

                # mix target
                rate = mask.sum()/CFG['img_size']/CFG['img_size']
                target = rate*target + (1.-rate)*self.labels[fmix_ix]
                #print(target, mask, img)
                #assert False
        
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            #print(img.sum(), img.shape)
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']
                    
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']),0.3,0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox((CFG['img_size'], CFG['img_size']), lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (CFG['img_size'] * CFG['img_size']))
                target = rate*target + (1.-rate)*self.labels[cmix_ix]
                
            #print('-', img.sum())
            #print(target)
            #assert False
                            
        # do label smoothing
        #print(type(img), type(target))
        if self.output_label == True:
            return img, target
        else:
            return img

# Define Train\Validation Image Augmentations

In [57]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, CenterCrop, Resize
)

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
#             HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
#             RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            Cutout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

# Model

In [58]:
# class Model(nn.Module):
#     def __init__(self, arch='resnext50_32x4d_ssl', n=6, pre=True):
#         super().__init__()
#         m = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', arch)
#         self.enc = nn.Sequential(*list(m.children())[:-2])       
#         nc = list(m.children())[-1].in_features
#         self.head = nn.Sequential(AdaptiveConcatPool2d(),Flatten(),nn.Linear(2*nc,512),
#                             Mish(),nn.BatchNorm1d(512), nn.Dropout(0.5),nn.Linear(512,n))
        
#     def forward(self, *x):
#         shape = x[0].shape
#         n = len(x)
#         x = torch.stack(x,1).view(-1,shape[1],shape[2],shape[3])
#         #x: bs*N x 3 x 128 x 128
#         x = self.enc(x)
#         #x: bs*N x C x 4 x 4
#         shape = x.shape
#         #concatenate the output for tiles into a single map
#         x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous()\
#           .view(-1,shape[1],shape[2]*n,shape[3])
#         #x: bs x C x N*4 x 4
#         x = self.head(x)
#         #x: bs x n
#         return x
    

# class CassvaImgClassifier(nn.Module):
#     def __init__(self, model_arch, n_class, pretrained=False):
#         super().__init__()
#         self.model = timm.create_model(model_arch, pretrained=pretrained)
#         n_features = self.model.classifier.in_features
#         self.enc = nn.Sequential(*list(self.model.children())[:-2])
# #         self.model.classifier = nn.Linear(n_features, n_class)
#         self.head = nn.Sequential(AdaptiveConcatPool2d(),
#                                               Flatten(),
#                                               nn.Linear(2*n_features,512),
#                                               Mish(),
#                                               nn.BatchNorm1d(512), 
#                                               nn.Dropout(0.5),
#                                               nn.Linear(512,n_class))
            
#         '''
#         self.model.classifier = nn.Sequential(
#             nn.Dropout(0.3),
#             #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
#             nn.Linear(n_features, n_class, bias=True)
#         )
#         '''
#     def forward(self, x):
    
#         shape = x.shape
#         n = 1
#         x = x.view(-1,shape[1],shape[2],shape[3])
#         #x: bs*N x 3 x 128 x 128
#         x = self.enc(x)
#         #x: bs*N x C x 4 x 4
#         shape = x.shape
#         #concatenate the output for tiles into a single map
#         x = x.view(-1,n,shape[1],shape[2],shape[3]).permute(0,2,1,3,4).contiguous().view(-1,shape[1],shape[2]*n,shape[3])
#         #x: bs x C x N*4 x 4
#         x = self.head(x)
#         #x: bs x n
#         return x
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x    
    
class CustomViT(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        ### vit
        num_features = self.model.head.in_features
        self.model.head = nn.Linear(num_features, num_classes)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(num_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(num_features, num_classes, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x
    
# ====================================================
# ResNext Model
# ====================================================
class CustomResNext(nn.Module):
    def __init__(self, model_arch, num_classes, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        #='resnext50_32x4d',
        n_features = self.model.fc.in_features
        self.model.fc = nn.Linear(n_features, num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

In [59]:
# # CustomViT
# CFG = {
#     'fold_num': 5,
#     'seed': 719,
#     'model_arch': 'vit_base_patch16_384',#'tf_efficientnet_b4_ns',
#     'img_size': 384,
#     'epochs': 10,
#     'train_bs': 4,#16,
#     'valid_bs': 4,#32,
#     'T_0': 10,
#     'lr': 1e-4,
#     'min_lr': 1e-6,
#     'weight_decay':1e-6,
#     'num_workers': 0, #4
#     'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
#     'verbose_step': 1,
#     'device': 'cuda:0',
#     'freeze_bn_epochs':5,
# }

# #CassvaImgClassifier
# CFG = {
#     'fold_num': 5,
#     'seed': 719,
#     'model_arch': 'vit_base_patch16_384',#'tf_efficientnet_b4_ns',
#     'img_size': 512,
#     'epochs': 10,
#     'train_bs': 4,#16,
#     'valid_bs': 4,#32,
#     'T_0': 10,
#     'lr': 1e-4,
#     'min_lr': 1e-6,
#     'weight_decay':1e-6,
#     'num_workers': 0, #4
#     'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
#     'verbose_step': 1,
#     'device': 'cuda:0',
#     'freeze_bn_epochs':5,
# }

# #CustomResNext
# CFG = {
#     'fold_num': 5,
#     'seed': 719,
#     'model_arch': 'resnext50_32x4d',#'tf_efficientnet_b4_ns',
#     'img_size': 512,
#     'epochs': 10,
#     'train_bs': 4,#16,
#     'valid_bs': 4,#32,
#     'T_0': 10,
#     'lr': 1e-4,
#     'min_lr': 1e-6,
#     'weight_decay':1e-6,
#     'num_workers': 0, #4
#     'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
#     'verbose_step': 1,
#     'device': 'cuda:0',
#     'freeze_bn_epochs':5,
# }

# Training APIs

In [60]:
def prepare_dataloader(df, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=False, do_cutmix=False)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader, device, scheduler=None, schd_batch_update=False):
    model.train()

    t = time.time()
    running_loss = None

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        #print(image_labels.shape, exam_label.shape)
        with autocast():
            image_preds = model(imgs)   #output = model(input)
            #print(image_preds.shape, exam_pred.shape)

            loss = loss_fn(image_preds, image_labels)
            
            scaler.scale(loss).backward()

            if running_loss is None:
                running_loss = loss.item()
            else:
                running_loss = running_loss * .99 + loss.item() * .01

            if ((step + 1) %  CFG['accum_iter'] == 0) or ((step + 1) == len(train_loader)):
                # may unscale_ here if desired (e.g., to allow clipping unscaled gradients)

                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad() 
                
                if scheduler is not None and schd_batch_update:
                    scheduler.step()

            if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(train_loader)):
                description = f'epoch {epoch} loss: {running_loss:.4f}'
                
                pbar.set_description(description)
                
    if scheduler is not None and not schd_batch_update:
        scheduler.step()
        
def valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False):
    model.eval()

    t = time.time()
    loss_sum = 0
    sample_num = 0
    image_preds_all = []
    image_targets_all = []
    
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        image_preds = model(imgs)   #output = model(input)
        #print(image_preds.shape, exam_pred.shape)
        image_preds_all += [torch.argmax(image_preds, 1).detach().cpu().numpy()]
        image_targets_all += [image_labels.detach().cpu().numpy()]
        
        loss = loss_fn(image_preds, image_labels)
        
        loss_sum += loss.item()*image_labels.shape[0]
        sample_num += image_labels.shape[0]  

        if ((step + 1) % CFG['verbose_step'] == 0) or ((step + 1) == len(val_loader)):
            description = f'epoch {epoch} loss: {loss_sum/sample_num:.4f}'
            pbar.set_description(description)
    
    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)
    print('validation multi-class accuracy = {:.4f}'.format((image_preds_all==image_targets_all).mean()))
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(loss_sum/sample_num)
        else:
            scheduler.step()

In [61]:
# reference: https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/173733
class MyCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean'):
        super().__init__(weight=weight, reduction=reduction)
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss
    
class LabelSmoothingCrossEntropy(nn.Module):
    """
    NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.1):
        """
        Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothingCrossEntropy, self).__init__()
        assert smoothing < 1.0
        self.smoothing = smoothing
        self.confidence = 1. - smoothing

    def forward(self, x, target):
        logprobs = F.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()
    
    
def bi_tempered_logistic_loss(activations,
        labels,
        t1,
        t2,
        label_smoothing=0.0,
        num_iters=5,
        reduction = 'mean'):

    """Bi-Tempered Logistic Loss.
    Args:
      activations: A multi-dimensional tensor with last dimension `num_classes`.
      labels: A tensor with shape and dtype as activations (onehot), 
        or a long tensor of one dimension less than activations (pytorch standard)
      t1: Temperature 1 (< 1.0 for boundedness).
      t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
      label_smoothing: Label smoothing parameter between [0, 1). Default 0.0.
      num_iters: Number of iterations to run the method. Default 5.
      reduction: ``'none'`` | ``'mean'`` | ``'sum'``. Default ``'mean'``.
        ``'none'``: No reduction is applied, return shape is shape of
        activations without the last dimension.
        ``'mean'``: Loss is averaged over minibatch. Return shape (1,)
        ``'sum'``: Loss is summed over minibatch. Return shape (1,)
    Returns:
      A loss tensor.
    """

    if len(labels.shape)<len(activations.shape): #not one-hot
        labels_onehot = torch.zeros_like(activations)
        labels_onehot.scatter_(1, labels[..., None], 1)
    else:
        labels_onehot = labels

    if label_smoothing > 0:
        num_classes = labels_onehot.shape[-1]
        labels_onehot = ( 1 - label_smoothing * num_classes / (num_classes - 1) ) \
                * labels_onehot + \
                label_smoothing / (num_classes - 1)

    probabilities = tempered_softmax(activations, t2, num_iters)

    loss_values = labels_onehot * log_t(labels_onehot + 1e-10, t1) \
            - labels_onehot * log_t(probabilities, t1) \
            - labels_onehot.pow(2.0 - t1) / (2.0 - t1) \
            + probabilities.pow(2.0 - t1) / (2.0 - t1)
    loss_values = loss_values.sum(dim = -1) #sum over classes

    if reduction == 'none':
        return loss_values
    if reduction == 'sum':
        return loss_values.sum()
    if reduction == 'mean':
        return loss_values.mean()

In [62]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',#'tf_efficientnet_b4_ns',
    'img_size': 512,
    'epochs': 10,
    'train_bs': 4,#16,
    'valid_bs': 4,#32,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 0, #4
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'freeze_bn_epochs':5,
}

# Main Loop

In [63]:
################ freeze bn 
def freeze_batchnorm_stats(net):
    try:
        for m in net.modules():
            if isinstance(m,nn.BatchNorm2d) or isinstance(m,nn.LayerNorm):
                m.eval()
    except ValuError:
        print('error with batchnorm2d or layernorm')
        return

if __name__ == '__main__':
     # for training only, need nightly build pytorch

    seed_everything(CFG['seed'])
    
    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first

        print('Training with {} started'.format(fold))

        print(len(trn_idx), len(val_idx))
        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')

        device = torch.device(CFG['device'])
        
        model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)
        scaler = GradScaler()   
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.1, step_size=CFG['epochs']-1)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
        #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1, div_factor=25, 
        #                                                max_lr=CFG['lr'], epochs=CFG['epochs'], steps_per_epoch=len(train_loader))
        
        loss_tr = LabelSmoothingCrossEntropy().to(device) #MyCrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        
        for epoch in range(CFG['epochs']):
#             if epoch < CFG['freeze_bn_epochs']:
#                 freeze_batchnorm_stats(model)  
            
            train_one_epoch(epoch, model, loss_tr, optimizer, train_loader, device, scheduler=scheduler, schd_batch_update=False)

            with torch.no_grad():
                valid_one_epoch(epoch, model, loss_fn, val_loader, device, scheduler=None, schd_loss_update=False)

            torch.save(model.state_dict(),'{}_fold_{}'.format(CFG['model_arch'], fold))
            
#         torch.save(model.cnn_model.state_dict(),'{}/cnn_model_fold_{}_{}'.format(CFG['model_path'], fold, CFG['tag']))
        del model, optimizer, train_loader, val_loader, scaler, scheduler
        torch.cuda.empty_cache()

Training with 0 started
16261 4066


epoch 0 loss: 0.6884: 100%|██████████| 4066/4066 [11:52<00:00,  5.71it/s]
epoch 0 loss: 0.3718: 100%|██████████| 1017/1017 [01:15<00:00, 13.56it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.8992


epoch 1 loss: 0.6499: 100%|██████████| 4066/4066 [11:51<00:00,  5.72it/s]
epoch 1 loss: 0.3673: 100%|██████████| 1017/1017 [01:12<00:00, 14.12it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.8965


epoch 2 loss: 0.6567: 100%|██████████| 4066/4066 [11:43<00:00,  5.78it/s]
epoch 2 loss: 0.2768: 100%|██████████| 1017/1017 [01:12<00:00, 14.01it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9203


epoch 3 loss: 0.6666: 100%|██████████| 4066/4066 [11:41<00:00,  5.79it/s]
epoch 3 loss: 0.3043: 100%|██████████| 1017/1017 [01:12<00:00, 14.06it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9213


epoch 4 loss: 0.5576: 100%|██████████| 4066/4066 [11:42<00:00,  5.79it/s]
epoch 4 loss: 0.3065: 100%|██████████| 1017/1017 [01:12<00:00, 14.07it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9142


epoch 5 loss: 0.5690: 100%|██████████| 4066/4066 [11:42<00:00,  5.78it/s]
epoch 5 loss: 0.3543: 100%|██████████| 1017/1017 [01:12<00:00, 14.06it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9110


epoch 6 loss: 0.5462: 100%|██████████| 4066/4066 [11:45<00:00,  5.76it/s]
epoch 6 loss: 0.2577: 100%|██████████| 1017/1017 [01:12<00:00, 14.04it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9331


epoch 7 loss: 0.5296: 100%|██████████| 4066/4066 [11:44<00:00,  5.77it/s]
epoch 7 loss: 0.2732: 100%|██████████| 1017/1017 [01:13<00:00, 13.93it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9297


epoch 8 loss: 0.5067: 100%|██████████| 4066/4066 [11:46<00:00,  5.75it/s]
epoch 8 loss: 0.2504: 100%|██████████| 1017/1017 [01:13<00:00, 13.92it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9341


epoch 9 loss: 0.5171: 100%|██████████| 4066/4066 [11:46<00:00,  5.75it/s]
epoch 0 loss: 0.7122: 100%|██████████| 4066/4066 [11:48<00:00,  5.74it/s]
epoch 0 loss: 0.3263: 100%|██████████| 1017/1017 [01:12<00:00, 13.99it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9137


epoch 1 loss: 0.6263: 100%|██████████| 4066/4066 [11:52<00:00,  5.71it/s]
epoch 1 loss: 0.3169: 100%|██████████| 1017/1017 [01:12<00:00, 13.94it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9107


epoch 2 loss: 0.5855: 100%|██████████| 4066/4066 [11:50<00:00,  5.72it/s]
epoch 2 loss: 0.4226: 100%|██████████| 1017/1017 [01:13<00:00, 13.91it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9260


epoch 3 loss: 0.5483: 100%|██████████| 4066/4066 [11:50<00:00,  5.72it/s]
epoch 3 loss: 0.4731: 100%|██████████| 1017/1017 [01:14<00:00, 13.68it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9292


epoch 4 loss: 0.5686: 100%|██████████| 4066/4066 [11:52<00:00,  5.71it/s]
epoch 4 loss: 0.2946: 100%|██████████| 1017/1017 [01:13<00:00, 13.91it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9284


epoch 5 loss: 0.5552: 100%|██████████| 4066/4066 [11:53<00:00,  5.70it/s]
epoch 5 loss: 0.3553: 100%|██████████| 1017/1017 [01:14<00:00, 13.74it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9341


epoch 6 loss: 0.5426: 100%|██████████| 4066/4066 [11:53<00:00,  5.70it/s]
epoch 6 loss: 0.2924: 100%|██████████| 1017/1017 [01:13<00:00, 13.83it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9380


epoch 7 loss: 0.5474: 100%|██████████| 4066/4066 [11:53<00:00,  5.70it/s]
epoch 7 loss: 0.2659: 100%|██████████| 1017/1017 [01:14<00:00, 13.74it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9420


epoch 8 loss: 0.5512: 100%|██████████| 4066/4066 [11:53<00:00,  5.70it/s]
epoch 8 loss: 0.2576: 100%|██████████| 1017/1017 [01:14<00:00, 13.72it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9410


epoch 9 loss: 0.5287: 100%|██████████| 4066/4066 [11:53<00:00,  5.70it/s]
epoch 9 loss: 0.2622: 100%|██████████| 1017/1017 [01:13<00:00, 13.91it/s]


validation multi-class accuracy = 0.9407
Training with 2 started
16262 4065


epoch 0 loss: 0.6665: 100%|██████████| 4066/4066 [11:58<00:00,  5.66it/s]
epoch 0 loss: 0.3472: 100%|██████████| 1017/1017 [01:14<00:00, 13.70it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9016


epoch 1 loss: 0.6148: 100%|██████████| 4066/4066 [11:55<00:00,  5.68it/s]
epoch 1 loss: 0.3111: 100%|██████████| 1017/1017 [01:13<00:00, 13.89it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9178


epoch 2 loss: 0.6257: 100%|██████████| 4066/4066 [11:57<00:00,  5.67it/s]
epoch 2 loss: 0.2830: 100%|██████████| 1017/1017 [01:13<00:00, 13.92it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9257


epoch 3 loss: 0.6356: 100%|██████████| 4066/4066 [11:57<00:00,  5.66it/s]
epoch 3 loss: 0.2936: 100%|██████████| 1017/1017 [01:13<00:00, 13.87it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9176


epoch 4 loss: 0.5770: 100%|██████████| 4066/4066 [11:57<00:00,  5.67it/s]
epoch 4 loss: 0.3623: 100%|██████████| 1017/1017 [01:13<00:00, 13.88it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.8937


epoch 5 loss: 0.5577: 100%|██████████| 4066/4066 [12:01<00:00,  5.64it/s]
epoch 5 loss: 0.2836: 100%|██████████| 1017/1017 [01:13<00:00, 13.82it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9245


epoch 6 loss: 0.5475: 100%|██████████| 4066/4066 [11:59<00:00,  5.65it/s]
epoch 6 loss: 0.2696: 100%|██████████| 1017/1017 [01:13<00:00, 13.88it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9257


epoch 7 loss: 0.5312: 100%|██████████| 4066/4066 [11:59<00:00,  5.65it/s]
epoch 7 loss: 0.2485: 100%|██████████| 1017/1017 [01:13<00:00, 13.88it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9383


epoch 8 loss: 0.5147: 100%|██████████| 4066/4066 [12:01<00:00,  5.64it/s]
epoch 8 loss: 0.2765: 100%|██████████| 1017/1017 [01:13<00:00, 13.84it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9262


epoch 9 loss: 0.5371: 100%|██████████| 4066/4066 [12:01<00:00,  5.64it/s]
epoch 9 loss: 0.2501: 100%|██████████| 1017/1017 [01:13<00:00, 13.90it/s]


validation multi-class accuracy = 0.9348
Training with 3 started
16262 4065


epoch 0 loss: 0.2816: 100%|██████████| 1017/1017 [01:13<00:00, 13.82it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9173


epoch 1 loss: 0.6057: 100%|██████████| 4066/4066 [11:57<00:00,  5.66it/s]
epoch 1 loss: 0.2977: 100%|██████████| 1017/1017 [01:13<00:00, 13.83it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9272


epoch 2 loss: 0.6261: 100%|██████████| 4066/4066 [12:01<00:00,  5.63it/s]
epoch 2 loss: 0.2581: 100%|██████████| 1017/1017 [01:13<00:00, 13.87it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9326


epoch 3 loss: 0.5868: 100%|██████████| 4066/4066 [12:00<00:00,  5.64it/s]
epoch 4 loss: 0.5644: 100%|██████████| 4066/4066 [12:03<00:00,  5.62it/s]
epoch 4 loss: 0.2782: 100%|██████████| 1017/1017 [01:14<00:00, 13.73it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9296


epoch 5 loss: 0.5313: 100%|██████████| 4066/4066 [12:05<00:00,  5.61it/s]
epoch 5 loss: 0.2680: 100%|██████████| 1017/1017 [01:13<00:00, 13.89it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9341


epoch 6 loss: 0.5482: 100%|██████████| 4066/4066 [12:03<00:00,  5.62it/s]
epoch 7 loss: 0.5147: 100%|██████████| 4066/4066 [12:03<00:00,  5.62it/s]
epoch 7 loss: 0.2685: 100%|██████████| 1017/1017 [01:13<00:00, 13.75it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9407


epoch 8 loss: 0.5169: 100%|██████████| 4066/4066 [12:04<00:00,  5.62it/s]
epoch 8 loss: 0.2604: 100%|██████████| 1017/1017 [01:13<00:00, 13.75it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9392


epoch 9 loss: 0.5246: 100%|██████████| 4066/4066 [12:03<00:00,  5.62it/s]
epoch 9 loss: 0.2561: 100%|██████████| 1017/1017 [01:12<00:00, 13.97it/s]


validation multi-class accuracy = 0.9383
Training with 4 started
16262 4065


epoch 0 loss: 0.6738: 100%|██████████| 4066/4066 [12:05<00:00,  5.60it/s]
epoch 0 loss: 0.3174: 100%|██████████| 1017/1017 [01:13<00:00, 13.75it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9026


epoch 1 loss: 0.6275: 100%|██████████| 4066/4066 [11:59<00:00,  5.65it/s]
epoch 1 loss: 0.2993: 100%|██████████| 1017/1017 [01:13<00:00, 13.79it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9127


epoch 2 loss: 0.6080: 100%|██████████| 4066/4066 [11:56<00:00,  5.67it/s]
epoch 2 loss: 0.3250: 100%|██████████| 1017/1017 [01:13<00:00, 13.78it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9186


epoch 3 loss: 0.5898: 100%|██████████| 4066/4066 [11:57<00:00,  5.67it/s]
epoch 4 loss: 0.5540: 100%|██████████| 4066/4066 [11:59<00:00,  5.65it/s]
epoch 4 loss: 0.2881: 100%|██████████| 1017/1017 [01:13<00:00, 13.92it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9257


epoch 5 loss: 0.5512: 100%|██████████| 4066/4066 [11:54<00:00,  5.69it/s]
epoch 6 loss: 0.5230: 100%|██████████| 4066/4066 [11:57<00:00,  5.67it/s]
epoch 7 loss: 0.5122: 100%|██████████| 4066/4066 [12:08<00:00,  5.58it/s]
epoch 7 loss: 0.2797: 100%|██████████| 1017/1017 [01:13<00:00, 13.86it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9257


epoch 8 loss: 0.5106: 100%|██████████| 4066/4066 [12:03<00:00,  5.62it/s]
epoch 8 loss: 0.2764: 100%|██████████| 1017/1017 [01:12<00:00, 14.11it/s]
  0%|          | 0/4066 [00:00<?, ?it/s]

validation multi-class accuracy = 0.9279


epoch 9 loss: 0.2875: 100%|██████████| 1017/1017 [01:12<00:00, 13.94it/s]


validation multi-class accuracy = 0.9250


In [67]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')

CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 512,
    'epochs': 10,
    'train_bs': 32,
    'valid_bs': 32,
    'lr': 1e-4,
    'num_workers': 0,
    'accum_iter': 1, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0',
    'tta': 3,
    'weights': [1,1,1,1,1]
}

def inference_one_epoch(model, data_loader, device):
    model.eval()

    image_preds_all = []
    
    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (imgs) in pbar:
        imgs = imgs.to(device).float()
        
        image_preds = model(imgs)   #output = model(input)
        image_preds_all += [torch.softmax(image_preds, 1).detach().cpu().numpy()]
        
    
    image_preds_all = np.concatenate(image_preds_all, axis=0)
    return image_preds_all

def get_inference_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

if __name__ == '__main__':
     # for training only, need nightly build pytorch

    seed_everything(CFG['seed'])
    
    folds = StratifiedKFold(n_splits=CFG['fold_num']).split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first

        print('Inference fold {} started'.format(fold))

        valid_ = train.loc[val_idx,:].reset_index(drop=True)
        valid_ds = CassavaDataset(valid_, '../input/cassava-leaf-disease-classification/train_images/', transforms=get_inference_transforms(), output_label=False)
        
        val_loader = torch.utils.data.DataLoader(
            valid_ds, 
            batch_size=CFG['valid_bs'],
            num_workers=CFG['num_workers'],
            shuffle=False,
            pin_memory=False,
        )

        device = torch.device(CFG['device'])
        model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique()).to(device)
        
        val_preds = []
        
        #for epoch in range(CFG['epochs']-3):    
        model.load_state_dict(torch.load('{}_fold_{}'.format(CFG['model_arch'], fold)))

        with torch.no_grad():
            for _ in range(CFG['tta']):
                val_preds += [1/CFG['tta']*inference_one_epoch(model, val_loader, device)]
                
        val_preds = np.mean(val_preds, axis=0) 
        
        
        print('fold {} validation loss = {:.5f}'.format(fold, log_loss(valid_.label.values, val_preds)))
        print('fold {} validation accuracy = {:.5f}'.format(fold, (valid_.label.values==np.argmax(val_preds, axis=1)).mean()))
        
        oof_ = pd.concat([valid_, pd.DataFrame(val_preds, columns=[f'soft_label_{i}' for i in range(1,6)])], axis=1)
        oof_.to_pickle(f"{CFG['model_arch']}_oof{fold}.pkl")
        
        del model
        torch.cuda.empty_cache()

Inference fold 0 started


100%|██████████| 134/134 [01:18<00:00,  1.72it/s]
100%|██████████| 134/134 [01:15<00:00,  1.78it/s]
100%|██████████| 134/134 [01:14<00:00,  1.79it/s]


fold 0 validation loss = 0.35060
fold 0 validation accuracy = 0.91565
Inference fold 1 started


100%|██████████| 134/134 [01:15<00:00,  1.78it/s]
100%|██████████| 134/134 [01:16<00:00,  1.75it/s]
100%|██████████| 134/134 [01:15<00:00,  1.77it/s]


fold 1 validation loss = 0.35959
fold 1 validation accuracy = 0.90958
Inference fold 2 started


100%|██████████| 134/134 [01:19<00:00,  1.68it/s]
100%|██████████| 134/134 [01:16<00:00,  1.75it/s]
100%|██████████| 134/134 [01:15<00:00,  1.77it/s]


fold 2 validation loss = 0.31200
fold 2 validation accuracy = 0.92218
Inference fold 3 started


100%|██████████| 134/134 [01:15<00:00,  1.77it/s]
100%|██████████| 134/134 [01:16<00:00,  1.75it/s]
100%|██████████| 134/134 [01:18<00:00,  1.70it/s]


fold 3 validation loss = 0.35316
fold 3 validation accuracy = 0.91447
Inference fold 4 started


100%|██████████| 134/134 [01:18<00:00,  1.71it/s]
100%|██████████| 134/134 [01:15<00:00,  1.77it/s]
100%|██████████| 134/134 [01:17<00:00,  1.73it/s]

fold 4 validation loss = 0.36121
fold 4 validation accuracy = 0.90816





In [64]:
v1 baseline accuracy = 0.8857
v2 v1=> change LabelSmoothingCrossEntropy accuracy = 0.8874
v3 v2=> add freeze_batchnorm_stats accuracy = 0.8874 no change
v4 v2=> remove some aug accuracy = 0.8937
v5 v4=> use ViT accuracy = 0.8837
v6 v4=> use resnext accuracy = 0.8734
v7 v4=> only classify 0,1,2,4 cause 3 is too much accuracy = 0.85
v8 v4=> remove noisy image by oof accuracy = 0.93


SyntaxError: invalid syntax (<ipython-input-64-53e7ad4baecf>, line 1)

# Inferece part is here: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-inference-tta