# 🐬 Happy Whale V6 

### 🐋 Trick List

* **GeM**
* **ArcFace**
* **Label Smoothing**
* **Focal Loss**
* **CutMix and MixUp**
* **TTA**

### 🐋 In Happy Whale V6

* **Detic Crop**
* **KFolds Stratified**
* **U2Net**
* **EMA**

### 🐋 In Happy Whale V7

* **Focus loss to Asymmetric Loss**
* **Metric Loss with Margin**
* **AdamW to AdamP**

### 🐋 Log

**2022/3/25**

* model = SegConvWhaleNet
* backbone = 'convnext_small'
* lr = 3e-4
* img_size = 384
* batch_size = 32
* weight_decay = 2e-5
* Epochs = 25

**2022/3/22**


I noticed that the original minimum lr was too small (about 1e-9), so I changed it to 1.5e-6.

In [None]:
! pip install timm

In [None]:
!git clone https://github.com/NathanUA/U-2-Net.git

In [None]:
import os
import sys
import math
import random
import warnings
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import notebook
from IPython import display
import matplotlib.pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F 
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.model_selection import StratifiedKFold
sys.path.append('./U-2-Net')
warnings.filterwarnings('ignore')

In [None]:
import timm
from timm.optim.optim_factory import create_optimizer
from model import U2NET

**Print Style**

In [None]:
class clr:
    S = '\033[1m' + '\033[96m'
    E = '\033[0m'

**Set Seeds**

In [None]:
def set_seeds(seed = 2022 + 3 + 22):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.backends.cudnn.benchmark = True
    os.environ['PYTHONHASHSEED'] = str(seed) # Set a fixed value for the hash seed
    
set_seeds()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**Config**

In [None]:
cfg = { 'Effnet_B0': 'tf_efficientnet_b0',
        'Effnet_B4': 'tf_efficientnet_b4',
       'EffnetV2_s': 'tf_efficientnetv2_s_in21k',
       'EffnetV2_m': 'tf_efficientnetv2_m_in21k',
       'ConvNext_s': 'convnext_small',
       'ConvNext_b': 'convnext_base_384_in22ft1k',
      }

In [None]:
LR = 3e-4
img_size = 56 #384
BATCH_SIZE = 64
NUM_WORKERS = 2
NUM_CLASSES = 15587

In [None]:
TRAIN_IMG_DIR = r'../input/happy-whale-and-dolphin/train_images'

**Read CSV and Make Labels**

In [None]:
# train_df_path = r'../input/whale2-cropped-dataset/train2.csv'
# train_df = pd.read_csv(train_df_path)
# train_df['label'] = train_df.groupby('individual_id').ngroup()
# train_df['sp_label'] = train_df.groupby('species').ngroup()
# train_df.head()

**Stratified KFold**

In [None]:
# N_SPLITS = 10

# skfolds = StratifiedKFold(n_splits=N_SPLITS)

# for n_fold, (train_data, valid_data) in enumerate(skfolds.split(X=train_df, y=train_df['individual_id'])):
# train_df.loc[train_data, 'kfold:'+str(n_fold)] = 'train'
# train_df.loc[valid_data, 'kfold:'+str(n_fold)] = 'valid'
    
# train_df.to_csv('train.csv', index=False)

# 🐬 Dataset

In [None]:
class HappyWhaleDataset(Dataset):
    def __init__(self, df:pd.DataFrame, img_dir:str, transform=None, crop_p=0.6):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        self.crop_p = crop_p
        
        
    def __getitem__(self, idx):
        img_name = self.df['image'][idx]
        #label = self.df['label'][idx]
        label = self.df['sp_label'][idx]
        
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.crop_p > 0 and random.uniform(0,1) < self.crop_p:
            if self.df['box'][idx] != 'nan':
                xmin, ymin, xmax, ymax = self.df['box'][idx]
                image = image[ymin:ymax, xmin:xmax, :]
        
        if self.transform is not None:
            image = self.transform(image=image)['image']

        image = torch.tensor(image.transpose(2, 0, 1))
        label = torch.tensor(label)
        
        return image, label
        
    def __len__(self):
        return len(self.df)

# 🐬 Data Augmentation

**1.**

In [None]:
from albumentations import Compose,OneOf,SmallestMaxSize, Resize,Normalize,HorizontalFlip,ShiftScaleRotate,CLAHE,RandomBrightnessContrast,Emboss,HueSaturationValue,RandomResizedCrop,MotionBlur,MedianBlur,GaussianBlur,GaussNoise,CoarseDropout,Sharpen

