# EFFICIENT NET B4 MODEL

The EfficientNet B4 model from the tf_efficientnet_b4_ns architecture, is a variant of EfficientNet designed for high performance with non-separable (NS) convolution layers, providing balanced depth, width, and input resolution. The model employs a 5-fold cross-validation approach, with a training batch size of 16 and a validation batch size of 32, ensuring comprehensive evaluation and learning. Training is set for 5 epochs, balancing the time and performance. The learning rate is set to 1e-4, with a minimum of 1e-6, using a Cosine Annealing schedule with a 10-cycle length to manage learning progression. To prevent overfitting, the weight decay was set to 1e-6, and batch accumulation with an iteration value of 2 allows for effective larger-batch training. 
Input images are resized to 512x512 pixels, providing a balance between detail and computational efficiency. To balance the classes in the dataset, image augmentation was performed to create multiple variants of the existing images to eliminate bias towards a particular class. Images were modified by resizing, flips, rotations, shifts and colour alterations to ensure variability. These techniques ensure that the model performs well on unseen data and results in optimal performance.


## IMPORTS

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os


In [None]:
from glob import glob
from sklearn.model_selection import GroupKFold, StratifiedKFold
import cv2
from skimage import io
import torch
from torch import nn
import os
from datetime import datetime
import time
import random
import cv2
import torchvision
from torchvision import transforms
import pandas as pd
import numpy as np
from tqdm import tqdm

import matplotlib.pyplot as plt
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.cuda.amp import autocast, GradScaler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F

import timm

import sklearn
import warnings
import joblib
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
import warnings
import cv2
import pydicom
#from efficientnet_pytorch import EfficientNet
from scipy.ndimage.interpolation import zoom

## Hyperparameter Declaration

In [None]:
CFG = {
    'fold_num': 5,
    'seed': 719,
    'model_arch': 'tf_efficientnet_b4_ns',
    'img_size': 512,
    'epochs': 5,
    'train_bs': 16,
    'valid_bs': 32,
    'T_0': 10,
    'lr': 1e-4,
    'min_lr': 1e-6,
    'weight_decay':1e-6,
    'num_workers': 4,
    'accum_iter': 2, # suppoprt to do batch accumulation for backprop with effectively larger batch size
    'verbose_step': 1,
    'device': 'cuda:0'
}

## Data Inspection

In [None]:
train = pd.read_csv('../input/cassava-leaf-disease-classification/train.csv')
train.head()

In [None]:
train.label.value_counts()

In [None]:
submission = pd.read_csv('../input/cassava-leaf-disease-classification/sample_submission.csv')
submission.head()

## Helper Functions

In [None]:
#Setting a seed ensures reproducibility of experiments by making random number generation deterministic
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
#Retrieve image from dataset
def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

img = get_img('../input/cassava-leaf-disease-classification/train_images/1000015157.jpg')
plt.imshow(img)
plt.show()

In [None]:
#Generates random bounding boxes within a given image size,
def rand_bbox(size, lam):
    W = size[0]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

#Dataset is stored as an object with this class type to be processed
class CassavaDataset(Dataset):
    def __init__(self, df, data_root, 
                 transforms=None, 
                 output_label=True, 
                 one_hot_label=False,
                 do_fmix=False, 
                 fmix_params={
                     'alpha': 1., 
                     'decay_power': 3., 
                     'shape': (CFG['img_size'], CFG['img_size']),
                     'max_soft': True, 
                     'reformulate': False
                 },
                 do_cutmix=False,
                 cutmix_params={
                     'alpha': 1,
                 }
                ):
        
        super().__init__()
        self.df = df.reset_index(drop=True).copy()
        self.transforms = transforms
        self.data_root = data_root
        self.do_fmix = do_fmix
        self.fmix_params = fmix_params
        self.do_cutmix = do_cutmix
        self.cutmix_params = cutmix_params
        
        self.output_label = output_label
        self.one_hot_label = one_hot_label
        
        if output_label == True:
            self.labels = self.df['label'].values
            
            if one_hot_label is True:
                self.labels = np.eye(self.df['label'].max()+1)[self.labels]
            
    def __len__(self):
        return self.df.shape[0]
    
    def __getitem__(self, index: int):
        
        # get labels
        if self.output_label:
            target = self.labels[index]
          
        img  = get_img("{}/{}".format(self.data_root, self.df.loc[index]['image_id']))

        if self.transforms:
            img = self.transforms(image=img)['image']
        
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():       
                lam = np.clip(np.random.beta(self.fmix_params['alpha'], self.fmix_params['alpha']),0.6,0.7)
                
                # Make mask, get mean / std
                mask = make_low_freq_image(self.fmix_params['decay_power'], self.fmix_params['shape'])
                mask = binarise_mask(mask, lam, self.fmix_params['shape'], self.fmix_params['max_soft'])
    
                fmix_ix = np.random.choice(self.df.index, size=1)[0]
                fmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[fmix_ix]['image_id']))

                if self.transforms:
                    fmix_img = self.transforms(image=fmix_img)['image']

                mask_torch = torch.from_numpy(mask)
                
                # mix image
                img = mask_torch*img+(1.-mask_torch)*fmix_img


                # mix target
                rate = mask.sum()/CFG['img_size']/CFG['img_size']
                target = rate*target + (1.-rate)*self.labels[fmix_ix]
        
        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > 0.5:
            with torch.no_grad():
                cmix_ix = np.random.choice(self.df.index, size=1)[0]
                cmix_img  = get_img("{}/{}".format(self.data_root, self.df.iloc[cmix_ix]['image_id']))
                if self.transforms:
                    cmix_img = self.transforms(image=cmix_img)['image']
                    
                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']),0.3,0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox((CFG['img_size'], CFG['img_size']), lam)

                img[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]

                rate = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (CFG['img_size'] * CFG['img_size']))
                target = rate*target + (1.-rate)*self.labels[cmix_ix]
                
                            
        # label smoothing
        if self.output_label == True:
            return img, target
        else:
            return img

