# Explore 🔎 data...

Lets see what annotation and images we have :)

In [None]:
! ls -l /kaggle/input/herbarium-2022-fgvc9

### Loading the train and test meta

In [None]:
import os
import json
import pandas as pd
import seaborn as sn
import matplotlib.pyplot as plt
from pprint import pprint

sn.set()

PATH_DATASET = "/kaggle/input/herbarium-2022-fgvc9"

with open(os.path.join(PATH_DATASET, "train_metadata.json")) as fp:
    train_data = json.load(fp)

with open(os.path.join(PATH_DATASET, "test_metadata.json")) as fp:
    test_data = json.load(fp)

pprint(train_data.keys())
pprint(len(test_data))

### Brief visualisations

In [None]:
train_annotations = pd.DataFrame(train_data['annotations'])
display(train_annotations.head(3))

axs = train_annotations[["genus_id", "institution_id", "category_id"]].hist(bins=100, sharey=True, figsize=(8, 8), grid=True, layout=(3, 1))
_= [ax.set_yscale('log') for ax in axs[0]]

In [None]:
train_categories = pd.DataFrame(train_data['categories']).set_index("category_id")
display(train_categories.head())
# (train_categories.index - train_categories.category_id).hist()

In [None]:
train_genera = pd.DataFrame(train_data['genera']).set_index("genus_id")
display(train_genera.head())

In [None]:
train_institutions = pd.DataFrame(train_data['institutions']).set_index("institution_id")
display(train_institutions.head())

In [None]:
train_images = pd.DataFrame(train_data['images']).set_index("image_id")
display(train_images.head())

In [None]:
train_distances = pd.DataFrame(train_data['distances'])
display(train_distances.head())

fig = plt.figure(figsize=(18, 18))
heat = train_distances.pivot(index="genus_id_y", columns="genus_id_x", values="distance")
_= sn.heatmap(heat, ax=fig.gca())

### Fused annotaions

In [None]:
df_train = pd.merge(train_annotations, train_images, how="left", right_index=True, left_on="image_id")
df_train = pd.merge(df_train, train_categories, how="left", right_index=True, left_on="category_id")
df_train = pd.merge(df_train, train_institutions, how="left", right_index=True, left_on="institution_id")
# df_train = pd.merge(df_train, train_genera, how="left", right_index=True, left_on="genus_id")

display(df_train.head())
print(f"training images: {len(df_train)}")

## Sample images 

In [None]:
# shuffle
df_train.sample(frac=1)