In [None]:
transforms = {
    'train': Compose([
        SmallestMaxSize(max_size=img_size),
        RandomResizedCrop(img_size, img_size, scale=(0.75, 1.0)),
        RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=0.5),
        OneOf([Sharpen(), Emboss(), CLAHE()], p=0.5),
        HueSaturationValue(),
        OneOf([MotionBlur(), MedianBlur(), GaussianBlur()], p=0.5),
        GaussNoise(),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.5),
        CoarseDropout(max_holes=2, max_height=(img_size//10), max_width=(img_size//10), min_holes=1, min_height=(img_size//20), min_width=(img_size//20), p=0.3),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
        
    'valid': Compose([
        SmallestMaxSize(max_size=img_size),
        Resize(img_size, img_size),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]),
    
    'TTA': Compose([
        SmallestMaxSize(max_size=img_size),
        RandomResizedCrop(img_size, img_size, scale=(0.8, 1.0)),
        HorizontalFlip(p=0.5),
        ShiftScaleRotate(rotate_limit=15, p=0.4),                                                                
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]), 
}

**2.MixUp**

In [None]:
def mixup_data(input, label, alpha=0.2):
    p = np.random.beta(alpha, alpha)
    batch_size = input.size()[0]
    index = torch.randperm(batch_size).to(device)
    lam = np.maximum(p, 1-p)
    mixed_input = lam * input + (1 - lam) * input[index, :]
    
    return mixed_input, label, label[index], lam

**3.CutMix**

In [None]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [None]:
def cutmix_data(input, label, beta=1.0):
    lam = np.random.beta(beta, beta)
    batch_size = input.size()[0]
    rand_index = torch.randperm(batch_size).to(device)
    
    label_a = label
    label_b = label[rand_index, :]
    
    bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
    input[:, :, bbx1:bbx2, bby1:bby2] = input[rand_index, :, bbx1:bbx2, bby1:bby2]
    
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
    label = lam * label_a + (1 - lam) * label_b
    
    return input, label

# 🐬 Model

**GeM Pooling**

In [None]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.tensor(p, dtype=torch.float))
        self.eps = eps

    def forward(self, inputs):
        KH = inputs.shape[-2]
        KW = inputs.shape[-1]
        inputs = inputs.clamp(min=self.eps).pow(self.p)
        outputs = F.avg_pool2d(inputs, (KH, KW)).pow(1./self.p)
        
        return outputs

**Cos Product**

In [None]:
class CosProduct(nn.Module):
    def __init__(self, embedding_size, num_classes):
        super(CosProduct, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, embedding_size))
        nn.init.xavier_uniform_(self.weight)
        
    def forward(self, x):
        x = F.normalize(x, dim=1)
        w = F.normalize(self.weight, dim=1)
        cosine = F.linear(x, w)
        
        return cosine

**1.Efficientnet backbone**

In [None]:
class EffWhaleNet(nn.Module):
    def __init__(self, backbone, num_classes, embedding_size=512, backbone_pretrained=True):
        super(EffWhaleNet, self).__init__()
        self.pretrained = backbone_pretrained
        self.backbone = timm.create_model(backbone, pretrained=self.pretrained)
        backbone_features = self.backbone.classifier.in_features
        self.backbone.global_pool = nn.Identity()
        self.backbone.classifier = nn.Identity()
        self.GeM = GeM()
        self.Embedding = nn.Linear(backbone_features, embedding_size)
        self.CosProduct = CosProduct(embedding_size, num_classes)
        
    def forward(self, x, only_feature=False):
        x = self.backbone(x)
        x = self.GeM(x).flatten(1)
        feature = self.Embedding(x)
        outputs = self.CosProduct(feature)
        if only_feature:
            return feature
        else:
            return outputs, feature

**2.EffnetV2 backbone + U2Net (input channels = 4)**