Transformation and Image Augmentation Functions

In [None]:
from albumentations import *

from albumentations.pytorch import ToTensorV2

def get_train_transforms():
    return Compose([
            RandomResizedCrop(CFG['img_size'], CFG['img_size']),
            Transpose(p=0.5),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            CoarseDropout(p=0.5),
            ToTensorV2(p=1.0),
        ], p=1.)
  
        
def get_valid_transforms():
    return Compose([
            CenterCrop(CFG['img_size'], CFG['img_size'], p=1.),
            Resize(CFG['img_size'], CFG['img_size']),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ], p=1.)

## Model Class and Functions

Image Classifier class for the EfficientNet B4 Model

In [None]:
class CassvaImgClassifier(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(n_features, n_class)
        '''
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.3),
            #nn.Linear(n_features, hidden_size,bias=True), nn.ELU(),
            nn.Linear(n_features, n_class, bias=True)
        )
        '''
    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
def prepare_dataloader(df, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/'):
    
    from catalyst.data.sampler import BalanceClassSampler
    
    train_ = df.loc[trn_idx,:].reset_index(drop=True)
    valid_ = df.loc[val_idx,:].reset_index(drop=True)
        
    train_ds = CassavaDataset(train_, data_root, transforms=get_train_transforms(), output_label=True, one_hot_label=False, do_fmix=False, do_cutmix=False)
    valid_ds = CassavaDataset(valid_, data_root, transforms=get_valid_transforms(), output_label=True)
    
    train_loader = torch.utils.data.DataLoader(
        train_ds,
        batch_size=CFG['train_bs'],
        pin_memory=False,
        drop_last=False,
        shuffle=True,        
        num_workers=CFG['num_workers'],
        #sampler=BalanceClassSampler(labels=train_['label'].values, mode="downsampling")
    )
    val_loader = torch.utils.data.DataLoader(
        valid_ds, 
        batch_size=CFG['valid_bs'],
        num_workers=CFG['num_workers'],
        shuffle=False,
        pin_memory=False,
    )
    return train_loader, val_loader

import torch
from torch.cuda.amp import autocast
from tqdm import tqdm

def train_one_epoch(
    epoch,
    model,
    loss_fn,
    optimizer,
    train_loader,
    device,
    scaler,
    CFG,
    scheduler=None,
    schd_batch_update=False
):
    model.train()
    running_loss = 0
    correct_predictions = 0
    total_samples = 0

    pbar = tqdm(enumerate(train_loader), total=len(train_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()

        with autocast():
            image_preds = model(imgs)
            loss = loss_fn(image_preds, image_labels)

        scaler.scale(loss).backward()

        running_loss += loss.item() * image_labels.shape[0]
        total_samples += image_labels.shape[0]

        correct_predictions += torch.sum(
            torch.argmax(image_preds, dim=1) == image_labels
        ).item()

        if (step + 1) % CFG['accum_iter'] == 0 or (step + 1) == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            if scheduler is not None and schd_batch_update:
                scheduler.step()

        if (step + 1) % CFG['verbose_step'] == 0 or (step + 1) == len(train_loader):
            pbar.set_description(f'epoch {epoch} loss: {running_loss / total_samples:.4f}')

    if scheduler is not None and not schd_batch_update:
        scheduler.step()

    avg_loss = running_loss / total_samples
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy

def valid_one_epoch(
    epoch,
    model,
    loss_fn,
    val_loader,
    device,
    CFG,
    scheduler=None,
    schd_loss_update=False
):
    model.eval()
    loss_sum = 0
    total_samples = 0
    correct_predictions = 0
    image_preds_all = []
    image_targets_all = []

    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for step, (imgs, image_labels) in pbar:
        imgs = imgs.to(device).float()
        image_labels = image_labels.to(device).long()
        
        with torch.no_grad():
            image_preds = model(imgs)
            loss = loss_fn(image_preds, image_labels)

            image_preds_all.append(
                torch.argmax(image_preds, dim=1).detach().cpu().numpy()
            )
            image_targets_all.append(image_labels.detach().cpu().numpy())
        
        loss_sum += loss.item() * image_labels.shape[0]
        total_samples += image_labels.shape[0]
        correct_predictions += torch.sum(
            torch.argmax(image_preds, dim=1) == image_labels
        ).item()

        if (step + 1) % CFG['verbose_step'] == 0 or (step + 1) == len(val_loader):
            pbar.set_description(f'epoch {epoch} loss: {loss_sum / total_samples:.4f}')

    image_preds_all = np.concatenate(image_preds_all)
    image_targets_all = np.concatenate(image_targets_all)

    accuracy = (image_preds_all == image_targets_all).mean()
    avg_loss = loss_sum / total_samples

    print(f'Validation multi-class accuracy = {accuracy:.4f}')
    
    if scheduler is not None:
        if schd_loss_update:
            scheduler.step(avg_loss)
        else:
            scheduler.step()
    
    return avg_loss, accuracy


Class to check loss obtained for each training iteration

In [None]:
# reference: https://www.kaggle.com/c/siim-isic-melanoma-classification/discussion/173733
class MyCrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, reduction='mean'):
        super().__init__(weight=weight, reduction=reduction)
        self.weight = weight
        self.reduction = reduction

    def forward(self, inputs, targets):
        lsm = F.log_softmax(inputs, -1)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

        loss = -(targets * lsm).sum(-1)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

## Model Training

In [None]:
from torch.cuda.amp import GradScaler
# Initialize scaler for mixed-precision training
scaler = GradScaler()

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

if __name__ == '__main__':
    seed_everything(CFG['seed'])
    
    folds = StratifiedKFold(n_splits=CFG['fold_num'], shuffle=True, random_state=CFG['seed']).split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (trn_idx, val_idx) in enumerate(folds):
        # we'll train fold 0 first
        print('Training with {} started'.format(fold))

        print(len(trn_idx), len(val_idx))
        train_loader, val_loader = prepare_dataloader(train, trn_idx, val_idx, data_root='../input/cassava-leaf-disease-classification/train_images/')

        device = torch.device(CFG['device'])
        
        model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)
        scaler = GradScaler()   
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CFG['T_0'], T_mult=1, eta_min=CFG['min_lr'], last_epoch=-1)
        
        loss_tr = nn.CrossEntropyLoss().to(device) #MyCrossEntropyLoss().to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        
        # Training loop with metric collection
        for epoch in range(CFG['epochs']):
            train_loss, train_acc = train_one_epoch(epoch,model,loss_tr,optimizer,train_loader,device,scaler=scaler,CFG=CFG,scheduler=scheduler,schd_batch_update=False
            )
            train_losses.append(train_loss)
            train_accuracies.append(train_acc)

            with torch.no_grad():
                val_loss, val_acc = valid_one_epoch(
                    epoch,
                    model,
                    loss_fn,
                    val_loader,
                    device,
                    CFG=CFG,
                    scheduler=None,
                    schd_loss_update=False
                )
                val_losses.append(val_loss)
                val_accuracies.append(val_acc)
            torch.save(model.state_dict(), '{}_fold_{}_{}'.format(CFG['model_arch'], fold, epoch))


        del model, optimizer, train_loader, val_loader, scaler, scheduler #clean memory
        torch.cuda.empty_cache()

## Model Results and Visualization

In [None]:
# Plotting the metrics
import matplotlib.pyplot as plt
# Ensure the lengths match
num_epochs = len(train_losses)  # or min(len(train_losses), len(val_losses))

# Ensure `epochs` array matches the length of metric arrays
epochs = range(0, num_epochs)

# Plot the corrected data
plt.figure(figsize=(12, 6))
plt.plot(epochs, train_losses[:num_epochs], 'b-', label='Training Loss')
plt.plot(epochs, val_losses[:num_epochs], 'r-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(epochs, train_accuracies[:num_epochs], 'b-', label='Training Accuracy')
plt.plot(epochs, val_accuracies[:num_epochs], 'r-', label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()


# Load Model from Saved State

This step is important as we need to ensure the model can be loaded so we can load it into the inference notebook

In [None]:
state_dict = torch.load('/kaggle/working/tf_efficientnet_b4_ns_fold_4_4', map_location=torch.device('cpu'))

In [None]:
# Instantiate the corresponding model (must match the architecture of the state dict)
b4_model = CassvaImgClassifier(CFG['model_arch'], train.label.nunique(), pretrained=True).to(device)  # or your custom model class

# Load the state dictionary into the model
b4_model.load_state_dict(state_dict)

Prints out the summary of the model to ensure the model is loaded properly

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary
import torch
import torchvision.models as models  # Or import your model class

summary(b4_model, (3, 512, 512))