# Imports

In [None]:
# Use this for static type checking of notebook in a terminal
# pip install -U nbqa
# nbqa mypy your_notebook.ipynb

In [None]:
!export CUDA_LAUNCH_BLOCKING=1 # for tracing issues
!pip install timm
!pip install torchinfo
!pip install --upgrade torchmetrics

import numpy as np
import pandas as pd
import os
import gc
import random
import wandb
from typing import Callable
from typing import Dict
from typing import Optional
from pathlib import Path
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # fixes a weird issue

from sklearn.preprocessing import LabelEncoder
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torchmetrics
from torchinfo import summary
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint

import timm
from timm.data.transforms_factory import create_transform
from timm.optim import create_optimizer_v2

# Configuration

In [None]:
class CFG:
    SEED = 42
    
    # Dataset
    N_FOLDS = 5
    NUM_WORKERS = 6 # number of threads for dataloaders
    
    ### Model
    MODEL_NAME = "tf_efficientnet_b4"
    PRETRAINED=True
    IMAGE_SIZE = 380
    EMBEDDING_SIZE = 512
    DROPOUT=0.0
    
    # Arcface
    S = 10
    M = 0.1
    
    # Training
    OPTIMIZER="adamW"
    MODEL_PATH="model.ckpt" # file to store the model checkpoints
    BATCH_SIZE = 32 # Effective batch size will be BATCH_SIZE*ACCUMULATE_GRAD_BATCHES
    ACCUMULATE_GRAD_BATCHES = 1 # "1" means updates model after every batch
    NUM_EPOCHS = 30
    LR = 0.001
    WEIGHT_DECAY = 0.000001

    DEBUG = False # smaller set if debugging

In [None]:
# Make everything deterministic - probably redundant here but I wanted to be sure
def fix_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    pl.seed_everything(seed)

fix_seed(CFG.SEED)

# Logging

In [None]:
# For using WandB to track training statistics
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb") 
wandb.login(key=wandb_api)

# Dataset

In [None]:
# Dataset PATHs
BASE_PATH = '../input/happywhale-enhanced-luke'
DATA_PATH = '../input/happy-whale-and-dolphin'
CHECKPOINTS_PATH = '../input/hwmodelcheckpoint/tf_efficientnet_b4_380.ckpt'
TRAIN_DIR = f"{BASE_PATH}/train_images"
TRAIN_DIR2 = f"{BASE_PATH}/imgextra (1)"
TEST_DIR = f"{DATA_PATH}/test_images"
TRAIN_CSV_PATH = f"{BASE_PATH}/data.csv"
TEST_CSV_PATH = f"{BASE_PATH}/sample_submission.csv"

OUTPUT_DIR = '/kaggle/working'
TRAIN_CSV_OUTPUT_PATH = f"{OUTPUT_DIR}/train.csv"
TEST_CSV_OUTPUT_PATH = f"{OUTPUT_DIR}/test.csv"
ENCODER_CLASSES_PATH = f"{OUTPUT_DIR}/encoder_classes.npy"
CHECKPOINTS_DIR = f"{OUTPUT_DIR}/checkpoints"
SUBMISSION_CSV_PATH = f"{OUTPUT_DIR}/submission.csv"

In [None]:
def get_image_path(id: str, dir: Path) -> str:
    return f"{dir}/{id}"

In [None]:
train_df = pd.read_csv(TRAIN_CSV_PATH)

N_CLASSES = len(train_df["individual_id"].unique())

train_df["image_path"] = train_df["image"].apply(get_image_path, dir=TRAIN_DIR)

# Integer encoding for individuals ids
encoder = LabelEncoder()
train_df["individual_id"] = encoder.fit_transform(train_df["individual_id"])

# Change path of some photos because I messed up when I created the dataset
src = '../input/happywhale-enhanced-luke/imgextra (1)'
for f in os.listdir(src):
    train_df.at[train_df['image']==f, 'image_path'] = f"{src}/{f}"

np.save(ENCODER_CLASSES_PATH, encoder.classes_) # save label encoder for use in other notebooks

train_df.to_csv(TRAIN_CSV_OUTPUT_PATH, index=False)
train_df.head()

