<br>
<h2 style = "font-size:60px; font-family:Garamond ; font-weight : normal; background-color: #f6f5f5 ; color : #fe346e; text-align: center; border-radius: 100px 100px;">[Pytorch] ArcFace Starter</h2>
<br>

![](https://media.istockphoto.com/illustrations/the-whale-is-blowing-illustration-id164494826?k=20&m=164494826&s=612x612&w=0&h=SGm8bwFqE7-h_ekqaXOVfIUIpKN8aW2AAMcFSbvpwYg=)

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Install Required Libraries</h1></span>

In [1]:
!pip install timm
!pip install --upgrade wandb
!pip install torch-lr-finder
!pip install -U pytorch_warmup

Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
     |████████████████████████████████| 431 kB 3.2 MB/s            
Installing collected packages: timm
Successfully installed timm-0.5.4
Collecting wandb
  Downloading wandb-0.12.11-py2.py3-none-any.whl (1.7 MB)
     |████████████████████████████████| 1.7 MB 3.1 MB/s            
Collecting setproctitle
  Downloading setproctitle-1.2.2-cp37-cp37m-manylinux1_x86_64.whl (36 kB)
Installing collected packages: setproctitle, wandb
  Attempting uninstall: wandb
    Found existing installation: wandb 0.12.9
    Uninstalling wandb-0.12.9:
      Successfully uninstalled wandb-0.12.9
Successfully installed setproctitle-1.2.2 wandb-0.12.11
Collecting torch-lr-finder
  Downloading torch_lr_finder-0.2.1-py3-none-any.whl (11 kB)
Installing collected packages: torch-lr-finder
Successfully installed torch-lr-finder-0.2.1
Collecting pytorch_warmup
  Downloading pytorch_warmup-0.0.4-py3-none-any.whl (6.5 kB)
Installing collected packages

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Import Required Libraries 📚</h1></span>

In [2]:
import os
import gc
import cv2
import math
import copy
import time
import random

# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

# For Image Models
import timm
import timm.optim

#Finding the best Learning rate
from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter
import pytorch_warmup as warmup

# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [3]:
from pprint import pprint
model_names = timm.list_models(pretrained=True)
pprint(model_names)

['adv_inception_v3',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_224_in22k',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_224_in22k',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'botnet26t_256',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'convit_base',
 'convit_small',
 'convit_tiny',
 'convmixer_768_32',
 'convmixer_1024_20_ks9_p14',
 'convmixer_1536_20',
 'convnext_base',
 'convnext_base_384_in22ft1k',
 'convnext_base_in22ft1k',
 'convnext_base_in22k',
 'convnext_large',
 'convnext_large_384_in22ft1k',
 'convnext_large_in22ft1k',
 'convnext_large_in22k',
 'convnext_small',
 'convnext_tiny',
 'convnext_xlarge_384_in22ft1k',
 'convnext_xlarge_in22ft1k',
 'convnext_xlarge_in22k',
 'crossvit_9_240',
 'crossv

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

<span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;"> Weights & Biases (W&B) is a set of machine learning tools that helps you build better models faster. <strong>Kaggle competitions require fast-paced model development and evaluation</strong>. There are a lot of components: exploring the training data, training different models, combining trained models in different combinations (ensembling), and so on.</span>

> <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">⏳ Lots of components = Lots of places to go wrong = Lots of time spent debugging</span>

<span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">W&B can be useful for Kaggle competition with it's lightweight and interoperable tools:</span>

* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Quickly track experiments,<br></span>
* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Version and iterate on datasets, <br></span>
* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Evaluate model performance,<br></span>
* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Reproduce models,<br></span>
* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Visualize results and spot regressions,<br></span>
* <span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">Share findings with colleagues.</span>

<span style="color: #000508; font-family: Segoe UI; font-size: 1.2em; font-weight: 300;">To learn more about Weights and Biases check out this <strong><a href="https://www.kaggle.com/ayuraj/experiment-tracking-with-weights-and-biases">kernel</a></strong>.</span>

In [4]:
import wandb

try:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb")
    wandb.login(key=api_key)
    anony = None
except:
    anony = "must"
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Training Configuration ⚙️</h1></span>

In [5]:
CONFIG = {"seed": 2022,
          "epochs": 35,
          "img_size": 512,
          "model_name": "tf_efficientnetv2_m",
          "num_classes": 53,
          "embedding_size": 512,
          "train_batch_size": 4,
          "valid_batch_size": 4,
          "learning_rate": 1e-2,
          "scheduler": 'CosineAnnealingLR',
          "min_lr": 8e-5,
          "T_max": 500,
          "weight_decay": 1e-7,
          "n_fold": 5,
          "n_accumulate": 1,
          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
          # ArcFace Hyperparameters
          "s": 30.0, 
          "m": 0.50,
          "ls_eps": 0.0,
          "easy_margin": False
          }

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Set Seed for Reproducibility</h1></span>

In [6]:
def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [7]:
ROOT_DIR = '../input/whale2-cropped-dataset'
TRAIN_DIR = '../input/whale2-cropped-dataset/cropped_train_images/cropped_train_images'
TEST_DIR = '../input/happy-whale-and-dolphin/test_images'

In [8]:
def get_train_file_path(id):
    return f"{TRAIN_DIR}/{id}"

# <h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Read the Data 📖</h1>

In [9]:
df = pd.read_csv(f"{ROOT_DIR}/train2.csv")
df.species.replace({"globis": "short_finned_pilot_whale",
                          "pilot_whale": "short_finned_pilot_whale",
                          "kiler_whale": "killer_whale",
                          "bottlenose_dolpin": "bottlenose_dolphin"}, inplace=True)

df['file_path'] = df['image'].apply(get_train_file_path)

df.head()

Unnamed: 0,image,species,individual_id,box,file_path
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9,2 116 802 665,../input/whale2-cropped-dataset/cropped_train_...
1,000562241d384d.jpg,humpback_whale,1a71fbb72250,588 597 3504 1477,../input/whale2-cropped-dataset/cropped_train_...
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b,0 453 3183 1589,../input/whale2-cropped-dataset/cropped_train_...
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063,1 91 2636 1495,../input/whale2-cropped-dataset/cropped_train_...
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392,1139 1590 3574 1913,../input/whale2-cropped-dataset/cropped_train_...


In [10]:
id_counts = df["individual_id"].value_counts() 
to_remove = id_counts[id_counts <= 74].index
df_train = df[~df.individual_id.isin(to_remove)]

In [11]:
df_train["individual_id"].value_counts()

37c7aba965a5    400
114207cab555    168
a6e325d8e924    155
19fbb960f07d    154
c995c043c353    153
f195c38bcf17    146
ffbb4e585ff2    145
ce6e37904aa4    145
281504409737    143
bc1eb2241633    141
9e89f8e28807    141
b9907151f66e    140
938b7e931166    135
c27db73f0e3b    135
4b8534134eb8    131
956562ff2888    131
600ab1de92d9    129
e69d5f9f8d1e    126
02da0e68dccd    122
208b91b1ca2b    122
180c0ab04dcd    122
1191a41ee0f4    121
136b6c84830f    119
778419da2957    112
6a3af6e0c55c    111
5bf17305f073    108
c2705f9e75c8    103
15d96d5d42c2     99
4a67e64bd3b7     99
dd8c756c9cb7     98
7362d7a01d00     94
812be36c2aef     93
5f48c2296a0e     90
9f3613b5c45b     90
b54c1f8df53f     89
2fdb3a09dc9c     88
10e758eb503a     87
a43daee90cbc     86
48936da899c3     84
8bc942512479     82
695bb814ce56     81
77410a623426     81
be330f0c495c     80
7485701415cd     80
2e0b381d3467     80
0e4660baf3f1     79
9ab8c57f10bc     79
bf412253bbc2     77
322a18725969     77
4b234d0d53c1     75


In [12]:
encoder = LabelEncoder()
df_train['individual_id'] = encoder.fit_transform(df_train['individual_id'])

with open("le.pkl", "wb") as fp:
    joblib.dump(encoder, fp)

In [13]:
print(len(np.unique(df_train["individual_id"])))

53


In [14]:
len(df_train)

6051

In [15]:
index_id = [i for i in range(6051)]

In [16]:
df_train.index = index_id
print(df_train)

                   image             species  individual_id  \
0     000a8f2d5c316a.jpg  bottlenose_dolphin             41   
1     001001f099519f.jpg         minke_whale              9   
2     00103cbe9d25ce.jpg           fin_whale              8   
3     00144776eb476d.jpg  bottlenose_dolphin             41   
4     00177f3c614d1e.jpg  bottlenose_dolphin             29   
...                  ...                 ...            ...   
6046  ffc5eb215d5539.jpg  bottlenose_dolphin             38   
6047  ffc71880c3066b.jpg         minke_whale              9   
6048  ffcc55db24e4d0.jpg  bottlenose_dolphin             52   
6049  ffcee8fa2578c1.jpg  bottlenose_dolphin             38   
6050  ffe24f955ff264.jpg  bottlenose_dolphin             32   

                      box                                          file_path  
0       259 443 2053 1068  ../input/whale2-cropped-dataset/cropped_train_...  
1      788 1036 2101 1373  ../input/whale2-cropped-dataset/cropped_train_...  
2     

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Create Folds</h1></span>

In [17]:
skf = StratifiedKFold(n_splits=CONFIG['n_fold'])

for fold, ( _, val_) in enumerate(skf.split(X=df_train, y=df_train.individual_id)):
      df_train.loc[val_ , "kfold"] = fold
#       df_val.loc[val_ , "kfold"] = fold      

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Dataset Class</h1></span>

In [18]:
class HappyWhaleDataset(Dataset):
    def __init__(self, df_train, transforms=None):
        self.df = df_train
        self.file_names = df_train['file_path'].values
        self.labels = df_train['individual_id'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_path = self.file_names[index]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        label = self.labels[index]
        
        if self.transforms:
            img = self.transforms(image=img)["image"]
            
        return {
            'image': img,
            'label': torch.tensor(label, dtype=torch.long)
        }

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Augmentations</h1></span>

In [19]:
data_transforms = {
    "train": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
#         A.HorizontalFlip(p=0.5),
#         A.VerticalFlip(p=0.5),
        A.Blur(p=0.2),
#         A.ToGray(p=0.5),
#         A.AdvancedBlur(p=0.5),
#         A.MultiplicativeNoise(p=0.5),
        A.Affine(p=0.5),
        A.Perspective(p=0.2),
        A.GaussNoise(p=0.3),
        A.Rotate(limit=10, p=0.5),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">GeM Pooling</h1></span>

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;">Code taken from <a href="https://amaarora.github.io/2020/08/30/gempool.html">GeM Pooling Explained</a></span>

![](https://i.imgur.com/thTgYWG.jpg)

In [20]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">ArcFace</h1></span>

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;">Code taken from <a href="https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/blob/master/src/modeling/metric_learning.py">Landmark2019-1st-and-3rd-Place-Solution</a></span>

In [21]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=CONFIG['device'])
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Create Model</h1></span>

In [22]:
model_name = CONFIG['model_name']
model = timm.create_model(model_name, pretrained=True)
print(model);

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_m-cc09e0cd.pth


EfficientNet(
  (conv_stem): Conv2dSame(3, 24, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (act1): SiLU(inplace=True)
  (blocks): Sequential(
    (0): Sequential(
      (0): ConvBnAct(
        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
      )
      (1): ConvBnAct(
        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inplace=True)
      )
      (2): ConvBnAct(
        (conv): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU(inp

In [23]:
# model_2 = "tf_efficientnet_b0"
# model_eff = timm.create_model(model_2, pretrained=True)
# print(model_eff.classifier)
# print(model_eff.global_pool)
# print(model_eff.pooling)

In [24]:
class HappyWhaleModel(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super(HappyWhaleModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = GeM()
#         self.pooling = GeM()
        self.embedding = nn.Linear(in_features, embedding_size)
        self.fc = ArcMarginProduct(embedding_size, 
                                   CONFIG["num_classes"],
                                   s=CONFIG["s"], 
                                   m=CONFIG["m"], 
                                   easy_margin=CONFIG["ls_eps"], 
                                   ls_eps=CONFIG["ls_eps"])

    def forward(self, images, labels):
        features = self.model(images)
        pooled_features = features.flatten(1)
        embedding = self.embedding(pooled_features)
        output = self.fc(embedding, labels)
        return output,embedding
    
model = HappyWhaleModel(CONFIG['model_name'], CONFIG['embedding_size'])

In [25]:
model.load_state_dict(torch.load("../input/256-53/Loss7.8851_epoch40.bin", map_location=CONFIG['device']))

<All keys matched successfully>

In [26]:
# model.fc = ArcMarginProduct(model.fc.in_features, 
#                                    CONFIG["num_classes"],
#                                    s=CONFIG["s"], 
#                                    m=CONFIG["m"], 
#                                    easy_margin=CONFIG["ls_eps"], 
#                                    ls_eps=CONFIG["ls_eps"])
model.to(CONFIG['device']);

In [27]:
# model = timm.create_model(CONFIG['model_name'], pretrained=True)
# # print(model)

In [28]:
# print(model)

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Loss Function</h1></span>

In [29]:
def criterion(outputs, labels):
    return nn.CrossEntropyLoss()(outputs, labels)

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Training Function</h1></span>

In [30]:
def train_one_epoch(model, optimizer, scheduler, warmup_scheduler, dataloader, device, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)
        
        outputs,_ = model(images, labels)
        loss = criterion(outputs, labels)
        loss = loss / CONFIG['n_accumulate']
            
        loss.backward()
    
        if (step + 1) % CONFIG['n_accumulate'] == 0:
            optimizer.step()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()
                
            warmup_scheduler.dampen()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Validation Function</h1></span>

In [31]:
@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)

        outputs,_ = model(images, labels)
        loss = criterion(outputs, labels)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])   
    
    gc.collect()
    
    return epoch_loss

# <span><h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Run Training</h1></span>

In [32]:
def run_training(model, optimizer, scheduler, warmup_scheduler, device, num_epochs):
    # To automatically log gradients
    wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_loss = np.inf
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss = train_one_epoch(model, optimizer, scheduler, warmup_scheduler, 
                                           dataloader=train_loader, 
                                           device=CONFIG['device'], epoch=epoch)
        
        val_epoch_loss = valid_one_epoch(model, valid_loader, device=CONFIG['device'], 
                                         epoch=epoch)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(val_epoch_loss)
        
        # Log the metrics
        wandb.log({"Train Loss": train_epoch_loss})
        wandb.log({"Valid Loss": val_epoch_loss})
        
        # deep copy the model
        if val_epoch_loss <= best_epoch_loss:
            print(f"{b_}Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss})")
            best_epoch_loss = val_epoch_loss
            test = os.listdir("./")

            for item in test:
                if item.endswith(".bin"):
                    os.remove(os.path.join("./", item))
#             run.summary["Best Loss"] = best_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "Loss{:.4f}_epoch{:.0f}.bin".format(best_epoch_loss, epoch)
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved{sr_}")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Loss: {:.4f}".format(best_epoch_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [33]:
def fetch_scheduler(optimizer):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CONFIG['T_0'], 
                                                             eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

In [34]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;">Prepare Dataloaders</span>

In [35]:
train_loader, valid_loader = prepare_loaders(df_train, fold=0)

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;">Define Optimizer and best learning rate finder</span>

In [36]:
# optimizer = optim.SGD(model.parameters(), lr=1e-02, 
#                        weight_decay=CONFIG['weight_decay'])
optimizer = timm.optim.SGDP(model.parameters(), lr=0.01)

In [37]:
# class CustomTrainIter(TrainDataLoaderIter):
#     def inputs_labels_from_batch(self, batch_data):
#         images = batch_data["image"]
#         labels = batch_data["label"]
#         return (images, labels), labels

# class ModelWrapper(nn.Module):
#     def __init__(self, model):
#         super(ModelWrapper, self).__init__()
#         self.model = model

#     def forward(self, data):
#         # Unpack data to the format you need
#         img, labels = data
#         return self.model(img, labels)
    
# model_wrap = ModelWrapper(model)
# model_wrap.to(CONFIG['device'])
    
# custom_loader = CustomTrainIter(train_loader)

# lr_finder = LRFinder(model_wrap, optimizer, criterion, device=CONFIG['device'])
# lr_finder.range_test(custom_loader, start_lr = 0.000001, end_lr=0.01, num_iter=200)
# lr_finder.plot() # to inspect the loss-learning rate graph
# lr_finder.reset()

## Add scheduler and Warm-up epochs

In [38]:
scheduler = fetch_scheduler(optimizer)
warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period=1)
warmup_scheduler.last_step = -1

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;">Start Training</span>

In [39]:
run = wandb.init(project='HappyWhale', 
                 config=CONFIG,
                 job_type='Train',
                 tags=['arcface', 'gem-pooling', 'tf_effv2_m', '128'],
                 anonymous='must')

[34m[1mwandb[0m: Currently logged in as: [33madnanpen[0m (use `wandb login --relogin` to force relogin)


In [40]:
model, history = run_training(model, optimizer, scheduler, warmup_scheduler,
                              device=CONFIG['device'],
                              num_epochs=CONFIG['epochs'])

[INFO] Using GPU: Tesla P100-PCIE-16GB



100%|██████████| 1210/1210 [14:56<00:00,  1.35it/s, Epoch=1, LR=0.00627, Train_Loss=8.87]
100%|██████████| 303/303 [00:46<00:00,  6.52it/s, Epoch=1, LR=0.00627, Valid_Loss=8.19]


[34mValidation Loss Improved (inf ---> 8.191935478607354)
Model Saved[0m



100%|██████████| 1210/1210 [14:47<00:00,  1.36it/s, Epoch=2, LR=0.000694, Train_Loss=8.64]
100%|██████████| 303/303 [00:45<00:00,  6.62it/s, Epoch=2, LR=0.000694, Valid_Loss=8.21]





100%|██████████| 1210/1210 [14:47<00:00,  1.36it/s, Epoch=3, LR=0.00164, Train_Loss=8.52] 
100%|██████████| 303/303 [00:46<00:00,  6.53it/s, Epoch=3, LR=0.00164, Valid_Loss=7.89]


[34mValidation Loss Improved (8.191935478607354 ---> 7.894992993352435)
Model Saved[0m



100%|██████████| 1210/1210 [14:49<00:00,  1.36it/s, Epoch=4, LR=0.0077, Train_Loss=8.45] 
100%|██████████| 303/303 [00:46<00:00,  6.53it/s, Epoch=4, LR=0.0077, Valid_Loss=8.07]





100%|██████████| 1210/1210 [14:51<00:00,  1.36it/s, Epoch=5, LR=0.00976, Train_Loss=8.41]
100%|██████████| 303/303 [00:45<00:00,  6.62it/s, Epoch=5, LR=0.00976, Valid_Loss=7.61]


[34mValidation Loss Improved (7.894992993352435 ---> 7.6087705307416345)
Model Saved[0m



100%|██████████| 1210/1210 [14:51<00:00,  1.36it/s, Epoch=6, LR=0.00473, Train_Loss=8.33]
100%|██████████| 303/303 [00:46<00:00,  6.57it/s, Epoch=6, LR=0.00473, Valid_Loss=8.3] 





100%|██████████| 1210/1210 [14:49<00:00,  1.36it/s, Epoch=7, LR=0.000168, Train_Loss=8.29]
100%|██████████| 303/303 [00:46<00:00,  6.56it/s, Epoch=7, LR=0.000168, Valid_Loss=7.97]





100%|██████████| 1210/1210 [14:46<00:00,  1.36it/s, Epoch=8, LR=0.00293, Train_Loss=8.23] 
100%|██████████| 303/303 [00:46<00:00,  6.59it/s, Epoch=8, LR=0.00293, Valid_Loss=7.88]





100%|██████████| 1210/1210 [14:44<00:00,  1.37it/s, Epoch=9, LR=0.00886, Train_Loss=8.2] 
100%|██████████| 303/303 [00:45<00:00,  6.63it/s, Epoch=9, LR=0.00886, Valid_Loss=8.19]





100%|██████████| 1210/1210 [14:43<00:00,  1.37it/s, Epoch=10, LR=0.00905, Train_Loss=8.18]
100%|██████████| 303/303 [00:45<00:00,  6.65it/s, Epoch=10, LR=0.00905, Valid_Loss=7.8] 





100%|██████████| 1210/1210 [14:42<00:00,  1.37it/s, Epoch=11, LR=0.00321, Train_Loss=8.15]
100%|██████████| 303/303 [00:45<00:00,  6.64it/s, Epoch=11, LR=0.00321, Valid_Loss=7.74]





100%|██████████| 1210/1210 [14:44<00:00,  1.37it/s, Epoch=12, LR=0.000119, Train_Loss=8.1] 
100%|██████████| 303/303 [00:45<00:00,  6.61it/s, Epoch=12, LR=0.000119, Valid_Loss=7.5] 


[34mValidation Loss Improved (7.6087705307416345 ---> 7.504015690231008)
Model Saved[0m



100%|██████████| 1210/1210 [14:43<00:00,  1.37it/s, Epoch=13, LR=0.00442, Train_Loss=8.07] 
100%|██████████| 303/303 [00:46<00:00,  6.59it/s, Epoch=13, LR=0.00442, Valid_Loss=7.67]





100%|██████████| 1210/1210 [14:41<00:00,  1.37it/s, Epoch=14, LR=0.00965, Train_Loss=8.08]
100%|██████████| 303/303 [00:45<00:00,  6.67it/s, Epoch=14, LR=0.00965, Valid_Loss=7.92]





100%|██████████| 1210/1210 [14:43<00:00,  1.37it/s, Epoch=15, LR=0.00796, Train_Loss=8.04]
100%|██████████| 303/303 [00:46<00:00,  6.53it/s, Epoch=15, LR=0.00796, Valid_Loss=7.84]





100%|██████████| 1210/1210 [14:45<00:00,  1.37it/s, Epoch=16, LR=0.00188, Train_Loss=7.99]
100%|██████████| 303/303 [00:46<00:00,  6.55it/s, Epoch=16, LR=0.00188, Valid_Loss=7.46]


[34mValidation Loss Improved (7.504015690231008 ---> 7.457063239040107)
Model Saved[0m



100%|██████████| 1210/1210 [14:46<00:00,  1.37it/s, Epoch=17, LR=0.000552, Train_Loss=7.95]
100%|██████████| 303/303 [00:45<00:00,  6.61it/s, Epoch=17, LR=0.000552, Valid_Loss=7.4] 


[34mValidation Loss Improved (7.457063239040107 ---> 7.395989204220886)
Model Saved[0m



100%|██████████| 1210/1210 [14:50<00:00,  1.36it/s, Epoch=18, LR=0.00597, Train_Loss=7.97] 
100%|██████████| 303/303 [00:45<00:00,  6.61it/s, Epoch=18, LR=0.00597, Valid_Loss=7.6] 





100%|██████████| 1210/1210 [14:47<00:00,  1.36it/s, Epoch=19, LR=0.00999, Train_Loss=7.94]
100%|██████████| 303/303 [00:45<00:00,  6.68it/s, Epoch=19, LR=0.00999, Valid_Loss=8.26]





100%|██████████| 1210/1210 [14:53<00:00,  1.35it/s, Epoch=20, LR=0.00657, Train_Loss=7.93]
100%|██████████| 303/303 [00:45<00:00,  6.60it/s, Epoch=20, LR=0.00657, Valid_Loss=7.89]





100%|██████████| 1210/1210 [14:50<00:00,  1.36it/s, Epoch=21, LR=0.000852, Train_Loss=7.91]
100%|██████████| 303/303 [00:45<00:00,  6.59it/s, Epoch=21, LR=0.000852, Valid_Loss=7.55]





100%|██████████| 1210/1210 [14:48<00:00,  1.36it/s, Epoch=22, LR=0.00142, Train_Loss=7.89] 
100%|██████████| 303/303 [00:46<00:00,  6.49it/s, Epoch=22, LR=0.00142, Valid_Loss=7.49]





100%|██████████| 1210/1210 [14:45<00:00,  1.37it/s, Epoch=24, LR=0.00984, Train_Loss=7.89]
100%|██████████| 303/303 [00:45<00:00,  6.62it/s, Epoch=24, LR=0.00984, Valid_Loss=7.5] 





100%|██████████| 1210/1210 [14:40<00:00,  1.37it/s, Epoch=25, LR=0.00504, Train_Loss=7.86]
100%|██████████| 303/303 [00:45<00:00,  6.63it/s, Epoch=25, LR=0.00504, Valid_Loss=7.68]





100%|██████████| 1210/1210 [14:45<00:00,  1.37it/s, Epoch=26, LR=0.000236, Train_Loss=7.82]
100%|██████████| 303/303 [00:45<00:00,  6.65it/s, Epoch=26, LR=0.000236, Valid_Loss=7.62]





100%|██████████| 1210/1210 [14:48<00:00,  1.36it/s, Epoch=27, LR=0.00265, Train_Loss=7.82] 
100%|██████████| 303/303 [00:46<00:00,  6.57it/s, Epoch=27, LR=0.00265, Valid_Loss=7.82]





100%|██████████| 1210/1210 [14:51<00:00,  1.36it/s, Epoch=28, LR=0.00866, Train_Loss=7.83]
100%|██████████| 303/303 [00:45<00:00,  6.60it/s, Epoch=28, LR=0.00866, Valid_Loss=8.35]





100%|██████████| 1210/1210 [14:50<00:00,  1.36it/s, Epoch=29, LR=0.00923, Train_Loss=7.85]
100%|██████████| 303/303 [00:46<00:00,  6.54it/s, Epoch=29, LR=0.00923, Valid_Loss=7.67]





100%|██████████| 1210/1210 [14:52<00:00,  1.36it/s, Epoch=30, LR=0.00351, Train_Loss=7.8] 
100%|██████████| 303/303 [00:46<00:00,  6.50it/s, Epoch=30, LR=0.00351, Valid_Loss=7.89]





100%|██████████| 1210/1210 [14:54<00:00,  1.35it/s, Epoch=31, LR=8.98e-5, Train_Loss=7.78] 
100%|██████████| 303/303 [00:46<00:00,  6.56it/s, Epoch=31, LR=8.98e-5, Valid_Loss=7.39]


[34mValidation Loss Improved (7.395989204220886 ---> 7.390798481705758)
Model Saved[0m



100%|██████████| 1210/1210 [14:52<00:00,  1.36it/s, Epoch=32, LR=0.00411, Train_Loss=7.77] 
100%|██████████| 303/303 [00:47<00:00,  6.44it/s, Epoch=32, LR=0.00411, Valid_Loss=7.5] 





100%|██████████| 1210/1210 [14:53<00:00,  1.35it/s, Epoch=33, LR=0.00953, Train_Loss=7.79]
100%|██████████| 303/303 [00:45<00:00,  6.60it/s, Epoch=33, LR=0.00953, Valid_Loss=7.86]





100%|██████████| 1210/1210 [14:51<00:00,  1.36it/s, Epoch=34, LR=0.0082, Train_Loss=7.78] 
100%|██████████| 303/303 [00:46<00:00,  6.58it/s, Epoch=34, LR=0.0082, Valid_Loss=7.51]





 25%|██▌       | 303/1210 [03:43<11:09,  1.35it/s, Epoch=35, LR=0.000395, Train_Loss=7.75]


KeyboardInterrupt: 

In [None]:
os.remove("*.txt")

In [None]:
run.finish()

# <h1 style = "font-family: garamond; font-size: 40px; font-style: normal; letter-spcaing: 3px; background-color: #f6f5f5; color :#fe346e; border-radius: 100px 100px; text-align:center">Visualizations</h1>

<span style="color: #000508; font-family: Segoe UI; font-size: 1.5em; font-weight: 300;"><a href="https://wandb.ai/dchanda/HappyWhale/runs/3j25um1k">View the Complete Dashboard Here ⮕</a></span>

![](https://i.imgur.com/zD3rD0W.jpg)

![Upvote!](https://img.shields.io/badge/Upvote-If%20you%20like%20my%20work-07b3c8?style=for-the-badge&logo=kaggle)