# 🕵Explore the provdided data

In [None]:
!ls -l /kaggle/input/happy-whale-and-dolphin

PATH_DATASET = "/kaggle/input/happy-whale-and-dolphin"

## Browsing the metadata

In [None]:
import os
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt

sn.set()

df_train = pd.read_csv(os.path.join(PATH_DATASET, "train.csv"))
display(df_train.head())
print(f"Dataset size: {len(df_train)}")
print(f"Unique ids: {len(df_train['individual_id'].unique())}")

Lets see how many speaced we have in the database...

In [None]:
counts_imgs = df_train["species"].value_counts()
counts_inds = df_train.drop_duplicates("individual_id")["species"].value_counts()

ax = pd.concat({"per Images": counts_imgs, "per Individuals": counts_inds}, axis=1).plot.barh(grid=True, figsize=(7, 10))
ax.set_xscale('log')

And compare they with unique individuals... 

**Note:** that the counts are in log scale

In [None]:
import numpy as np
from pprint import pprint

species_individuals = {}
for name, dfg in df_train.groupby("species"):
    species_individuals[name] = dfg["individual_id"].value_counts()

si_max = max(list(map(len, species_individuals.values())))
si = {n: [0] * si_max for n in species_individuals}
for n, counts in species_individuals.items():
    si[n][:len(counts)] = list(np.log(counts))
si = pd.DataFrame(si)

In [None]:
import seaborn as sn

fig = plt.figure(figsize=(10, 8))
ax = sn.heatmap(si[:500].T, cmap="BuGn", ax=fig.gca())

And see the top individulas

In [None]:
ax = df_train["individual_id"].value_counts(ascending=True)[-50:].plot.barh(figsize=(3, 8), grid=True)  # ascending=True

## Browse some images

In [None]:
nb_species = len(df_train["species"].unique())
fig, axarr = plt.subplots(ncols=5, nrows=nb_species, figsize=(12, nb_species * 2))

for i, (name, dfg) in enumerate(df_train.groupby("species")):
    axarr[i, 0].set_title(name)
    for j, (_, row) in enumerate(dfg[:5].iterrows()):
        im_path = os.path.join(PATH_DATASET, "train_images", row["image"])
        img = plt.imread(im_path)
        axarr[i, j].imshow(img)
        axarr[i, j].set_axis_off()

# Baseline: species classification with Lightning⚡Flash

Follow the example: https://lightning-flash.readthedocs.io/en/stable/reference/image_classification.html

In [None]:
!pip install -q effdet "icevision[all]" 'lightning-flash[image]'
# !pip install -q "pytorch-lightning==1.4.*"
!pip uninstall -y wandb

In [None]:
!pip download -q effdet "icevision[all]" 'lightning-flash[image]' --dest frozen_packages --prefer-binary
!rm frozen_packages/torch-*
!ls -l frozen_packages

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

## 1. Create the DataModule 🗄️

In [None]:
datamodule = ImageClassificationData.from_data_frame(
    input_field="image",
    target_fields="species",
    # for simplicity take just half of the data
    train_data_frame=df_train[:len(df_train) // 2],
    train_images_root=os.path.join(PATH_DATASET, "train_images"),
    batch_size=64,
    transform_kwargs={"image_size": (300, 300)},
    val_split=0.1,
    num_workers=2,
)

## 2. Build the task ⚙️

In [None]:
from torchmetrics import F1

model = ImageClassifier(
    backbone="efficientnet_b3",
    labels=datamodule.labels,
    metrics=F1(),
    pretrained=True,
    optimizer="AdamW",
    learning_rate=0.005,
)

## 3. Finetune the model 🛠️

In [None]:
from pytorch_lightning.loggers import CSVLogger
# from pytorch_lightning.callbacks import StochasticWeightAveraging

# Trainer Args
GPUS = int(torch.cuda.is_available())  # Set to 1 if GPU is enabled for notebook

# swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    max_epochs=5,
    # gradient_clip_val=0.01,
    gpus=GPUS,
    precision=16 if GPUS else 32,
    logger=logger,
)

In [None]:
trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 2))
# trainer.finetune(model, datamodule=datamodule, strategy="no_freeze")

trainer.save_checkpoint("image_classification_model.pt")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(12, 4)
plt.grid()