In [None]:
class SegEffWhaleNet(nn.Module):
    def __init__(self, backbone, num_classes, embedding_size=512, unet_path=None, backbone_pretrained=True):
        super(SegEffWhaleNet, self).__init__()
        self.THRESHOLD = 0.3
        self.Unet = U2NET()
        self.Unet.load_state_dict(torch.load(unet_path))
            
        for param in self.Unet.parameters():
            param.requires_grad = False 
            
        self.backbone = timm.create_model(backbone, pretrained=backbone_pretrained)
        backbone_features = self.backbone.classifier.in_features
        self.backbone.global_pool = nn.Identity()
        self.backbone.classifier = nn.Identity()
        self.GeM = GeM()
        self.Embedding = nn.Sequential(nn.Linear(backbone_features, embedding_size), nn.Dropout(p=0.2))
        self.FeatMerge = nn.Linear(embedding_size*2, embedding_size)
        self.CosProduct = CosProduct(embedding_size, num_classes)
        
    
    def forward(self, x):
        pred = self.unet_pred(x)
        pred = self.normPRED(pred)
        mask = torch.where(pred > self.THRESHOLD, 1, 0)
        s = x * mask
        
        x = self.backbone(x)
        x = self.GeM(x).flatten(1)
        feat_x = self.Embedding(x)
        
        s = self.backbone(s)
        s = self.GeM(s).flatten(1)
        feat_s = self.Embedding(s)
        
        feature = torch.cat([feat_x, feat_s], dim=1)
        feature = self.FeatMerge(feature)
        
        outputs = self.CosProduct(feature)
        
        
        return outputs, feature
    
    
    def unet_pred(self, x):
        self.Unet.eval()
        with torch.no_grad():
            d1,d2,d3,d4,d5,d6,d7 = self.Unet(x)
            pred = d1.detach()
            
            del d1,d2,d3,d4,d5,d6,d7
            
        return pred
    
    
    def normPRED(self, d):
        batch_size = d.size()[0]
        
        dmin = d.reshape(batch_size, -1).min(dim=1, keepdim=True).values
        dmin = dmin.reshape(batch_size, 1, 1, 1)
        
        dmax = d.reshape(batch_size, -1).max(dim=1, keepdim=True).values
        dmax = dmax.reshape(batch_size, 1, 1, 1)
        
        return (d - dmin) / (dmax - dmin)

**3.Convnext backbone + U2Net**

In [None]:
class SegConvWhaleNet(nn.Module):
    def __init__(self, backbone, num_classes, embedding_size=512, unet_path=None, backbone_pretrained=True):
        super(SegConvWhaleNet, self).__init__()
        self.THRESHOLD = 0.3
        self.Unet = U2NET()
        
        if unet_path is not None:
            self.Unet.load_state_dict(torch.load(unet_path))
            
        for param in self.Unet.parameters():
            param.requires_grad = False 
            
        self.backbone = timm.create_model(backbone, pretrained=backbone_pretrained)
        backbone_features = self.backbone.get_classifier().in_features
        self.backbone.head.global_pool = GeM()
        self.backbone.head.fc = nn.Identity()
        self.Embedding = nn.Linear(backbone_features, embedding_size)
        self.Feat_Merge = nn.Linear(embedding_size * 2, embedding_size)
        self.CosProduct = CosProduct(embedding_size, num_classes)
        
    
    def forward(self, x):
        pred = self.unet_pred(x)
        pred = self.normPRED(pred)
        mask = torch.where(pred > self.THRESHOLD, 1, 0)
        s = x * mask
        
        x = self.backbone(x)
        feat_x = self.Embedding(x)
        
        s = self.backbone(s)
        feat_s = self.Embedding(s)
        
        feature = torch.cat([feat_x, feat_s], dim=1)
        feature = self.Feat_Merge(feature)
        
        outputs = self.CosProduct(feature)
        
        return outputs, feature
    
    
    def unet_pred(self, x):
        self.Unet.eval()
        with torch.no_grad():
            d1,d2,d3,d4,d5,d6,d7 = self.Unet(x)
            pred = d1.detach()
            
            del d1,d2,d3,d4,d5,d6,d7
            
        return pred
    
    
    def normPRED(self, d):
        batch_size = d.size()[0]
        
        dmin = d.reshape(batch_size, -1).min(dim=1, keepdim=True).values
        dmin = dmin.reshape(batch_size, 1, 1, 1)
        
        dmax = d.reshape(batch_size, -1).max(dim=1, keepdim=True).values
        dmax = dmax.reshape(batch_size, 1, 1, 1)
        
        return (d - dmin) / (dmax - dmin)

