In [1]:
!pip install timm
!pip install --upgrade wandb




In [2]:
import os
import gc
import cv2
import math
import copy
import time
import random

In [3]:
# For data manipulation
import numpy as np
import pandas as pd

# Pytorch Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp


In [4]:
# Utils
import joblib
from tqdm import tqdm
from collections import defaultdict

# Sklearn Imports
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold

# For Image Models
import timm

In [5]:
pip install albumentations

Note: you may need to restart the kernel to use updated packages.


In [6]:
# Albumentations for augmentations
import albumentations as A
from albumentations.pytorch import ToTensorV2

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
sr_ = Style.RESET_ALL

import warnings
warnings.filterwarnings("ignore")

# For descriptive error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

## Weights & Biases

In [7]:
import wandb

try:
    with open("./key.txt") as f:
        line = f.readlines()
    wandb.login(key=line)
    anony = None
except:
    anony = "must"
    print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')

[34m[1mwandb[0m: Currently logged in as: [33mcaracao[0m ([33mcaracao11111[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Training Config

In [8]:
CONFIG = {"seed": 2022,
          "epochs": 4,
          "img_size": 448,
          "model_name": "tf_efficientnet_b0_ns",
          "num_classes": 15587,
          "embedding_size": 512,
          "train_batch_size": 32,
          "valid_batch_size": 64,
          "learning_rate": 1e-4,
          "scheduler": 'CosineAnnealingLR',
          "min_lr": 1e-6,
          "T_max": 500,
          "weight_decay": 1e-6,
          "n_fold": 5,
          "n_accumulate": 1,
          "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
          # ArcFace Hyperparameters
          "s": 30.0, 
          "m": 0.50,
          "ls_eps": 0.0,
          "easy_margin": False
          }


In [9]:
## Set Seed for Reproducibility

def set_seed(seed=42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG['seed'])

In [41]:
ROOT_DIR = '../HappyWhale'
TRAIN_DIR = './cropped_train_images'
TEST_DIR = './cropped_test_images'

In [42]:
def get_train_file_path(id):
    return f"{TRAIN_DIR}/{id}"

## Read the Data

In [43]:
df = pd.read_csv(f"{ROOT_DIR}/train.csv")
df['file_path'] = df['image'].apply(get_train_file_path)
df.head()

Unnamed: 0,image,species,individual_id,file_path
0,00021adfb725ed.jpg,melon_headed_whale,cadddb1636b9,./cropped_train_images/00021adfb725ed.jpg
1,000562241d384d.jpg,humpback_whale,1a71fbb72250,./cropped_train_images/000562241d384d.jpg
2,0007c33415ce37.jpg,false_killer_whale,60008f293a2b,./cropped_train_images/0007c33415ce37.jpg
3,0007d9bca26a99.jpg,bottlenose_dolphin,4b00fe572063,./cropped_train_images/0007d9bca26a99.jpg
4,00087baf5cef7a.jpg,humpback_whale,8e5253662392,./cropped_train_images/00087baf5cef7a.jpg


In [44]:
encoder = LabelEncoder()
df['individual_id'] = encoder.fit_transform(df['individual_id'])

with open("le.pkl", "wb") as fp:
    joblib.dump(encoder, fp)

## Create Folds


In [45]:
skf = StratifiedKFold(n_splits=CONFIG['n_fold'])

for fold, ( _, val_) in enumerate(skf.split(X=df, y=df.individual_id)):
      df.loc[val_ , "kfold"] = fold

## Dataset Class

In [46]:
from HappyWhaleDataset import HappyWhaleDataset

In [47]:
## Augmentations

In [48]:
data_transforms = {
    "train": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.ShiftScaleRotate(shift_limit=0.1, 
                           scale_limit=0.15, 
                           rotate_limit=60, 
                           p=0.5),
        A.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
        A.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG['img_size'], CONFIG['img_size']),
        A.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
        ToTensorV2()], p=1.)
}

In [49]:
## GeM Pooling

In [50]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return self.gem(x, p=self.p, eps=self.eps)
        
    def gem(self, x, p=3, eps=1e-6):
        return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)
        
    def __repr__(self):
        return self.__class__.__name__ + \
                '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + \
                ', ' + 'eps=' + str(self.eps) + ')'

In [51]:
## ArcFace

In [52]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.50, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=CONFIG['device'])
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [53]:
## Create Model

In [54]:
class HappyWhaleModel(nn.Module):
    def __init__(self, model_name, embedding_size, pretrained=True):
        super(HappyWhaleModel, self).__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained)
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Identity()
        self.model.global_pool = nn.Identity()
        self.pooling = GeM()
        self.embedding = nn.Linear(in_features, embedding_size)
        self.fc = ArcMarginProduct(embedding_size, 
                                   CONFIG["num_classes"],
                                   s=CONFIG["s"], 
                                   m=CONFIG["m"], 
                                   easy_margin=CONFIG["ls_eps"], 
                                   ls_eps=CONFIG["ls_eps"])

    def forward(self, images, labels):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        output = self.fc(embedding, labels)
        return output
    
    def extract(self, images):
        features = self.model(images)
        pooled_features = self.pooling(features).flatten(1)
        embedding = self.embedding(pooled_features)
        return embedding

    
model = HappyWhaleModel(CONFIG['model_name'], CONFIG['embedding_size'])
model.to(CONFIG['device']);

In [55]:
## Loss Function

In [56]:
def criterion(outputs, labels):
    return nn.CrossEntropyLoss()(outputs, labels)

In [57]:
## Training Function

In [58]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)
        
        outputs = model(images, labels)
        loss = criterion(outputs, labels)
        loss = loss / CONFIG['n_accumulate']
            
        loss.backward()
    
        if (step + 1) % CONFIG['n_accumulate'] == 0:
            optimizer.step()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss

