# 🐋🐬 Lightning⚡Flash 🦈 BackFin & CNN & ArcFace

Let's train [`timm`](https://github.com/rwightman/pytorch-image-models) models with [PyTorch Lightning Flash](https://github.com/PyTorchLightning/lightning-flash)!

## Sources
- [[Pytorch] ArcFace + GeM Pooling Starter](https://www.kaggle.com/debarshichanda/pytorch-arcface-gem-pooling-starter)
- [FAISS Pytorch Inference](https://www.kaggle.com/debarshichanda/faiss-pytorch-inference)
- [backfintfrecrods](https://www.kaggle.com/datasets/jpbremer/backfintfrecords) dataset

## Nice Lightning `Trainer` Flags to try
- Run a learning rate finder algorithm: `auto_lr_find=True`
- Automatically try to find the largest batch size that fits into memory: `auto_scale_batch_size=True`
- Quickly check whether everything runs fine: `fast_dev_run=True`
- Train on multiple GPUs: `gpus=2` (if you use multiple GPUs, also set `accelerator=ddp`)
- Train with half precision: `precision=16`
- Use Stochastic Weight Averaging: `stochastic_weight_avg=True`

# Installations

In [None]:
!pip install -q torch==1.10.1+cu102 torchvision==0.11.2+cu102 torchaudio==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html
!pip install -q faiss-gpu 'lightning-flash[image]'
!pip install -q happywhale -f ../input/-pytorch-lightning-happywhale-pkg
!pip install -q -U timm segmentation-models-pytorch
!pip install -q Pillow==9.0.1
!pip uninstall -y torchtext  # segmentation-models-pytorch
!pip list | grep torch
!pip list | grep lightning
!pip list | grep happywhale

# Imports

In [None]:
import math
from typing import Callable, Dict, Optional, Tuple
from pathlib import Path

import faiss
import flash
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash.image import ImageClassificationData, ImageClassifier
from PIL import Image
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import normalize, LabelEncoder

# Paths & Settings

In [None]:
INPUT_DIR = Path("..") / "input"
OUTPUT_DIR = Path("/") / "kaggle" / "working"

DATA_ROOT_DIR = INPUT_DIR / "convert-backfintfrecords" / "happy-whale-and-dolphin-backfin"
TRAIN_DIR = DATA_ROOT_DIR / "train_images"
TEST_DIR = DATA_ROOT_DIR / "test_images"
TRAIN_CSV_PATH = DATA_ROOT_DIR / "train.csv"
SAMPLE_SUBMISSION_CSV_PATH = DATA_ROOT_DIR / "sample_submission.csv"

N_SPLITS = 5

ENCODER_CLASSES_PATH = OUTPUT_DIR / "encoder_classes.npy"
TEST_CSV_PATH = OUTPUT_DIR / "test.csv"
TRAIN_CSV_ENCODED_FOLDED_PATH = OUTPUT_DIR / "train_encoded_folded.csv"
CHECKPOINTS_DIR = OUTPUT_DIR / "checkpoints"
SUBMISSION_CSV_PATH = OUTPUT_DIR / "submission.csv"

# Prepare DataFrames

In [None]:
def get_image_path(id: str, dir: Path) -> str:
    return f"{dir / id}"

## Train DataFrame

In [None]:
train_df = pd.read_csv(TRAIN_CSV_PATH)

train_df["image_path"] = train_df["image"].apply(get_image_path, dir=TRAIN_DIR)

encoder = LabelEncoder()
train_df["individual_id"] = encoder.fit_transform(train_df["individual_id"])
np.save(ENCODER_CLASSES_PATH, encoder.classes_)

skf = StratifiedKFold(n_splits=N_SPLITS)
for fold, (_, val_) in enumerate(skf.split(X=train_df, y=train_df.individual_id)):
    train_df.loc[val_, "kfold"] = fold
    
train_df.to_csv(TRAIN_CSV_ENCODED_FOLDED_PATH, index=False)
train_df.head()

## Test DataFrame

In [None]:
# Use sample submission csv as template
test_df = pd.read_csv(SAMPLE_SUBMISSION_CSV_PATH)
test_df["image_path"] = test_df["image"].apply(get_image_path, dir=TEST_DIR)

test_df.drop(columns=["predictions"], inplace=True)

# Dummy id
test_df["individual_id"] = 0
test_df.to_csv(TEST_CSV_PATH, index=False)
test_df.head()

# ArcMargin Loss

In [None]:
from happywhale.utils.arc_margin_product import ArcMarginProduct

class ArcLoss(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        num_classes: int,
        arc_s: float,
        arc_m: float,
        arc_easy_margin: bool,
        arc_ls_eps: float,
    ) -> None:
        super().__init__()
        self.arc = ArcMarginProduct(
            in_features=embedding_size,
            out_features=num_classes,
            s=arc_s,
            m=arc_m,
            easy_margin=arc_easy_margin,
            ls_eps=arc_ls_eps,
        )

    def forward(self, features: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        outputs = self.arc(features, targets)
        loss = F.cross_entropy(outputs, targets)
        return loss

# Training

In [None]:
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import StochasticWeightAveraging


def train(
    train_csv_encoded_folded: str = str(TRAIN_CSV_ENCODED_FOLDED_PATH),
    # val_fold: float = 0.0,
    val_split: float = 0.05,
    image_size: int = 384,
    batch_size: int = 32,
    num_workers: int = 2,
    model_name: str = "convnext_small",
    pretrained: bool = True,
    embedding_size: int = 512,
    arc_s: float = 30.0,
    arc_m: float = 0.5,
    arc_easy_margin: bool = False,
    arc_ls_eps: float = 0.0,
    learning_rate: float = 3e-4,
    weight_decay: float = 1e-6,
    project: str = "kaggle-happywhale-flash",
    checkpoints_dir: str = str(CHECKPOINTS_DIR),
    # auto_lr_find: bool = False,
    # fast_dev_run: bool = False,
    gpus: int = torch.cuda.device_count(),
    max_epochs: int = 10,
    precision: int = 16,
):
    # pl.seed_everything(42)

    train_val_df = pd.read_csv(train_csv_encoded_folded)
    # train_df = train_val_df[train_val_df.kfold != val_fold].reset_index(drop=True)
    # val_df = train_val_df[train_val_df.kfold == val_fold].reset_index(drop=True)

    datamodule = ImageClassificationData.from_data_frame(
        input_field="image_path",
        target_fields="individual_id",
        train_data_frame=train_val_df,
        # train_data_frame=train_df,
        # val_data_frame=val_df,
        transform_kwargs={
            "image_size": (image_size, image_size),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
        },
        batch_size=batch_size,
        num_workers=num_workers,
        val_split=val_split,
    )

    arc_loss = ArcLoss(
        embedding_size=embedding_size,
        num_classes=datamodule.num_classes,
        arc_s=arc_s,
        arc_m=arc_m,
        arc_ls_eps=arc_ls_eps,
        arc_easy_margin=arc_easy_margin,
    )

    model = ImageClassifier(
        num_classes=embedding_size,
        backbone=model_name,
        pretrained=pretrained,
        loss_fn=arc_loss,
        optimizer=("AdamW", {"lr": learning_rate, "weight_decay": weight_decay}),
        learning_rate=learning_rate,
        metrics=[],
    )

    model_checkpoint = ModelCheckpoint(
        checkpoints_dir,
        filename=f"{model_name}_{image_size}",
        monitor="val_arcloss",
    )
    
    swa = StochasticWeightAveraging(swa_epoch_start=0.6)
    logger = CSVLogger(save_dir='logs/')

    trainer = flash.Trainer(
        # benchmark=True,
        logger=logger,
        # auto_lr_find=auto_lr_find,
        # fast_dev_run=fast_dev_run,
        callbacks=[model_checkpoint],
        # deterministic=True,
        gpus=gpus,
        max_epochs=max_epochs,
        precision=precision,
        # limit_train_batches=0.1,
        # limit_val_batches=0.1,
    )

    # trainer.tune(model, datamodule=datamodule)
    trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")
    # trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 1))
    # trainer.fit(model, datamodule=datamodule)
    return model, trainer, logger

https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv

In [None]:
MODEL_NAME = "tf_efficientnet_b4_ns"
IMAGE_SIZE = 380
BATCH_SIZE = 32

model, trainer, log = train(model_name=MODEL_NAME, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, max_epochs=5)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sn

metrics = pd.read_csv(f'{log.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()

# Inference

In [None]:
from happywhale.inference.infer import (
    load_eval_module,
    load_encoder,
    # get_embeddings,
    create_and_search_index,
    create_val_targets_df,
    create_distances_df,
    get_best_threshold,
    create_predictions_df,
)

def load_dataloaders(
    train_csv_encoded_folded: str,
    test_csv: str,
    val_fold: float,
    image_size: int,
    batch_size: int,
    num_workers: int,
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    
    train_val_df = pd.read_csv(train_csv_encoded_folded)
    train_df = train_val_df[train_val_df.kfold != val_fold].reset_index(drop=True)
    val_df = train_val_df[train_val_df.kfold == val_fold].reset_index(drop=True)
    test_df = pd.read_csv(test_csv)

    datamodule = ImageClassificationData.from_data_frame(
        input_field="image_path",
        target_fields="individual_id",
        train_data_frame=train_df,
        val_data_frame=val_df,
        test_data_frame=test_df,
        transform_kwargs={
            "image_size": (image_size, image_size),
            "mean": [0.485, 0.456, 0.406],
            "std": [0.229, 0.224, 0.225],
        },
        batch_size=batch_size,
        num_workers=num_workers,
    )
    datamodule.setup()

    train_dl = datamodule.train_dataloader()
    val_dl = datamodule.val_dataloader()
    test_dl = datamodule.test_dataloader()

    return train_dl, val_dl, test_dl


@torch.inference_mode()
def get_embeddings(
    module: ImageClassifier, dataloader: DataLoader, encoder: LabelEncoder, stage: str
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

    all_image_names = []
    all_embeddings = []
    all_targets = []

    for batch in tqdm(dataloader, desc=f"Creating {stage} embeddings"):
        image_names = [Path(item["filepath"]).name for item in batch["metadata"]]
        images = batch["input"].to(module.device)
        targets = batch["target"].to(module.device)

        embeddings = module(images)

        all_image_names.append(image_names)
        all_embeddings.append(embeddings.cpu().numpy())
        all_targets.append(targets.cpu().numpy())

    all_image_names = np.concatenate(all_image_names)
    all_embeddings = np.vstack(all_embeddings)
    all_targets = np.concatenate(all_targets)

    all_embeddings = normalize(all_embeddings, axis=1, norm="l2")
    all_targets = encoder.inverse_transform(all_targets)

    return all_image_names, all_embeddings, all_targets

In [None]:
from happywhale.settings import IDS_WITHOUT_BACKFIN_PATH, PUBLIC_SUBMISSION_CSV_PATH

def infer(
    checkpoint_path: str,
    train_csv_encoded_folded: str = str(TRAIN_CSV_ENCODED_FOLDED_PATH),
    test_csv: str = str(TEST_CSV_PATH),
    val_fold: float = 0.0,
    image_size: int = 256,
    batch_size: int = 64,
    num_workers: int = 2,
    embedding_size: int = 512,
    k: int = 50,
):
    module = load_eval_module(checkpoint_path, torch.device("cuda"), lit_module_cls=ImageClassifier)

    train_dl, val_dl, test_dl = load_dataloaders(
        train_csv_encoded_folded=train_csv_encoded_folded,
        test_csv=test_csv,
        val_fold=val_fold,
        image_size=image_size,
        batch_size=batch_size,
        num_workers=num_workers,
    )

    encoder = load_encoder(ENCODER_CLASSES_PATH)

    train_image_names, train_embeddings, train_targets = get_embeddings(module, train_dl, encoder, stage="train")
    val_image_names, val_embeddings, val_targets = get_embeddings(module, val_dl, encoder, stage="val")
    test_image_names, test_embeddings, test_targets = get_embeddings(module, test_dl, encoder, stage="test")

    D, I = create_and_search_index(embedding_size, train_embeddings, val_embeddings, k)  # noqa: E741
    print("Created index with train_embeddings")

    val_targets_df = create_val_targets_df(train_targets, val_image_names, val_targets)
    print(f"val_targets_df=\n{val_targets_df.head()}")

    val_df = create_distances_df(val_image_names, train_targets, D, I, "val")
    print(f"val_df=\n{val_df.head()}")

    best_th, best_cv = get_best_threshold(val_targets_df, val_df, adjust_th=True)
    print(f"val_targets_df=\n{val_targets_df.describe()}")

    train_embeddings = np.concatenate([train_embeddings, val_embeddings])
    train_targets = np.concatenate([train_targets, val_targets])
    print("Updated train_embeddings and train_targets with val data")

    D, I = create_and_search_index(embedding_size, train_embeddings, test_embeddings, k)  # noqa: E741
    print("Created index with train_embeddings")

    test_df = create_distances_df(test_image_names, train_targets, D, I, "test")
    print(f"test_df=\n{test_df.head()}")

    predictions = create_predictions_df(test_df, best_th)
    print(f"predictions.head()={predictions.head()}")
    
    # Fix missing predictions
    # From https://www.kaggle.com/code/jpbremer/backfins-arcface-tpu-effnet/notebook
    public_predictions = pd.read_csv(PUBLIC_SUBMISSION_CSV_PATH)
    ids_without_backfin = np.load(IDS_WITHOUT_BACKFIN_PATH, allow_pickle=True)

    ids2 = public_predictions["image"][~public_predictions["image"].isin(predictions["image"])]

    predictions = pd.concat(
        [
            predictions[~(predictions["image"].isin(ids_without_backfin))],
            public_predictions[public_predictions["image"].isin(ids_without_backfin)],
            public_predictions[public_predictions["image"].isin(ids2)],
        ]
    )
    predictions = predictions[["image","predictions"]].drop_duplicates()
    predictions.to_csv(SUBMISSION_CSV_PATH, index=False)

In [None]:
infer(checkpoint_path=CHECKPOINTS_DIR / f"{MODEL_NAME}_{IMAGE_SIZE}.ckpt", image_size=IMAGE_SIZE, batch_size=BATCH_SIZE)

In [None]:
!head submission.csv