# Kaggle: Plant Pathology 2021 - FGVC8

In [None]:
! pip install pytorch-lightning torchmetrics -q
! pip list | grep torch
! nvidia-smi

## Data exploration

Checking what data do we have available and what is the labels distribution...

In [None]:
# jsu to see what is the data location
! ls /kaggle/input/plant-pathology-2021-fgvc8 -l

Looking in the training dataset table, what colums and what is the data representation...

In [None]:
%matplotlib inline

import os
import json
import pandas as pd
from pprint import pprint

base_path = '/kaggle/input/plant-pathology-2021-fgvc8'
path_csv = os.path.join(base_path, 'train.csv')
train_data = pd.read_csv(path_csv)
print(train_data.head())

We can see that each image can have multiple labels so lets check what is the mos common label count...

*The target classes, a space delimited list of all diseases found in the image.
Unhealthy leaves with too many diseases to classify visually will have the complex class, and may also have a subset of the diseases identified.*

In [None]:
import numpy as np

train_data['nb_classes'] = [len(lbs.split(" ")) for lbs in train_data['labels']]
lb_hist = dict(zip(range(10), np.bincount(train_data['nb_classes'])))
pprint(lb_hist)

Browse the label distribution, enrolling all labels in the dataset, so in case an image has two labels both are used in this stat...

In [None]:
import itertools
import seaborn as sns

labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in train_data['labels']]))

ax = sns.countplot(y=labels_all, orient='v')
ax.grid()

Get some stat for labels combinations...

In [None]:
labels_unique = set(labels_all)
print(f"unique labels: {labels_unique}")
train_data['labels_sorted'] = [" ".join(sorted(lbs.split(" "))) for lbs in train_data['labels']]

labels_combine = {}
for comb in train_data['labels_sorted']:
    labels_combine[comb] = labels_combine.get(comb, 0) + 1

show_counts = '\n'.join(sorted(f'\t{k}: {v}' for k, v in labels_combine.items()))
print(f"unique combinations: \n" + show_counts)
print(f"total: {sum(labels_combine.values())}")

And add visualisation over each case, so five a few examples per labe combination...

In [None]:
import matplotlib.pyplot as plt

nb_samples = 6
n, m = len(np.unique(train_data['labels_sorted'])), nb_samples,
fig, axarr = plt.subplots(nrows=n, ncols=m, figsize=(m * 2, n * 2))
for ilb, (lb, df_) in enumerate(train_data.groupby('labels_sorted')):
    img_names = list(df_['image'])
    for i in range(m):
        img_name = img_names[i]
        img = plt.imread(os.path.join(base_path, f"train_images/{img_name}"))
        axarr[ilb, i].imshow(img)
        if i == 0:
            axarr[ilb, i].set_title(f"{lb} #{len(df_)}")
        axarr[ilb, i].set_xticks([])
        axarr[ilb, i].set_yticks([])
plt.axis('off')

## Dataset & DataModule

Creating standard PyTorch dataset to define how the data shall be loaded and set representations.
We define the sample pair as:
 - RGB image
 - one-hot labels encoding

A DataModule standardizes the training, val, test splits, data preparation and transforms.
The main advantage is consistent data splits, data preparation and transforms across models.

In [None]:
import os
import torch

from kaggle_plantpatho.data import PlantPathologyDataset

dataset = PlantPathologyDataset()

# quick view
fig = plt.figure(figsize=(9, 6))
for i in range(9):
    img, lb = dataset[i]
    ax = fig.add_subplot(3, 3, i + 1, xticks=[], yticks=[])
    ax.imshow(img)
    ax.set_title(lb)

Let us add also a simplified version, where we keep only complex label for multi-label cases and the true label for all others...

In [None]:
from kaggle_plantpatho.data import PlantPathologySimpleDataset

dataset = PlantPathologySimpleDataset()

# quick view
fig = plt.figure(figsize=(9, 6))
for i in range(9):
    img, lb = dataset[i]
    ax = fig.add_subplot(3, 3, i + 1, xticks=[], yticks=[])
    ax.imshow(img)
    ax.set_title(f"label: {lb}")

The DataModule include creating training and validation dataset with given split and feading it to particular data loaders...

In [None]:
import pytorch_lightning as pl


from kaggle_plantpatho.data import PlantPathologyDM

dm = PlantPathologyDM()
dm.setup()
print(dm.num_classes)

# quick view
fig = plt.figure(figsize=(12, 4))
for imgs, lbs in dm.train_dataloader():
    lb_hist= dict(zip(range(10), np.bincount(lbs)))
    print(f'batch labels: {lb_hist}')
    print(f'image size: {imgs[0].shape}')
    for i in range(5):
        ax = fig.add_subplot(1, 5, i + 1, xticks=[], yticks=[])
        # print(np.rollaxis(imgs[i].numpy(), 0, 3).shape)
        ax.imshow(np.rollaxis(imgs[i].numpy(), 0, 3))
        ax.set_title(f"label: {lbs[i]}")
    break

In [None]:
dm2 = PlantPathologyDM(full=True)
dm2.setup()

fig = plt.figure(figsize=(16, 4))
for imgs, lbs in dm2.train_dataloader():
    print(f'batch labels: {torch.sum(lbs, axis=0)}')
    print(f'image size: {imgs[0].shape}')
    for i in range(5):
        ax = fig.add_subplot(1, 5, i + 1, xticks=[], yticks=[])
        # print(np.rollaxis(imgs[i].numpy(), 0, 3).shape)
        ax.imshow(np.rollaxis(imgs[i].numpy(), 0, 3))
        ax.set_title(lbs[i])
    break

## CNN Model

We start with some stanrd CNN models taken from torch vision.
Then we define Ligthning module including training and validation step and configure optimizer/schedular.

In [None]:
from kaggle_plantpatho.models import LitResnet, LitPlantPathology

# see: https://pytorch.org/vision/stable/models.html
net = LitResnet(arch='resnet50', num_classes=dm.num_classes)
# print(net)
model = LitPlantPathology(model=net)

## Training

We use Pytorch Lightning which allow us to drop all the boilet plate code and simplify all training just to use/call Trainer...

In [None]:
logger = pl.loggers.CSVLogger(save_dir='logs/', name=model.arch)

# ==============================

trainer = pl.Trainer(
    # fast_dev_run=True,
    gpus=1,
    # callbacks=[cb_ckpt],
    logger=logger,
    max_epochs=10,
    precision=16,
    accumulate_grad_batches=8,
    val_check_interval=0.25,
    progress_bar_refresh_rate=1,
    weights_summary='top',
)

# ==============================

# trainer.tune(model, datamodule=dm)
trainer.fit(model=model, datamodule=dm)

Quick visualization of the training process...

In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')
print(metrics.head())

aggreg_metrics = []
agg_col = "epoch"
for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[['train_loss', 'valid_loss']].plot(grid=True, legend=True, xlabel=agg_col)
df_metrics[['valid_f1', 'valid_acc', 'train_acc']].plot(grid=True, legend=True, xlabel=agg_col)