fig, axarr = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
for i, (_, row) in enumerate(df_train[:10].iterrows()):
    img_path = os.path.join(PATH_DATASET, "train_images", row["file_name"])
    img = plt.imread(img_path)
    axarr[i // 5, i % 5].imshow(img)
#     print(row)
fig.tight_layout()

In [None]:
import glob
import numpy as np
from tqdm.auto import tqdm
from joblib import Parallel, delayed

def _color_means(img_path):
    img = plt.imread(img_path)
    means = {i: np.mean(img[..., i]) / 255.0 for i in range(3)}
    std = {i: np.std(img[..., i]) / 255.0 for i in range(3)}
    return means, std

images = glob.glob(os.path.join(PATH_DATASET, "train_images", "*", "*", "*.jpg"))
# images += glob.glob(os.path.join(PATH_DATASET, "test_images", "*", "*.jpg"))
clr_mean_std = Parallel(n_jobs=os.cpu_count())(delayed(_color_means)(fn) for fn in tqdm(images[:15000]))

In [None]:
img_color_mean = pd.DataFrame([c[0] for c in clr_mean_std]).describe()
display(img_color_mean)
img_color_std = pd.DataFrame([c[1] for c in clr_mean_std]).describe()
display(img_color_std)

img_color_mean = list(img_color_mean.T["mean"])
img_color_std = list(img_color_std.T["mean"])
print(img_color_mean, img_color_std)

# Training with Lightning⚡Flash

**Follow the example:** https://lightning-flash.readthedocs.io/en/stable/reference/image_classification.html


**Later you would need to adjust the image size to used model:**

| **Base model** | resolution |
|----------------|------------|
| EfficientNetB0 | 224        |
| EfficientNetB1 | 240        |
| EfficientNetB2 | 260        |
| EfficientNetB3 | 300        |
| EfficientNetB4 | 380        |
| EfficientNetB5 | 456        |
| EfficientNetB6 | 528        |
| EfficientNetB7 | 600        |

In [None]:
!pip install -q effdet "icevision[all]" 'lightning-flash[image]'
# !pip install -q "pytorch-lightning==1.4.*"
!pip uninstall -y wandb

In [None]:
!pip download -q effdet "icevision[all]" 'lightning-flash[image]' --dest frozen_packages --prefer-binary
!rm frozen_packages/torch-*
!ls -l frozen_packages

In [None]:
import torch

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier

## 1. Create the DataModule 🗄️

In [None]:
from dataclasses import dataclass
from torchvision import transforms as T
from typing import Tuple, Callable
from flash.core.data.io.input_transform import InputTransform

@dataclass
class ImageClassificationInputTransform(InputTransform):

    image_size: Tuple[int, int] = (224, 224)

    def input_per_sample_transform(self):
        return T.Compose([
            T.ToTensor(),
            T.Resize(self.image_size),
            # T.Normalize([0.778, 0.756, 0.709], [0.246, 0.250, 0.253]),
            T.Normalize(img_color_mean, img_color_std),
        ])

    def train_input_per_sample_transform(self):
        return T.Compose([
            T.ToTensor(),
            T.Resize(self.image_size),
            # T.Normalize([0.778, 0.756, 0.709], [0.246, 0.250, 0.253]),
            T.Normalize(img_color_mean, img_color_std),
            T.RandomHorizontalFlip(),
            T.RandomAffine(degrees=10, scale=(0.9, 1.1), translate=(0.1, 0.1)),
            # T.ColorJitter(),
            # T.RandomAutocontrast(),
            # T.RandomPerspective(distortion_scale=0.1),
        ])

    def target_per_sample_transform(self) -> Callable:
        return torch.as_tensor

In [None]:
datamodule = ImageClassificationData.from_data_frame(
    input_field="file_name",
    target_fields="category_id",
    # for simplicity take just half of the data
    train_data_frame=df_train[:len(df_train) // 2],
    train_images_root=os.path.join(PATH_DATASET, "train_images"),
    train_transform=ImageClassificationInputTransform,
    batch_size=128,
    transform_kwargs={"image_size": (224, 224)},
    num_workers=3,
)

## 2. Build the task ⚙️

In [None]:
model = ImageClassifier(
    backbone="efficientnet_b0",
    num_classes=datamodule.num_classes,
    pretrained=True,
    optimizer="AdamW",
    learning_rate=0.001,
)

## 3. Finetune the model 🛠️

In [None]:
from pytorch_lightning.loggers import CSVLogger
# from pytorch_lightning.callbacks import StochasticWeightAveraging

# Trainer Args
GPUS = int(torch.cuda.is_available())  # Set to 1 if GPU is enabled for notebook

# swa = StochasticWeightAveraging(swa_epoch_start=0.6)
logger = CSVLogger(save_dir='logs/')

trainer = flash.Trainer(
    max_epochs=3,
    # gradient_clip_val=0.01,
    gpus=GPUS,
    precision=16 if GPUS else 32,
    logger=logger,
    accumulate_grad_batches=32,
)

In [None]:
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

trainer.save_checkpoint("image_classification_model.pt")

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())
g = sn.relplot(data=metrics, kind="line")
plt.gcf().set_size_inches(15, 5)

## Inference 🎉

In [None]:
test_images = pd.DataFrame(test_data).set_index("image_id")
display(test_images.head())
print(f"inference for {len(test_images)} images")

In [None]:
datamodule = ImageClassificationData.from_data_frame(
    input_field="file_name",
    # target_fields="category_id",
    predict_data_frame=test_images,
    # for simplicity take just fraction of the data
    # predict_data_frame=test_images[:len(test_images) // 100],
    predict_images_root=os.path.join(PATH_DATASET, "test_images"),
    batch_size=16,
    transform_kwargs={"image_size": (224, 224)},
    num_workers=2,
)

In [None]:
predictions = []
for lbs in trainer.predict(model, datamodule=datamodule, output="labels"):
    # lbs = [torch.argmax(p["preds"].float()).item() for p in preds]
    predictions += lbs

In [None]:
submission = pd.DataFrame({"id": test_images.index, "Predicted": predictions}).set_index("id")
submission.to_csv("submission.csv")

! head submission.csv