In [None]:
class SegConvWhaleNetV2(nn.Module):
    def __init__(self, backbone, num_classes, embedding_size=512, unet_path=None, backbone_pretrained=True):
        super(SegConvWhaleNetV2, self).__init__()
        self.THRESHOLD = 0.3
        self.Unet = U2NET()
        
        if unet_path is not None:
            self.Unet.load_state_dict(torch.load(unet_path))
            
        for param in self.Unet.parameters():
            param.requires_grad = False 
            
        self.backbone = timm.create_model(backbone, pretrained=backbone_pretrained)
        backbone_features = self.backbone.get_classifier().in_features
        self.backbone.head.global_pool = GeM()
        self.backbone.head.fc = nn.Identity()
        self.Embedding = nn.Sequential(nn.Linear(backbone_features, embedding_size), nn.Dropout(p=0.2))
        self.FeatMerge = nn.Linear(embedding_size * 2, embedding_size)
        self.CosProduct = CosProduct(embedding_size, num_classes)
        
    
    def forward(self, x):
        pred = self.unet_pred(x)
        pred = self.normPRED(pred)
        mask = torch.where(pred > self.THRESHOLD, 1, 0)
        s = x * mask
        
        x = self.backbone(x)
        feat_x = self.Embedding(x)
        
        s = self.backbone(s)
        feat_s = self.Embedding(s)
        
        feature = torch.cat([feat_x, feat_s], dim=1)
        feature = self.Feat_Merge(feature)
        
        outputs = self.CosProduct(feature)
        
        return outputs, feature
    
    
    def unet_pred(self, x):
        self.Unet.eval()
        with torch.no_grad():
            d1,d2,d3,d4,d5,d6,d7 = self.Unet(x)
            pred = d1.detach()
            
            del d1,d2,d3,d4,d5,d6,d7
            
        return pred
    
    
    def normPRED(self, d):
        batch_size = d.size()[0]
        
        dmin = d.reshape(batch_size, -1).min(dim=1, keepdim=True).values
        dmin = dmin.reshape(batch_size, 1, 1, 1)
        
        dmax = d.reshape(batch_size, -1).max(dim=1, keepdim=True).values
        dmax = dmax.reshape(batch_size, 1, 1, 1)
        
        return (d - dmin) / (dmax - dmin)

# 🐬 EMA

**EMA inside timm** https://fastai.github.io/timmdocs/training_modelEMA#Internals-of-Model-EMA-inside-timm

In [None]:
class ModelEmaV2(nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.device = device
        if self.device is not None:
            self.module.to(device=device)

    def _update(self, model, update_fn):
        with torch.no_grad():
            for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

# 🐬 Loss Functon

**1.CrossEntropyLoss with Label Smoothing**

In [None]:
class labelsmoothing_CELoss(nn.Module):
    def __init__(self, label_smoothing=0.1, reduction='none'):
        super(labelsmoothing_CELoss, self).__init__()
        assert 0 <= label_smoothing < 1
        self.label_smoothing = label_smoothing
        self.reduction = reduction
        
    def forward(self, x, label):
        num_classes = x.shape[-1]
        logprobs = F.log_softmax(x, dim=1)
        label = F.one_hot(label, num_classes)
        label = label * (1 - self.label_smoothing) + self.label_smoothing / num_classes
        if self.reduction is 'none':
            return (-1 * label * logprobs).sum(dim=1).reshape(-1)
        
        if self.reduction is 'mean':
            return (-1 * label * logprobs).sum(dim=1).reshape(-1).mean()
        
        if self.reduction is 'sum':
            return (-1 * label * logprobs).sum(dim=1).reshape(-1).sum()

**2.ArcLoss**

ArcFace https://arxiv.org/pdf/1801.07698v1.pdf

In [None]:
class ArcLoss(nn.Module):
    def __init__(self, m=0.5, s=30, easy_margin=False):
        super(ArcLoss, self).__init__()
        self.m = m
        self.s = s
        self.cos_m = math.cos(self.m)
        self.sin_m = math.sin(self.m)
        self.easy_margin = easy_margin
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        
        
    def forward(self, cosine, labels):
        num_classes = cosine.size()[-1]
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m 
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        
        one_hot = F.one_hot(labels, num_classes) 
        outputs = (one_hot * phi + (1.0 - one_hot) * cosine) * self.s
        
        loss = labelsmoothing_CELoss()(outputs, labels)
        
        return loss

**3.Focal Loss**

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha = 0.25, gamma = 2.0, reduction='none', label_smoothing=0.1):
        super(FocalLoss, self).__init__()
        self.s = 20
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        
    def forward(self, x, label):
        num_classes = x.size()[-1]
        x_s = x * self.s
        probs = F.softmax(x_s, dim=1)
        logprobs = F.log_softmax(x_s, dim=1)
        one_hot = F.one_hot(label, num_classes)
        pow_down = one_hot * (1 - probs) + (1.0 - one_hot) * probs
        if self.label_smoothing > 0:
            one_hot = one_hot * (1 - self.label_smoothing) + self.label_smoothing / num_classes
        CELoss = -1 * logprobs * one_hot
        loss = self.alpha * torch.pow(pow_down, self.gamma) * CELoss
        
        loss = loss.sum(dim=1)
        
        if self.reduction is 'none':
            return loss
        
        if self.reduction is 'mean':
            return loss.mean()
        
        if self.reduction is 'sum':
            return loss.sum()

