# Kaggle: [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification/overview)


## Setup environment 

- connect the gDrive with dataset
- extract data to local
- install pytorch lightning

In [None]:
from google.colab import drive

# connect to  my gDrive
drive.mount("/content/gdrive")

In [None]:
# copy the dataset to local drive
! cp /content/gdrive/MyDrive/Data/cassava-leaf-disease-classification.zip .

In [None]:
# extract dataset to the drive
! unzip -q cassava-leaf-disease-classification.zip
! ls -l

In [None]:
! pip install "pytorch-lightning==1.2.0rc0" "lightning-bolts==0.3.2rc1" "lightning-flash==0.2.2rc2" "torchtext==0.5" -q

# import os
# os.kill(os.getpid(), 9)
! pip list | grep torch

In [None]:
! nvidia-smi

## Data exploration

In [None]:
%matplotlib inline

import json
from pprint import pprint

import pandas as pd

path_csv = "/content/train.csv"
train_data = pd.read_csv(path_csv)
print(train_data.head())

label_mapping = json.load(open("/content/label_num_to_disease_map.json"))
label_mapping = {int(k): v for k, v in label_mapping.items()}
pprint(label_mapping)

In [None]:
import numpy as np
import seaborn as sns

lb_hist = dict(zip(range(10), np.bincount(train_data["label"])))
pprint(lb_hist)

ax = sns.countplot(y=train_data["label"].map(label_mapping), orient="v")
ax.grid()

In [None]:
import matplotlib.pyplot as plt

fig, axarr = plt.subplots(nrows=4, ncols=5, figsize=(16, 10))
for lb, df_ in train_data.groupby("label"):
    img_names = list(df_["image_id"])
    for i in range(4):
        img_name = img_names[i]
        img = plt.imread(f"/content/train_images/{img_name}")
        axarr[i, lb].imshow(img)
        axarr[i, lb].set_title(f"label: {lb} & image: {img_name}")
        axarr[i, lb].set_xticks([])
        axarr[i, lb].set_yticks([])
fig.tight_layout()

## Dataset adjustment

In [None]:
import os
import shutil

import pandas as pd
import tqdm

path_csv = "/content/train.csv"
data = pd.read_csv(path_csv)
# shuffle data
data = data.sample(frac=1, random_state=42).reset_index(drop=True)

frac = int(0.8 * len(data))
train = data[:frac]
valid = data[frac:]

# crating train and valid folder
for folder, df in [("train", train), ("valid", valid)]:
    folder = os.path.join("/content/dataset", folder)
    os.makedirs(folder, exist_ok=True)
    # triage images per class / label
    for _, row in tqdm.tqdm(df.iterrows()):
        img_name, lb = row["image_id"], row["label"]
        folder_lb = os.path.join(folder, str(lb))
        # create folder for label if it is missing
        if not os.path.isdir(folder_lb):
            os.mkdir(folder_lb)
        shutil.copy(os.path.join("/content/train_images", img_name), os.path.join(folder_lb, img_name))

! ls -l /content/dataset/train
! ls -l /content/dataset/valid

## Flash finetuning

In [None]:
import multiprocessing as mproc

import flash
import torch
from flash.core.data import download_data
from flash.core.finetuning import FreezeUnfreeze
from flash.vision import ImageClassificationData, ImageClassifier

# 2. Load the data
datamodule = ImageClassificationData.from_folders(
    train_folder="/content/dataset/train/",
    valid_folder="/content/dataset/valid/",
    batch_size=128,
    num_workers=mproc.cpu_count(),
)

# 3. Build the model
model = ImageClassifier(
    backbone="resnet34",
    optimizer=torch.optim.Adam,
    num_classes=datamodule.num_classes,
)

In [None]:
# 4. Create the trainer. Run twice on data
trainer = flash.Trainer(
    gpus=1,
    max_epochs=3,
    precision=16,
    val_check_interval=0.5,
    progress_bar_refresh_rate=1,
)

# 5. Train the model
trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))

# 7. Save it!
trainer.save_checkpoint("image_classification_model.pt")

In [None]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir lightning_logs/