# Explore 🕵 provdided data

In [None]:
!ls -l /kaggle/input/happy-whale-and-dolphin

PATH_DATASET = "/kaggle/input/happy-whale-and-dolphin"

## Browsing the metadata

In [None]:
import os
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt

sn.set()

df_train = pd.read_csv(os.path.join(PATH_DATASET, "train.csv"))
display(df_train.head())
print(f"Dataset size: {len(df_train)}")
print(f"Unique ids: {len(df_train['individual_id'].unique())}")

Lets see how many speaced we have in the database...

In [None]:
counts_imgs = df_train["species"].value_counts()
counts_inds = df_train.drop_duplicates("individual_id")["species"].value_counts()

ax = pd.concat({"per Images": counts_imgs, "per Individuals": counts_inds}, axis=1).plot.barh(grid=True, figsize=(7, 10))
ax.set_xscale('log')

And compare they with unique individuals... 

**Note:** that the counts are in log scale

In [None]:
import numpy as np
from pprint import pprint

species_individuals = {}
for name, dfg in df_train.groupby("species"):
    species_individuals[name] = dfg["individual_id"].value_counts()

si_max = max(list(map(len, species_individuals.values())))
si = {n: [0] * si_max for n in species_individuals}
for n, counts in species_individuals.items():
    si[n][:len(counts)] = list(np.log(counts))
si = pd.DataFrame(si)

In [None]:
import seaborn as sn

fig = plt.figure(figsize=(10, 8))
ax = sn.heatmap(si[:500].T, cmap="BuGn", ax=fig.gca())

# Baseline: embedding with Lightning⚡Flash

Follow the example: https://lightning-flash.readthedocs.io/en/stable/reference/image_embedder.html

In [None]:
!pip install -q vissl fairscale 'lightning-flash[image]'
# temp fix untill it is merged to master & released...
!pip install -q -U "https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/master.zip"
!pip uninstall -y wandb

In [None]:
!pip download -q vissl fairscale 'lightning-flash[image]' --dest frozen_packages --prefer-binary
!pip wheel -q "https://github.com/PyTorchLightning/lightning-flash/archive/refs/heads/master.zip" --wheel-dir frozen_packages
!rm frozen_packages/torch-*
!ls -l frozen_packages

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

## 1. Create the Dataset 🗄️

In [None]:
from PIL import Image
from torch.utils.data import Dataset


class HappyWhaleDataset(Dataset):
    def __init__(self, df: pd.DataFrame, path_folder: str, transform = None):
        self.df = df
        self.transform = transform

        self.image_names = self.df["image"].values
        self.image_paths = [os.path.join(path_folder, n) for n in self.image_names]
        self.targets = list(self.df["individual_id"])
        self.uq_targets = sorted(set(self.targets))
        lut = {v: k for k, v in dict(enumerate(self.uq_targets)).items()}
        self.labels = [lut[ind] for ind in self.targets]

    def __getitem__(self, idx: int) -> tuple:
        img_path = self.image_paths[idx]
        img = Image.open(img_path).convert('RGB')

        if self.transform:
            img = self.transform(img)

        lb = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, lb

    def __len__(self) -> int:
        return len(self.df)

In [None]:
dataset = HappyWhaleDataset(
    # df=df_train,
    # ToDo: use full dataset
    df=df_train[:int(len(df_train) * 0.6)],
    path_folder=f"{PATH_DATASET}/train_images",
)

fig, axarr = plt.subplots(nrows=2, ncols=5, figsize=(14, 4))
for i in range(10):
    img, ind = dataset[i]
    axarr[i // 5, i % 5].imshow(img)
    axarr[i // 5, i % 5].set_title(ind)
fig.tight_layout()

In [None]:
datamodule = ImageClassificationData.from_datasets(
    train_dataset=dataset,
    batch_size=64,
    num_workers=6,
)

## 2. Build the task ⚙️

In [None]:
embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="simclr_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 256},
    pretraining_transform_kwargs={"size_crops": [196]},
)

## 3. Finetune the model 🛠️

In [None]:
from pytorch_lightning.loggers import CSVLogger
# from pytorch_lightning.callbacks import StochasticWeightAveraging

# Trainer Args
GPUS = int(torch.cuda.is_available())  # Set to 1 if GPU is enabled for notebook

# swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    max_epochs=5,
    # gradient_clip_val=0.01,
    gpus=GPUS,
    precision=16 if GPUS else 32,
    logger=logger,
)

In [None]:
trainer.fit(embedder, datamodule=datamodule)

trainer.save_checkpoint("image_embedder_model.pt")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(15, 5)

## Run predictions 🎉

In [None]:
import glob

imgs = glob.glob(f"{PATH_DATASET}/test_images/*.jpg")
datamodule = ImageClassificationData.from_files(
    predict_files=imgs[:5],
    batch_size=12
)

In [None]:
embedder.input_transform = None
embeddings = trainer.predict(embedder, datamodule=datamodule)

# list of embeddings for images sent to the predict function
pprint(embeddings)