Reference: [Plant Pathology with Lightning ⚡- By Jirka](https://www.kaggle.com/jirkaborovec/plant-pathology-with-lightning)


In [None]:
!pip install git+https://github.com/PytorchLightning/lightning-flash.git@master -q
!pip install git+https://github.com/PytorchLightning/metrics.git@master -q
!pip install timm -q

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from sklearn.preprocessing import MultiLabelBinarizer
import matplotlib.pyplot as plt

import flash
from flash.vision import ImageClassificationData, ImageClassifier
import torchmetrics


from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.metrics import FBeta
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger


import torch
import torchmetrics
import torchvision
from torch import nn
from torch.nn import functional as F

import os
from glob import glob
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader, Dataset

from PIL import Image

from torchvision import transforms

In [None]:
df = pd.read_csv('/kaggle/input/plant-pathology-2021-fgvc8/train.csv')
df['label_org'] = df.labels.values
df.labels = df.labels.str.split()

In [None]:
import itertools
import seaborn as sns

labels_all = list(itertools.chain(*[lbs.split(" ") for lbs in df['label_org']]))

ax = sns.countplot(y=sorted(labels_all), orient='v')
ax.grid()

In [None]:
labels = []
i = 0
for label in tqdm(df.labels):
    labels.extend(label)
labels = set(labels)
num_classes = len(labels)
labels

In [None]:
mlb = MultiLabelBinarizer(sparse_output=True)
mlb = mlb.fit(df.labels)

In [None]:
def create_ohe(df, mlb):    
    ohe = mlb.transform(df.labels)
    ohe = pd.DataFrame.sparse.from_spmatrix(ohe, columns=mlb.classes_)
    df = df.merge(ohe, left_index=True, right_index=True)
    return df

In [None]:
df = create_ohe(df, mlb)
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df.head()

In [None]:
split = 0.9
frac = int(split * len(df))

train_data = df[:frac]
val_data = df[frac:]

train_data = train_data.sample(frac=1, random_state=42).reset_index(drop=True)
val_data = val_data.sample(frac=1, random_state=42).reset_index(drop=True)

In [None]:
IMAGE_SIZE = 224

In [None]:
from torchvision import transforms as T

TRAIN_TRANSFORM = T.Compose([
    T.Resize(512),
    T.RandomPerspective(),
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    # T.Normalize([0.431, 0.498,  0.313], [0.237, 0.239, 0.227]),  # custom
])

VALID_TRANSFORM = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    # T.Normalize([0.431, 0.498,  0.313], [0.237, 0.239, 0.227]),  # custom
])

In [None]:
class PlantDataset(Dataset):
    def __init__(self, data, transformation, folder='train'):
        self.data = data
        self.transform = transformation
        self.folder = folder
    
    def __len__(self): return len(self.data)
    
    def __getitem__(self, idx):
        folder = self.folder
        file = f'/kaggle/input/plant-pathology-2021-fgvc8/{folder}_images/' + self.data.loc[idx, 'image']
        image = Image.open(file)
        if self.transform:
            image = self.transform(image)
        labels = self.data.iloc[idx, 3:].to_numpy().astype(int)
        return image, labels

In [None]:
train_dataset = PlantDataset(train_data, TRAIN_TRANSFORM)
val_dataset = PlantDataset(val_data, VALID_TRANSFORM)

In [None]:
import multiprocessing as mproc
import pytorch_lightning as pl

class PlantPathologyDM(pl.LightningDataModule):

    def __init__(
        self,
        train_dataset: Dataset = None,
        val_dataset: Dataset = None,
        batch_size: int = 64,
        num_workers: int = None,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers if num_workers is not None else mproc.cpu_count()
        self.train_dataset = train_dataset
        self.valid_dataset = val_dataset

    def prepare_data(self):
        pass

    @property
    def num_classes(self) -> int:
        return num_classes

    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
        )

    def test_dataloader(self):
        pass


In [None]:
dm = PlantPathologyDM(train_dataset, val_dataset)

In [None]:
# # quick view
# fig = plt.figure(figsize=(3, 7))
# for imgs, lbs in dm.train_dataloader():
#     print(f'batch labels: {torch.sum(lbs, axis=0)}')
#     print(f'image size: {imgs[0].shape}')
#     for i in range(3):
#         ax = fig.add_subplot(3, 1, i + 1, xticks=[], yticks=[])
#         # print(np.rollaxis(imgs[i].numpy(), 0, 3).shape)
#         ax.imshow(np.rollaxis(imgs[i].numpy(), 0, 3))
#         ax.set_title(lbs[i])
#     break

In [None]:
class PLModel(pl.LightningModule):
    def __init__(self, backbone, num_classes, lr: float = 1e-4):
        super().__init__()
        self.model = ImageClassifier(num_classes, backbone='xception')
        
        self.num_classes = num_classes
        self.train_accuracy = torchmetrics.Accuracy()
        self.val_accuracy = torchmetrics.Accuracy()
        self.val_f1_score = torchmetrics.F1(self.num_classes)
        self.learn_rate = lr
        self.loss = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        return torch.sigmoid(self.model.forward(x))

    def compute_loss(self, y_hat, y):
        return self.loss(y_hat, y.to(float))
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.compute_loss(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_accuracy(y_hat, y), prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.compute_loss(y_hat, y)
        self.log("valid_loss", loss, prog_bar=False)
        self.log("valid_acc", self.val_accuracy(y_hat, y), prog_bar=True)
        self.log("valid_f1", self.val_f1_score(y_hat, y), prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learn_rate)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, self.trainer.max_epochs, 0)
        return [optimizer], [scheduler]

In [None]:
pl_model = PLModel('xception', dm.num_classes)

In [None]:
logger = pl.loggers.CSVLogger(save_dir='logs/', name='xception_v0')

In [None]:
trainer = pl.Trainer(
    gpus=1,
    logger=logger,
    max_epochs=10,
    accumulate_grad_batches=8,
    val_check_interval=0.25,
    progress_bar_refresh_rate=1,
)


trainer.fit(model=pl_model, datamodule=dm)

In [None]:
# !nvidia-smi