In [None]:
# create test_df for making predictions
test_df = pd.read_csv('../input/happy-whale-and-dolphin/sample_submission.csv')

test_df["image_path"] = test_df["image"].apply(get_image_path, dir=TEST_DIR)
test_df.drop(columns=["predictions"], inplace=True)

test_df["individual_id"] = 0 # Dummy id

test_df.to_csv(TEST_CSV_OUTPUT_PATH, index=False)
test_df.head()

# Dataset Classes

In [None]:
class HappyWhaleDataset(Dataset):
    """
    Class to Make any dataset
    """
    def __init__(self, df: pd.DataFrame, transform: Optional[Callable] = None):
        self.df = df
        self.transform = transform

        self.image_names = self.df["image"].values
        self.image_paths = self.df["image_path"].values
        self.targets = self.df["individual_id"].values

    def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
        image_name = self.image_names[index]

        image_path = self.image_paths[index]

        image = Image.open(image_path)
    
        if self.transform:
            image = self.transform(image)

        target = self.targets[index]
        target = torch.tensor(target, dtype=torch.long)

        return {"image_name": image_name, "image": image, "target": target}

    def __len__(self) -> int:
        return len(self.df)

In [None]:
class LitDataModule(pl.LightningDataModule):
    """
    Lightning data module for testing, validation, and testing
    """
    def __init__(
        self,
        train_csv: str,
        test_csv: str,
        val_fold: float,
        image_size: int,
        batch_size: int,
        num_workers: int,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.train_df = pd.read_csv(train_csv)
        self.test_df = pd.read_csv(test_csv)
        
        # for training, includes augmentation
        self.transform_train = create_transform(
            input_size=(self.hparams.image_size, self.hparams.image_size),
            crop_pct=1.0,
        )
        # for evaluation, deactivates augmentation
        self.transform_eval = create_transform(
            input_size=(self.hparams.image_size, self.hparams.image_size),
            crop_pct=1.0,
            is_training=False,
        )
        
    def setup(self):
        # Split train df using fold
        train_df = self.train_df[self.train_df.folds != self.hparams.val_fold].reset_index(drop=True)
        val_df = self.train_df[self.train_df.folds == self.hparams.val_fold].reset_index(drop=True)

        self.train_dataset = HappyWhaleDataset(train_df, transform=self.transform_train)
        self.val_dataset = HappyWhaleDataset(val_df, transform=self.transform_eval)
        self.test_dataset = HappyWhaleDataset(self.test_df, transform=self.transform_eval)

    def train_dataloader(self) -> DataLoader:
        return self._dataloader(self.train_dataset, train=True)

    def val_dataloader(self) -> DataLoader:
        return self._dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._dataloader(self.test_dataset)

    def _dataloader(self, dataset: HappyWhaleDataset, train: bool = False) -> DataLoader:
        if train == True:
            batch_size = self.hparams.batch_size
        else:
            batch_size = self.hparams.batch_size*2 # double batch size for eval since there's no training

        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=train,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            drop_last=train,
        )

# Loss Function

In [None]:
class SoftMax(nn.Module):
    """
    Softmax loss, just a linear layer
    """
    def __init__(self, 
        num_features: int,
        num_classes: int,
    ):
        super(SoftMax, self).__init__()
        
        self.num_features = num_features
        self.n_classes = num_classes
        self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
        nn.init.xavier_uniform_(self.W)

    def forward(self, input: torch.Tensor, label: torch.Tensor, device: str = "cuda") -> torch.Tensor:
        x=input
        W=self.W

        logits = F.linear(x, W)

        return logits

In [None]:
class ArcFace(nn.Module):
    """
    ArcFace Loss
    """
    def __init__(self, 
        num_features: int,
        num_classes: int,
        s: float, 
        m: float):
        super(ArcFace, self).__init__()
        
        self.num_features = num_features
        self.n_classes = num_classes
        self.s = s
        self.m = m
        self.W = nn.Parameter(torch.FloatTensor(num_classes, num_features))
        nn.init.xavier_uniform_(self.W)

    def forward(self, input: torch.Tensor, label: torch.Tensor, device: str = "cuda") -> torch.Tensor:
        # normalize features
        x = F.normalize(input)
        # normalize weights
        W = F.normalize(self.W)
        # dot product
        logits = F.linear(x, W)

        if label is None:
            return logits
        
        # add margin
        theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7)) # truncate it because we don't need that high resolution
        target_logits = torch.cos(theta + self.m)

        # convert to one-hot encoding
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = logits * (1 - one_hot) + target_logits * one_hot
        
        # feature re-scale
        output *= self.s

        return output