**4.Asymmetric Loss**

In [None]:
class AsymmetricLossSingleLabel(nn.Module):
    def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='none'):
        super(AsymmetricLossSingleLabel, self).__init__()
        self.s = 20
        self.eps = eps
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        self.targets_classes = []  # prevent gpu repeated memory allocation
        self.gamma_pos = gamma_pos
        self.gamma_neg = gamma_neg
        self.reduction = reduction

    def forward(self, inputs, target):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (1-hot vector)
        """

        num_classes = inputs.size()[-1]
        log_preds = self.logsoftmax(inputs*self.s)
        self.targets_classes = F.one_hot(target, num_classes)

        # ASL weights
        targets = self.targets_classes
        anti_targets = 1 - targets
        xs_pos = torch.exp(log_preds)
        xs_neg = 1 - xs_pos
        xs_pos = xs_pos * targets
        xs_neg = xs_neg * anti_targets
        asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
                                 self.gamma_pos * targets + self.gamma_neg * anti_targets)
        log_preds = log_preds * asymmetric_w

        if self.eps > 0:  # label smoothing
            self.targets_classes = self.targets_classes * (1 - self.eps) + self.eps / num_classes

        # loss calculation
        loss = - self.targets_classes.mul(log_preds)

        loss = loss.sum(dim=-1)
        if self.reduction == 'mean':
            loss = loss.mean()

        return loss

**5.MetricLoss**

In [None]:
class MetricLoss(nn.Module):
    def __init__(self, num_classes, margin=0.2, reduction='none'):
        super(MetricLoss, self).__init__()
        self.num_classes = num_classes
        self.reduction = reduction
        self.margin = margin
        
    def forward(self, embed, label):
        embed = F.normalize(embed, dim=1)
        onehot = F.one_hot(label, self.num_classes).float()
        onehot = F.normalize(onehot, dim=1)
        target = F.linear(onehot, onehot)
        correl = F.linear(embed, embed)
        correl = correl - target * self.margin
        correl = torch.clamp(correl, min=0.0)
        
        diff = torch.pow((correl - target), 2)
        loss = diff.sum(dim=1)
        
        if self.reduction is 'none':
            return loss
        
        if self.reduction is 'mean':
            return loss.mean()
        
        if self.reduction is 'sum':
            return loss.sum()

**Mixed Loss**

In [None]:
class MixedLoss(nn.Module):
    def __init__(self, ArcLoss, AsyLoss, MetricLoss, w1 = 1.0, w2 = 1.0, w3 = 1.0, reduction='none'):
        super(MixedLoss, self).__init__()
        self.ArcLoss = ArcLoss
        self.AsyLoss = AsyLoss
        self.MetricLoss = MetricLoss
        self.w1 = w1
        self.w2 = w2
        self.w3 = w3
        self.reduction = reduction
        
    def forward(self, x, embed, label):
        loss = self.w1 * self.ArcLoss(x, label) + self.w2 * self.AsyLoss(x, label) + self.w3 * self.MetricLoss(embed, label)
        if self.reduction is 'none':
            return loss
        
        if self.reduction is 'mean':
            return loss.mean()
        
        if self.reduction is 'sum':
            return loss.sum()
        

# 🐬 Animator

**Plot loss and acc**

In [None]:
class Animator:
    def __init__(self, legend=None, xlim=None, ylim=[[0,None],[0,1]], fmts=('b-', 'y-', 'r-', 'g-'), figsize=(5, 5)):
        self.w, self.h = figsize
        self.fig, self.axes = plt.subplots(1, 2, figsize=(self.w * 2, self.h))
        if legend is None:
            self.legend = ['train_loss','train_acc','valid_loss','valid_acc']
        else:
            self.legend = legend
        self.X = None
        self.Y = None
        self.xlim = xlim
        self.ylim = ylim
        self.fmts = fmts
        
        
    def log(self):
        
        return self.Y

        
    def add(self, x, y):
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if self.X is None:
            self.X = [[] for _ in range(n)]
        if self.Y is None:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            self.X[i].append(a)
            self.Y[i].append(b)
        self.axes[0].cla()
        self.axes[1].cla()
        for i, (x, y, fmt) in enumerate(zip(self.X, self.Y, self.fmts)):
            self.axes[i%2].plot(x, y, fmt, label=self.legend[i])
            self.axes[i%2].legend(loc='best')
            self.axes[i%2].set_xlabel('Epoch')
            self.axes[i%2].grid(True)
            
        if self.xlim is not None:
            self.axes[0].set_xlim(self.xlim)
            self.axes[1].set_xlim(self.xlim)
            
        if self.ylim is not None:
            self.axes[0].set_ylim(self.ylim[0])
            self.axes[1].set_ylim(self.ylim[1])
            
        display.display(self.fig) # add display.clear_output(wait=True) to clear output

# 🐬 Trainer

**Accumulator used to accumulate loss, acc, num of elements**

In [None]:
class Accumulator: 
    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
class Trainer:
    def __init__(self, model, train_dataloader, optimizer, criterion, valid_dataloader=None, use_EMA=True):
        self.best_acc = 0.0
        self.model = model
        self.use_EMA = use_EMA
        if self.use_EMA:
            self.model_ema = timm.utils.ModelEma(self.model, device='cuda')
        self.model.to(device)
        self.criterion = criterion
        self.optimizer = optimizer
        self.lr_scheduler = None
        self.train_dataloader = train_dataloader
        self.train_steps = len(train_dataloader)
        self.valid_dataloader = valid_dataloader
        if self.valid_dataloader is not None:
            self.valid_steps = len(valid_dataloader)
            
        
    def train(self, num_epochs, cut_mix_p=0.9, cut_mix_start=2, cut_mix_end = 28):
        self.lr_scheduler = OneCycleLR(optimizer, max_lr=3e-4, epochs=num_epochs, steps_per_epoch=self.train_steps, 
                                       pct_start=0.3, div_factor=10, final_div_factor=10)
                                        
        animator = Animator(xlim=[0,num_epochs])
        for epoch in range(num_epochs):
            print(clr.S + f"Epoch : {epoch}" + clr.E)
            train_loss, train_acc = self.train_one_epoch(epoch, cut_mix_p, cut_mix_start, cut_mix_end)
            if valid_set is not None:
                valid_loss, valid_acc = self.evaluate(epoch)
                animator.add(epoch, (train_loss, train_acc, valid_loss, valid_acc))
                if valid_acc > self.best_acc:
                    self.best_acc = valid_acc
                    print(clr.S + f'Found Higher Acc = {self.best_acc:.4f}. Saving State...' + clr.E)
                    if self.use_EMA:
                        torch.save(self.model_ema.module.state_dict(), f'HappyWhaleNet{epoch}.pth')
                    else:   
                        torch.save(self.model.state_dict(), f'HappyWhaleNet{epoch}.pth')
        
        print(clr.S + "Train finished!" + clr.E)
                
                
    def train_one_epoch(self, epoch, cut_mix_p, cut_mix_start, cut_mix_end):
        print(clr.S + "--- Train ---" + clr.E)
        self.model.train()
        metric = Accumulator(4)
        train_tqdm = notebook.tqdm(self.train_dataloader, total=self.train_steps)
        for i, data in enumerate(train_tqdm):
            input, label = data
            input = input.to(device)
            label = label.to(device).reshape(-1)
            
            if epoch >= cut_mix_start and random.uniform(0, 1) < cut_mix_p and epoch < cut_mix_end:
                input, label_a, label_b, lam = mixup_data(input, label, alpha=0.2)
                output, feature = self.model(input)
                loss = lam * self.criterion(output, feature, label_a) + (1 - lam) * self.criterion(output, feature, label_b)
            else:
                output, feature = self.model(input)
                loss = self.criterion(output, feature, label)
                
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            self.lr_scheduler.step()

            if self.use_EMA:
                self.model_ema.update(self.model)
            
            metric.add(float(loss.sum()), self.compute_acc(output, label), self.compute_Top5_acc(output, label), label.numel())
            train_tqdm.set_postfix(loss = metric[0]/metric[3], acc = metric[1]/metric[3], top5acc = metric[2]/metric[3])

                    
        return metric[0] / metric[3], metric[1] / metric[3]
         
        
    def evaluate(self, epoch):
        print(clr.S + "--- Valid ---" + clr.E)
        self.model.eval()
        metric = Accumulator(4)
        valid_tqdm = notebook.tqdm(self.valid_dataloader, total=self.valid_steps)
        with torch.no_grad():
            for i, data in enumerate(valid_tqdm):
                input, label = data
                input = input.to(device)
                label = label.to(device)
                label = label.reshape(-1)
                
                if self.use_EMA:
                    output, feature = self.model_ema.module(input)
                else:
                    output, feature = self.model(input)
                    
                loss = self.criterion(output, feature, label)
                metric.add(float(loss.sum()), self.compute_acc(output, label), self.compute_Top5_acc(output, label), label.numel())
                valid_tqdm.set_postfix(loss = metric[0]/metric[3], acc = metric[1]/metric[3], top5acc = metric[2]/metric[3])
            
        return metric[0] / metric[3], metric[1] / metric[3]
        
    
    def compute_acc(self, pred, label):
        pred = pred.argmax(axis=1)
        cmp = pred.type(label.dtype) == label
        
        return float(cmp.type(label.dtype).sum())
    
    
    def compute_Top5_acc(self, pred, label):
        top5 = torch.topk(pred, dim=1, k=5)
        cmp = top5.indices == label.reshape(-1,1)
        
        return float(cmp.type(label.dtype).sum())

# 🐬 Prepare training

In [None]:
def DataProvider(train_set, valid_set=None, batch_size=32, num_workers=2):
    train_dataloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    if valid_set is not None:
        valid_dataloader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        return {'train':train_dataloader, 'valid':valid_dataloader}
    else:
        return {'train':train_dataloader}

In [None]:
def str2int(s):
    if type(s) != str:
        return 'nan'
    xmin, ymin, xmax, ymax = s.strip().split(' ')
    
    return [int(xmin.strip()), int(ymin.strip()), int(xmax.strip()), int(ymax.strip())]

In [None]:
df = pd.read_csv('../input/kfolds-cropped-whale-train-csv/kfolds_cropped_whale_train.csv')
df['box'] = df['box'].apply(str2int)

NUM_CLASSES = len(df.groupby('individual_id'))

print(f'Dataset size : {len(df)}')
print(f'Number of individuals : {NUM_CLASSES}')

In [None]:
train_df = df[df['kfold:0'] == 'train'].reset_index(drop=True)
valid_df = df[df['kfold:0'] == 'valid'].reset_index(drop=True)

train_df = train_df[:5000]

train_set = HappyWhaleDataset(train_df, TRAIN_IMG_DIR, transform=transforms['train'], crop_p=1.0)
valid_set = HappyWhaleDataset(valid_df, TRAIN_IMG_DIR, transform=transforms['valid'], crop_p=1.0)

print(f'train_set size : {len(train_set)}')
print(f'valid_set size : {len(valid_set)}')

In [None]:
arcloss = ArcLoss()
asyloss = AsymmetricLossSingleLabel()
metricloss = MetricLoss(num_classes=NUM_CLASSES)

mixedloss = MixedLoss(ArcLoss=arcloss, AsyLoss=asyloss, MetricLoss=metricloss, w1 = 1, w2 = 1, w3 = 1)

In [None]:
Pretrained_U2Net = r'../input/pretrained-u2net/u2net.pth'
# model = SegEffWhaleNetV2(backbone=cfg['Effnet_B0'], num_classes=NUM_CLASSES, embedding_size=512, unet_path=Pretrained_U2Net, backbone_pretrained=True)
model = EffWhaleNet(backbone=cfg['Effnet_B0'], num_classes=NUM_CLASSES, embedding_size=512, backbone_pretrained=True)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'number of params : {n_parameters//1e6}M')

In [None]:
provider = DataProvider(train_set, valid_set, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

optimizer = timm.optim.AdamP((param for param in model.parameters() if param.requires_grad), lr=3e-4, weight_decay=2e-5)

whale_trainer =Trainer(model, provider['train'], optimizer, criterion=mixedloss, valid_dataloader=provider['valid'], use_EMA=False)

In [None]:
whale_trainer.train(num_epochs = 10, cut_mix_p = 0, cut_mix_start = 0, cut_mix_end = 10)

In [None]:
for param in model.backbone.parameters():
    param.requires_grad = False
    
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'number of params : {n_parameters//1e6}M')

In [None]:
param_group_1 = [param for param in model.backbone.parameters() if param.requires_grad]
param_group_2 = []

for param in model.GeM.parameters():
    param_group_2.append(param)
    
for param in model.Embedding.parameters():
    param_group_2.append(param)
    
for param in model.CosProduct.parameters():
    param_group_2.append(param)
    
optimizer = timm.optim.AdamW([{'params':param_group_1, 'lr':LR/10 }, 
                              {'params':param_group_2}], lr = LR, weight_decay = 2e-5)

**1.EffV1 weight_decay=1e-6 epoch=4 acc=0.066 overfitting**

**2.EffV1 weight_decay=1e-5 epoch=15**

# 🐬 Analysis

# 🐬 Test

In [None]:
individual_label_map = df[['individual_id', 'label']].drop_duplicates()
individual_label_map = individual_label_map.append({'individual_id':'new_individual', 'label':NUM_CLASSES}, ignore_index=True)
individual_label_map = individual_label_map.sort_values(by='label', ignore_index=True)
individual_label_map.head()

In [None]:
class WhaleTestDataset(Dataset):
    def __init__(self, path:str, df:pd.DataFrame, crop=False, transform=None):
        self.path = path
        self.df = df
        self.crop = crop
        self.transform = transform
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.path, self.df['image'][idx])
        img = Image.open(img_path).convert('RGB')
        img = np.array(img)
        
        if self.crop:
            if self.df['box'][idx] == 'nan':
                pass
            else:
                xmin, ymin, xmax, ymax = self.df['box'][idx]
                image = image[ymin:ymax, xmin:xmax, :]
        
        if self.transform is not None:
            img = self.transform(image=img)['image']

        img = img.transpose(2, 0, 1)
        img = torch.tensor(img)
        
        return img 

In [None]:
class Test:
    def __init__(self, sample_df_path:str, image_path:str, model, num_classes, transform=None, TTA_num=4, threshold=0.3):
        self.sample_df = pd.read_csv(sample_df_path)
        self.dataset = WhaleTestDatasetV2(path=image_path, df=self.sample_df, transform=transform)
        self.size = len(self.sample_df)
        self.model = model
        self.model.to(device)
        self.model.eval()
        self.num_classes = num_classes
        self.threshold = threshold
        self.TTA_num = TTA_num
        
        
    def test(self, save_path='sample_submission.csv'):
        for i in range(self.size):
            pred = self.TTA(idx=i)
            top5_label, top5_value = self.get_Top5(pred)
            individuals = individual_label_map['individual_id'][np.array(top5_label.cpu())].tolist()
            individuals = ' '.join(individuals)
            self.sample_df['predictions'][i] = individuals
            print(f'\r{i}/{self.size}|{top5_value.detach().cpu()}|{individuals}', end=' ')
        
        self.sample_df.to_csv(save_path, index=0)
            
            
    def TTA(self, idx):
        pred = torch.zeros(self.num_classes).to(device)
        for time in range(self.TTA_num):
            img = self.dataset[idx]
            img = TTA_transform[str(time)](image=img)['image']
            img = torch.tensor(img).unsqueeze(0)
            img = img.to(device)
            output, _ = self.model(img)
            pred += output.squeeze(0)
        pred /= self.TTA_num
        
        return pred     
           
        
    def get_Top5(self, pred):
        pred = torch.cat([pred, torch.tensor([self.threshold]).to(device)])
        top5 = torch.topk(pred, k=5)
        top5_value = top5.values
        top5_label = top5.indices
            
        return top5_label, top5_value

**1. Model_1 score : 0.62**

In [None]:
model_path = '../input/pretrained-happywhale-segconvs/SegConvS_HappyWhaleNet18_384.pth'
model = SegConvWhaleNet(backbone='convnext_small', num_classes=NUM_CLASSES, embedding_size=512, backbone_pretrained=False)
model.load_state_dict(torch.load(model_path))

In [None]:
image_path = '../input/whale2-cropped-dataset/cropped_test_images/cropped_test_images'
sample_df_path = '../input/happy-whale-and-dolphin/sample_submission.csv'

test_tool = Test(sample_df_path, image_path, model, num_classes = NUM_CLASSES, transform=transforms['TTA'], TTA_num=5, threshold=0.35)

In [None]:
test_tool.test()