In [59]:
## Validation Function

In [60]:
@torch.inference_mode()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:        
        images = data['image'].to(device, dtype=torch.float)
        labels = data['label'].to(device, dtype=torch.long)
        
        batch_size = images.size(0)

        outputs = model(images, labels)
        loss = criterion(outputs, labels)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])   
    
    gc.collect()
    
    return epoch_loss

In [61]:
## Run Training

In [62]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    wandb.watch(model, log_freq=100)
    
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_loss = np.inf
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CONFIG['device'], epoch=epoch)
        
        val_epoch_loss = valid_one_epoch(model, valid_loader, device=CONFIG['device'], 
                                         epoch=epoch)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(val_epoch_loss)
        
        # Log the metrics
        wandb.log({"Train Loss": train_epoch_loss})
        wandb.log({"Valid Loss": val_epoch_loss})
        
        # deep copy the model
        if val_epoch_loss <= best_epoch_loss:
            print(f"{b_}Validation Loss Improved ({best_epoch_loss} ---> {val_epoch_loss})")
            best_epoch_loss = val_epoch_loss
            run.summary["Best Loss"] = best_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "Loss{:.4f}_epoch{:.0f}.bin".format(best_epoch_loss, epoch)
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved{sr_}")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Loss: {:.4f}".format(best_epoch_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [63]:
def fetch_scheduler(optimizer):
    if CONFIG['scheduler'] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CONFIG['T_0'], 
                                                             eta_min=CONFIG['min_lr'])
    elif CONFIG['scheduler'] == None:
        return None
        
    return scheduler