# Setup Model

In [None]:
# summary of model architecture
# model = timm.create_model(model_name='tf_efficientnet_b4')
# print(summary(model))

In [None]:
class LitModule(pl.LightningModule):
    """
    Lightning module
    """
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        drop_rate: float,
        embedding_size: int,
        num_classes: int,
        arc_s: float,
        arc_m: float,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        len_train_dl: int,
        epochs:int
    ):
        super().__init__()

        self.save_hyperparameters()

        # get pretrained model from timm
        self.model = timm.create_model(model_name, pretrained=pretrained, drop_rate=drop_rate)
        
        # embedding layer to take output of feature extractor
        self.embedding = nn.Linear(self.model.get_classifier().in_features, embedding_size)
        # get rid of classifier in timm model
        self.model.reset_classifier(num_classes=0, global_pool="avg")

        # create ArcFace
        self.arc = ArcFace(
            num_features=embedding_size,
            num_classes=num_classes,
            s=arc_s,
            m=arc_m,
        )

        # use this if doing SoftMax
#         self.soft = SoftMax(
#             num_features=embedding_size,
#             num_classes=num_classes,
#         )

        # create loss functions
        self.loss_fn = F.cross_entropy
        self.train_acc = torchmetrics.Accuracy()
        self.train_top_5_acc = torchmetrics.Accuracy(top_k=5)
        self.train_f1 = torchmetrics.F1Score(num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy()
        self.val_top_5_acc = torchmetrics.Accuracy(top_k=5)
        self.val_f1 = torchmetrics.F1Score(num_classes=num_classes)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        features = self.model(images)
        embeddings = self.embedding(features)

        return embeddings
    
    def configure_optimizers(self):
        # create optimizer using timm
        optimizer = create_optimizer_v2(
            self.parameters(),
            opt=self.hparams.optimizer,
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        
        # create learning rate scheduler
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            steps_per_epoch=self.hparams.len_train_dl,
            epochs=self.hparams.epochs,
        )
        scheduler = {"scheduler": scheduler, "interval": "step"}

        return [optimizer], [scheduler]

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        images, targets = batch["image"], batch["target"]

        embeddings = self(images)
        outputs = self.arc(embeddings, targets, self.device)

        loss = self.loss_fn(outputs, targets)
        self.train_acc(outputs, targets)
        self.train_top_5_acc(outputs, targets)
        self.train_f1(outputs, targets)
        
        self.log(f"train_loss", loss, batch_size=CFG.BATCH_SIZE)
        self.log(f"train_acc", self.train_acc, batch_size=CFG.BATCH_SIZE)
        self.log(f"train_top_5_acc", self.train_top_5_acc, batch_size=CFG.BATCH_SIZE)
        self.log(f"train_f1", self.train_f1,  batch_size=CFG.BATCH_SIZE)
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        images, targets = batch["image"], batch["target"]

        embeddings = self(images)
        outputs = self.arc(embeddings, targets, self.device)

        loss = self.loss_fn(outputs, targets)
        self.val_acc(outputs, targets)
        self.val_top_5_acc(outputs, targets)
        self.val_f1(outputs, targets)

        self.log(f"val_loss", loss, batch_size=CFG.BATCH_SIZE)
        self.log(f"val_acc", self.val_acc, batch_size=CFG.BATCH_SIZE)
        self.log(f"val_top_5_acc", self.val_top_5_acc, batch_size=CFG.BATCH_SIZE)
        self.log(f"val_f1", self.val_f1, batch_size=CFG.BATCH_SIZE)

        return loss
    
    def predict_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        images = batch["image"]
        
        embeddings = self(images) # when there's no labels, it just does SoftMax with normalization
        pred = self.arc(embeddings, label=None, device=self.device)
        
        predtop = torch.topk(pred, 5) # saves values and indices of top5 predictions
        
        return predtop

# Setup Training

In [None]:
def train(
    train_csv: str = str(TRAIN_CSV_OUTPUT_PATH),
    test_csv: str = str(TEST_CSV_OUTPUT_PATH),
    val_fold: float = 2.0,
    image_size: int = 380,
    batch_size: int = 32,
    num_workers: int = 4,
    model_name: str = "tf_efficientnet_b4",
    log_name: str = "EffnetA",
    pretrained: bool = True,
    drop_rate: float = 0.0,
    embedding_size: int = 512,
    num_classes: int = 15587,
    arc_s: float = 30.0,
    arc_m: float = 0.5,
    optimizer: str = "adam",
    learning_rate: float = 3e-4,
    weight_decay: float = 1e-6,
    checkpoints_dir: str = str(CHECKPOINTS_DIR),
    accumulate_grad_batches: int = 1,
    auto_lr_find: bool = False,
    auto_scale_batch_size: bool = False,
    fast_dev_run: bool = False,
    gpus: int = 1,
    max_epochs: int = 10,
    precision: int = 16,
    stochastic_weight_avg: bool = True,
    continued_training: bool = False,
):
    """
    Setup and run training
    """
    
    torch.cuda.empty_cache()
    gc.collect()
    
    torch.autograd.set_detect_anomaly(False)
    torch.autograd.profiler.profile(False)
    torch.autograd.profiler.emit_nvtx(False)
    
    wandb_logger = WandbLogger(project="HappyWhale", name=log_name)
    
    datamodule = LitDataModule(
        train_csv=train_csv,
        test_csv=test_csv,
        val_fold=val_fold,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    
    datamodule.setup()
    len_train_dl = len(datamodule.train_dataloader())

    if continued_training:
        module = LitModule.load_from_checkpoint(CHECKPOINTS_PATH)
    else:
        module = LitModule(
            model_name=model_name,
            pretrained=pretrained,
            drop_rate=drop_rate,
            embedding_size=embedding_size,
            num_classes=num_classes,
            arc_s=arc_s,
            arc_m=arc_m,
            optimizer=optimizer,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            len_train_dl=len_train_dl,
            epochs=max_epochs
        )
    
    model_checkpoint = ModelCheckpoint(
        checkpoints_dir,
        filename=f"{model_name}_{image_size}",
        monitor="train_loss",
    )
        
    trainer = pl.Trainer(
        accumulate_grad_batches=accumulate_grad_batches,
        auto_lr_find=auto_lr_find,
        auto_scale_batch_size=auto_scale_batch_size,
        benchmark=True,
        callbacks=[model_checkpoint],
        deterministic=True,
        fast_dev_run=fast_dev_run,
        gpus=gpus,
        max_epochs=2 if CFG.DEBUG else max_epochs,
        precision=precision,
        stochastic_weight_avg=stochastic_weight_avg,
        limit_train_batches=0.1 if CFG.DEBUG else 1.0,
        limit_val_batches=0.1 if CFG.DEBUG else 1.0,
        logger=wandb_logger,
    )

    trainer.tune(module, datamodule=datamodule)

    # actual fitting
    trainer.fit(module, datamodule=datamodule)
    wandb_logger.finalize("success")
    
    torch.cuda.empty_cache()
    gc.collect()

# Run Training

In [None]:
# run training with desired hparams
train(model_name=CFG.MODEL_NAME,
      image_size=CFG.IMAGE_SIZE,
      batch_size=CFG.BATCH_SIZE,
      arc_s=CFG.S,
      arc_m=CFG.M,
      num_classes=N_CLASSES,
      learning_rate=CFG.LR,
      optimizer=CFG.OPTIMIZER,
      max_epochs=CFG.NUM_EPOCHS,
      pretrained=CFG.PRETRAINED,
      drop_rate=CFG.DROPOUT,
      log_name="Train-CE-LR001-S10-M1-Adam",
     )

wandb.finish()

In [None]:
# download the checkpoints if it's a live notebook
# from shutil import make_archive
# directory = "./checkpoints"
# make_archive('chkpnts', 'zip', directory)

# from IPython.display import FileLink
# FileLink(r'chkpnts.zip')