# 🐋🐬 PyTorch ⚡ BackFin ConvNeXt ArcFace

Let's train [`timm`](https://github.com/rwightman/pytorch-image-models) models with [PyTorch Lightning](https://www.pytorchlightning.ai/)!

## Sources
- [[Pytorch] ArcFace + GeM Pooling Starter](https://www.kaggle.com/debarshichanda/pytorch-arcface-gem-pooling-starter)
- [FAISS Pytorch Inference](https://www.kaggle.com/debarshichanda/faiss-pytorch-inference)
- [backfintfrecrods](https://www.kaggle.com/datasets/jpbremer/backfintfrecords) dataset

## Nice Lightning `Trainer` Flags to try
- Run a learning rate finder algorithm: `auto_lr_find=True`
- Automatically try to find the largest batch size that fits into memory: `auto_scale_batch_size=True`
- Quickly check whether everything runs fine: `fast_dev_run=True`
- Train on multiple GPUs: `gpus=2` (if you use multiple GPUs, also set `accelerator=ddp`)
- Train with half precision: `precision=16`
- Use Stochastic Weight Averaging: `stochastic_weight_avg=True`

## Public LB scores
- V02: 0.378 (`image_size=256`, `"tf_efficientnet_b2"`, `batch_size=128`, `learning_rate=3e-4`)
- V10: 0.439 (`image_size=512`, `"tf_efficientnet_b0"`, `batch_size=64`, `learning_rate=3e-4`)
- V12: 0.498 (`image_size=512`, `"tf_efficientnet_b2"`, `batch_size=32`, `learning_rate=3e-4`)
- V14: 0.567 (`image_size=512`, `"tf_efficientnet_b4"`, `batch_size=16`, `learning_rate=3e-4`)
- V21: 0.656 (backfin cropped data, `image_size=384`, `"tf_efficientnet_b4"`, `batch_size=32`, `learning_rate=3e-4`, `scheduler=OneCycleLR`)
- V22: 0.701 (backfin cropped data, `image_size=384`, `"convnext_small"`, `batch_size=32`, `learning_rate=3e-4`, `scheduler=OneCycleLR`)

# Installations (timm + FAISS)

In [None]:
!pip install -q timm faiss-gpu
!pip install -q happywhale -U -f ../input/-pytorch-lightning-happywhale-pkg
!pip install -q Pillow==9.0.0
!pip uninstall -y torchtext
!pip list | grep torch
!pip list | grep lightning
!pip list | grep happywhale

# Imports

In [None]:
import math
from typing import Callable, Dict, Optional, Tuple
from pathlib import Path

import faiss
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from timm.data.transforms_factory import create_transform
from timm.optim import create_optimizer_v2
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import normalize, LabelEncoder

# Paths & Settings

In [None]:
INPUT_DIR = Path("..") / "input"
OUTPUT_DIR = Path("/") / "kaggle" / "working"

DATA_ROOT_DIR = INPUT_DIR / "convert-backfintfrecords" / "happy-whale-and-dolphin-backfin"
TRAIN_DIR = DATA_ROOT_DIR / "train_images"
TEST_DIR = DATA_ROOT_DIR / "test_images"
TRAIN_CSV_PATH = DATA_ROOT_DIR / "train.csv"
SAMPLE_SUBMISSION_CSV_PATH = DATA_ROOT_DIR / "sample_submission.csv"

N_SPLITS = 5

ENCODER_CLASSES_PATH = OUTPUT_DIR / "encoder_classes.npy"
TEST_CSV_PATH = OUTPUT_DIR / "test.csv"
TRAIN_CSV_ENCODED_FOLDED_PATH = OUTPUT_DIR / "train_encoded_folded.csv"
CHECKPOINTS_DIR = OUTPUT_DIR / "checkpoints"
SUBMISSION_CSV_PATH = OUTPUT_DIR / "submission.csv"

# Prepare DataFrames

In [None]:
def get_image_path(id: str, dir: Path) -> str:
    return str(dir / id)

## Train DataFrame

In [None]:
train_df = pd.read_csv(TRAIN_CSV_PATH)

train_df["image_path"] = train_df["image"].apply(get_image_path, dir=TRAIN_DIR)

encoder = LabelEncoder()
train_df["individual_id"] = encoder.fit_transform(train_df["individual_id"])
np.save(ENCODER_CLASSES_PATH, encoder.classes_)

skf = StratifiedKFold(n_splits=N_SPLITS)
for fold, (_, val_) in enumerate(skf.split(X=train_df, y=train_df.individual_id)):
    train_df.loc[val_, "kfold"] = fold

train_df = train_df.astype({'kfold': 'int8'})
train_df.to_csv(TRAIN_CSV_ENCODED_FOLDED_PATH, index=False)
display(train_df.head())

## Test DataFrame

In [None]:
# Use sample submission csv as template
test_df = pd.read_csv(SAMPLE_SUBMISSION_CSV_PATH)
test_df["image_path"] = test_df["image"].apply(get_image_path, dir=TEST_DIR)

test_df.drop(columns=["predictions"], inplace=True)

# Dummy id
test_df["individual_id"] = 0
test_df.to_csv(TEST_CSV_PATH, index=False)
test_df.head()

# Lightning DataModule

In [None]:
from happywhale.utils.dataset import HappyWhaleDataset

class LitDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_csv_encoded_folded: str,
        test_csv: str,
        image_size: int,
        batch_size: int,
        num_workers: int,
        val_split: float = 0.1,
        val_fold: float = None,
    ):
        super().__init__()

        self.save_hyperparameters()

        self.train_df = pd.read_csv(train_csv_encoded_folded)
        self.test_df = pd.read_csv(test_csv)
        
        self.transform = create_transform(
            input_size=(self.hparams.image_size, self.hparams.image_size),
            crop_pct=1.0,
        )
        self.num_classes = len(set(
            list(self.train_df["individual_id"].values) + list(self.test_df["individual_id"].values)
        ))
        
    def setup(self, stage: Optional[str] = None):
        if stage == "fit" or stage is None:
            nb_train = int(self.hparams.val_split * len(self.train_df))
            # Split train df using fold
            if self.hparams.val_fold is not None:
                train_df = self.train_df[self.train_df.kfold != self.hparams.val_fold].reset_index(drop=True)
                val_df = self.train_df[self.train_df.kfold == self.hparams.val_fold].reset_index(drop=True)
            else:
                train_df = self.train_df[:-nb_train].reset_index(drop=True)
                val_df = self.train_df[-nb_train:].reset_index(drop=True)

            self.train_dataset = HappyWhaleDataset(train_df, transform=self.transform)
            self.val_dataset = HappyWhaleDataset(val_df, transform=self.transform)

        if stage == "test" or stage is None:
            self.test_dataset = HappyWhaleDataset(self.test_df, transform=self.transform)

    def train_dataloader(self) -> DataLoader:
        return self._dataloader(self.train_dataset, train=True)

    def val_dataloader(self) -> DataLoader:
        return self._dataloader(self.val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._dataloader(self.test_dataset)

    def _dataloader(self, dataset: HappyWhaleDataset, train: bool = False) -> DataLoader:
        return DataLoader(
            dataset,
            batch_size=self.hparams.batch_size,
            shuffle=train,
            num_workers=self.hparams.num_workers,
            pin_memory=True,
            drop_last=train,
        )

# Lightning Module

In [None]:
from happywhale.utils.arc_margin_product import ArcMarginProduct

class LitModule(pl.LightningModule):
    def __init__(
        self,
        model_name: str,
        pretrained: bool,
        drop_rate: float,
        embedding_size: int,
        num_classes: int,
        arc_s: float,
        arc_m: float,
        arc_easy_margin: bool,
        arc_ls_eps: float,
        optimizer: str,
        learning_rate: float,
        weight_decay: float,
        len_train_dl: int,
        batch_size: int,
        epochs: int
    ):
        super().__init__()

        self.save_hyperparameters()

        self.model = timm.create_model(model_name, pretrained=pretrained, drop_rate=drop_rate)
        self.embedding = nn.Linear(self.model.get_classifier().in_features, embedding_size)
        self.model.reset_classifier(num_classes=0, global_pool="avg")

        self.arc = ArcMarginProduct(
            in_features=embedding_size,
            out_features=num_classes,
            s=arc_s,
            m=arc_m,
            easy_margin=arc_easy_margin,
            ls_eps=arc_ls_eps,
        )

        self.loss_fn = F.cross_entropy

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        features = self.model(images)
        embeddings = self.embedding(features)
        return embeddings

    def configure_optimizers(self):
        optimizer = create_optimizer_v2(
            self.parameters(),
            opt=self.hparams.optimizer,
            lr=self.hparams.learning_rate,
            weight_decay=self.hparams.weight_decay,
        )
        
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            self.hparams.learning_rate,
            steps_per_epoch=self.hparams.len_train_dl,
            epochs=self.hparams.epochs,
        )
        scheduler = {"scheduler": scheduler, "interval": "step"}

        return [optimizer], [scheduler]

    def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, "train")

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor:
        return self._step(batch, "val")

    def _step(self, batch: Dict[str, torch.Tensor], step: str) -> torch.Tensor:
        images, targets = batch["image"], batch["target"]

        embeddings = self(images)
        outputs = self.arc(embeddings, targets, self.device)

        loss = self.loss_fn(outputs, targets)
        
        self.log(f"{step}_loss", loss)
        return loss

# Training

In [None]:
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import StochasticWeightAveraging


def train(
    train_csv_encoded_folded: str = str(TRAIN_CSV_ENCODED_FOLDED_PATH),
    test_csv: str = str(TEST_CSV_PATH),
    # val_fold: int = 0,
    val_split: float = 0.05,
    image_size: int = 256,
    batch_size: int = 64,
    num_workers: int = 4,
    model_name: str = "tf_efficientnet_b0",
    pretrained: bool = True,
    drop_rate: float = 0.0,
    embedding_size: int = 512,
    arc_s: float = 30.0,
    arc_m: float = 0.5,
    arc_easy_margin: bool = False,
    arc_ls_eps: float = 0.0,
    optimizer: str = "adamw",
    learning_rate: float = 3e-4,
    weight_decay: float = 1e-6,
    checkpoints_dir: str = str(CHECKPOINTS_DIR),
    accumulate_grad_batches: int = 1,
    # auto_lr_find: bool = False,
    auto_scale_batch_size: bool = True,
    # fast_dev_run: bool = False,
    gpus: int = 1,
    max_epochs: int = 10,
    precision: int = 16,
):
    # pl.seed_everything(42)

    datamodule = LitDataModule(
        train_csv_encoded_folded=train_csv_encoded_folded,
        test_csv=test_csv,
        val_split=val_split,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )
    
    datamodule.setup()
    print(datamodule.num_classes)
    len_train_dl = len(datamodule.train_dataloader())

    model = LitModule(
        model_name=model_name,
        pretrained=pretrained,
        drop_rate=drop_rate,
        embedding_size=embedding_size,
        num_classes=datamodule.num_classes,
        arc_s=arc_s,
        arc_m=arc_m,
        arc_easy_margin=arc_easy_margin,
        arc_ls_eps=arc_ls_eps,
        optimizer=optimizer,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        len_train_dl=len_train_dl,
        batch_size=batch_size,
        epochs=max_epochs
    )
    
    model_checkpoint = ModelCheckpoint(
        checkpoints_dir,
        filename=f"{model_name}_{image_size}",
        monitor="val_loss",
    )
    
    swa = StochasticWeightAveraging(swa_epoch_start=0.6)
    logger = CSVLogger(save_dir='logs/')
        
    trainer = pl.Trainer(
        accumulate_grad_batches=accumulate_grad_batches,
        # auto_lr_find=auto_lr_find,
        # auto_scale_batch_size=auto_scale_batch_size,
        benchmark=True,
        logger=logger,
        callbacks=[model_checkpoint],
        # deterministic=True,
        # fast_dev_run=fast_dev_run,
        gpus=gpus,
        max_epochs=max_epochs,
        precision=precision,
        # limit_train_batches=0.1,
        # limit_val_batches=0.1,
    )

    # trainer.tune(module, datamodule=datamodule)

    trainer.fit(model, datamodule=datamodule)
    
    return model, trainer

https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv

In [None]:
MODEL_NAME = "convnext_tiny"
IMAGE_SIZE = 224
BATCH_SIZE = 48

model, trainer = train(model_name=MODEL_NAME, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sn

metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()

# Inference

In [None]:
from happywhale.inference.infer import (
    load_eval_module,
    load_encoder,
    # get_embeddings,
    create_and_search_index,
    create_val_targets_df,
    create_distances_df,
    get_best_threshold,
    create_predictions_df,
)

def load_dataloaders(
    train_csv_encoded_folded: str,
    test_csv: str,
    val_fold: float,
    image_size: int,
    batch_size: int,
    num_workers: int,
) -> Tuple[DataLoader, DataLoader, DataLoader]:

    datamodule = LitDataModule(
        train_csv_encoded_folded=train_csv_encoded_folded,
        test_csv=test_csv,
        val_fold=val_fold,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    datamodule.setup()

    train_dl = datamodule.train_dataloader()
    val_dl = datamodule.val_dataloader()
    test_dl = datamodule.test_dataloader()

    return train_dl, val_dl, test_dl


@torch.inference_mode()
def get_embeddings(
    module: pl.LightningModule, dataloader: DataLoader, encoder: LabelEncoder, stage: str
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

    all_image_names = []
    all_embeddings = []
    all_targets = []

    for batch in tqdm(dataloader, desc=f"Creating {stage} embeddings"):
        image_names = batch["image_name"]
        images = batch["image"].to(module.device)
        targets = batch["target"].to(module.device)

        embeddings = module(images)

        all_image_names.append(image_names)
        all_embeddings.append(embeddings.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

    all_image_names = np.concatenate(all_image_names)
    all_embeddings = np.vstack(all_embeddings)
    all_targets = np.concatenate(all_targets)

    all_embeddings = normalize(all_embeddings, axis=1, norm="l2")
    all_targets = encoder.inverse_transform(all_targets)

    return all_image_names, all_embeddings, all_targets

In [None]:
from happywhale.settings import IDS_WITHOUT_BACKFIN_PATH, PUBLIC_SUBMISSION_CSV_PATH

def infer(
    checkpoint_path: str,
    train_csv_encoded_folded: str = str(TRAIN_CSV_ENCODED_FOLDED_PATH),
    test_csv: str = str(TEST_CSV_PATH),
    val_fold: float = 0.0,
    image_size: int = 256,
    batch_size: int = 64,
    num_workers: int = 2,
    k: int = 50,
):
    module = load_eval_module(checkpoint_path, torch.device("cuda"), lit_module_cls=LitModule)

    train_dl, val_dl, test_dl = load_dataloaders(
        train_csv_encoded_folded=train_csv_encoded_folded,
        test_csv=test_csv,
        val_fold=val_fold,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    encoder = load_encoder(ENCODER_CLASSES_PATH)
    train_image_names, train_embeddings, train_targets = get_embeddings(module, train_dl, encoder, stage="train")
    val_image_names, val_embeddings, val_targets = get_embeddings(module, val_dl, encoder, stage="val")
    test_image_names, test_embeddings, test_targets = get_embeddings(module, test_dl, encoder, stage="test")

    D, I = create_and_search_index(module.hparams.embedding_size, train_embeddings, val_embeddings, k)  # noqa: E741
    print("Created index with train_embeddings")

    val_targets_df = create_val_targets_df(train_targets, val_image_names, val_targets)
    print(f"val_targets_df=\n{val_targets_df.head()}")

    val_df = create_distances_df(val_image_names, train_targets, D, I, "val")
    print(f"val_df=\n{val_df.head()}")

    best_th, best_cv = get_best_threshold(val_targets_df, val_df, adjust_th=True)
    print("val_targets_df:")
    display(val_targets_df.describe())

    train_embeddings = np.concatenate([train_embeddings, val_embeddings])
    train_targets = np.concatenate([train_targets, val_targets])
    print("Updated train_embeddings and train_targets with val data")

    D, I = create_and_search_index(module.hparams.embedding_size, train_embeddings, test_embeddings, k)  # noqa: E741
    print("Created index with train_embeddings")

    test_df = create_distances_df(test_image_names, train_targets, D, I, "test")
    print(f"test_df=\n{test_df.head()}")

    predictions = create_predictions_df(test_df, best_th)
    print(f"predictions.head()={predictions.head()}")
    
    # Fix missing predictions
    # From https://www.kaggle.com/code/jpbremer/backfins-arcface-tpu-effnet/notebook
    public_predictions = pd.read_csv(PUBLIC_SUBMISSION_CSV_PATH)
    ids_without_backfin = np.load(IDS_WITHOUT_BACKFIN_PATH, allow_pickle=True)

    ids2 = public_predictions["image"][~public_predictions["image"].isin(predictions["image"])]
    predictions = pd.concat(
        [
            predictions[~(predictions["image"].isin(ids_without_backfin))],
            public_predictions[public_predictions["image"].isin(ids_without_backfin)],
            public_predictions[public_predictions["image"].isin(ids2)],
        ]
    )
    predictions = predictions[["image","predictions"]].drop_duplicates()
    predictions.to_csv(SUBMISSION_CSV_PATH, index=False)

In [None]:
infer(checkpoint_path=CHECKPOINTS_DIR / f"{MODEL_NAME}_{IMAGE_SIZE}.ckpt", image_size=IMAGE_SIZE, batch_size=BATCH_SIZE)

In [None]:
!head submission.csv