In [64]:
def prepare_loaders(df, fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = HappyWhaleDataset(df_train, transforms=data_transforms["train"])
    valid_dataset = HappyWhaleDataset(df_valid, transforms=data_transforms["valid"])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG['train_batch_size'], 
                              num_workers=2, shuffle=True, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG['valid_batch_size'], 
                              num_workers=2, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader


In [65]:
## Prepare Dataloaders

In [66]:
train_loader, valid_loader = prepare_loaders(df, fold=0)

In [67]:
## Define Optimizer and Scheduler

In [68]:
optimizer = optim.Adam(model.parameters(), lr=CONFIG['learning_rate'], 
                       weight_decay=CONFIG['weight_decay'])
scheduler = fetch_scheduler(optimizer)

In [69]:
## Start Training

In [70]:
run = wandb.init(project='HappyWhale', 
                 config=CONFIG,
                 job_type='Train',
                 tags=['arcface', 'gem-pooling', 'effnet-b0-ns', '448'],
                 anonymous='must')

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [71]:
model, history = run_training(model, optimizer, scheduler,
                              device=CONFIG['device'],
                              num_epochs=CONFIG['epochs'])

  5%| | 67/1275 [35:32<10:10:26, 30.32s/it, Epoch=1, LR=9.57e-5, Train_Loss=24.2wandb: Network error (ReadTimeout), entering retry loop.
 57%|▌| 726/1275 [6:06:37<4:32:29, 29.78s/it, Epoch=1, LR=4.31e-5, Train_Loss=23wandb: Network error (ReadTimeout), entering retry loop.
 88%|▉| 1116/1275 [9:21:50<1:19:27, 29.98s/it, Epoch=1, LR=8.74e-5, Train_Loss=2wandb: Network error (ReadTimeout), entering retry loop.
 92%|▉| 1179/1275 [9:52:47<47:23, 29.62s/it, Epoch=1, LR=7.19e-5, Train_Loss=22.wandb: Network error (ReadTimeout), entering retry loop.
100%|█| 1275/1275 [10:45:50<00:00, 30.39s/it, Epoch=1, LR=4.28e-5, Train_Loss=22
100%|███| 160/160 [40:22<00:00, 15.14s/it, Epoch=1, LR=4.28e-5, Valid_Loss=19.1]


[34mValidation Loss Improved (inf ---> 19.110681076176874)
Model Saved[0m



 10%| | 124/1275 [1:11:02<11:09:27, 34.90s/it, Epoch=2, LR=1.06e-5, Train_Loss=1wandb: Network error (ReadTimeout), entering retry loop.
 17%|▏| 220/1275 [2:06:00<10:06:48, 34.51s/it, Epoch=2, LR=1.02e-6, Train_Loss=1wandb: Network error (ReadTimeout), entering retry loop.
 59%|▌| 751/1275 [6:57:18<4:20:34, 29.84s/it, Epoch=2, LR=9.93e-5, Train_Loss=18wandb: Network error (ReadTimeout), entering retry loop.
100%|█| 1275/1275 [11:27:29<00:00, 32.35s/it, Epoch=2, LR=3.42e-6, Train_Loss=17
100%|███| 160/160 [41:39<00:00, 15.62s/it, Epoch=2, LR=3.42e-6, Valid_Loss=17.3]


[34mValidation Loss Improved (19.110681076176874 ---> 17.2952968441939)
Model Saved[0m



 75%|▊| 959/1275 [7:56:31<2:38:28, 30.09s/it, Epoch=3, LR=1.08e-6, Train_Loss=15wandb: Network error (ReadTimeout), entering retry loop.
 98%|▉| 1245/1275 [10:16:42<14:28, 28.96s/it, Epoch=3, LR=6.43e-5, Train_Loss=14wandb: Network error (ReadTimeout), entering retry loop.
100%|█| 1275/1275 [10:31:33<00:00, 29.72s/it, Epoch=3, LR=7.3e-5, Train_Loss=14.
 84%|███▍| 135/160 [29:22<05:26, 13.06s/it, Epoch=3, LR=7.3e-5, Valid_Loss=14.3]wandb: Network error (ReadTimeout), entering retry loop.
100%|██████| 160/160 [34:51<00:00, 13.07s/it, Epoch=3, LR=7.3e-5, Valid_Loss=15]


[34mValidation Loss Improved (17.2952968441939 ---> 14.995058596560476)
Model Saved[0m



 25%|▎| 321/1275 [2:36:20<7:41:58, 29.06s/it, Epoch=4, LR=8.06e-5, Train_Loss=14wandb: Network error (ReadTimeout), entering retry loop.
 72%|▋| 922/1275 [8:04:36<3:17:39, 33.60s/it, Epoch=4, LR=4.96e-5, Train_Loss=14wandb: Network error (ReadTimeout), entering retry loop.
100%|█| 1275/1275 [11:01:01<00:00, 31.11s/it, Epoch=4, LR=9.05e-5, Train_Loss=14
100%|███| 160/160 [36:05<00:00, 13.53s/it, Epoch=4, LR=9.05e-5, Valid_Loss=14.8]


[34mValidation Loss Improved (14.995058596560476 ---> 14.799370113054009)
Model Saved[0m

Training complete in 46h 19m 25s
Best Loss: 14.7994


In [72]:
run.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Train Loss,█▄▁▁
Valid Loss,█▅▁▁

0,1
Best Loss,14.79937
Train Loss,14.79783
Valid Loss,14.79937
