# Plant Disease Recognition

In [1]:
# Setting up the project
# !git clone -n --depth 1 --filter tree:0 https://github.com/spMohanty/PlantVillage-Dataset.git ./dataset
# !cd ./dataset && git sparse-checkout set raw/color && git checkout
# !pip install ivy

In [2]:
import os
import json
import glob
import multiprocessing as mp
from typing import TypeAlias

import ivy
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import cv2
import numpy as np
from tqdm.notebook import tqdm

In [3]:
ivy.set_backend("torch")
ivy.set_default_device("gpu:0" if ivy.gpu_is_available() else "cpu")

In [4]:
ivy.default_device()

'gpu:0'

In [5]:
# Constants
SEED = 8753

## Utils

In [6]:
# Utils

class AverageCalculator:
    def __init__(self):
        self.reset()

    def update(self, num, count=1):
        self.count += count
        self.sum += num * count

    def avg(self):
        return self.sum/self.count

    def reset(self):
        self.sum = self.count = 0.0

## Creating Dataset and DataLoader

In [7]:
Data: TypeAlias = tuple[ivy.Array, ivy.Array]

class PlantVillageDataset(Dataset):
    """Represents the PlantVillage Dataset"""

    IMG_SHAPE = (256, 256, 3)

    def __init__(self, dataset_path: str | None, seed: int | None = None):
        """
        Args:
            dataset_path: Path to the local PlantVillage repo
        """

        if dataset_path is None:
            return

        disease_folders_path = os.path.join(dataset_path, "raw/color")
        disease_folders = glob.glob(os.path.join(disease_folders_path, "*"))
        self.label_names = [os.path.basename(x) for x in disease_folders]

        images = []
        labels = []
        for i, label_text in enumerate(self.label_names):
            imgs_path = glob.glob(os.path.join(disease_folders_path, label_text, "*"))
            images += imgs_path
            labels += [i] * len(imgs_path)

        self.images = np.array(images)
        self.labels = torch.tensor(labels)
        self.labels = F.one_hot(self.labels)

        assert self.labels.shape[1] == 38, "Wrong one-hot on labels"
        assert len(self.images) == len(self.labels), \
            "image array and label array do not have equal sizes"


    def split(
        self,
        ratio: float,
        shuffle: bool = True,
        seed: int | None = None
    ) -> tuple['PlantVillageDataset', 'PlantVillageDataset']:
        """Split the dataset into two PlantVillage datasets."""

        assert 0 < ratio < 1, "ratio must be between 0 and 1"

        if shuffle:
            self.shuffle(seed)

        split1_size = int(len(self.images) * ratio)

        ds1 = PlantVillageDataset(None)
        ds2 = PlantVillageDataset(None)
        ds1.label_names = self.label_names
        ds2.label_names = self.label_names
        ds1.images = self.images[:split1_size]
        ds1.labels = self.labels[:split1_size]
        ds2.images = self.images[split1_size:]
        ds2.labels = self.labels[split1_size:]

        return ds1, ds2


    def shuffle(self, seed: int | None = None) -> None:
        rng = np.random.default_rng(seed)
        idxs = rng.permutation(len(self.images))
        self.images = self.images[idxs]
        self.labels = self.labels[idxs]


    def __len__(self):
        return len(self.images)


    def __getitem__(self, i: int) -> Data:
        img_path = self.images[i]
        img = cv2.imread(img_path)
        assert img.shape == self.IMG_SHAPE, f"Wrong image shape {img.shape}"
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = torch.from_numpy(img).to(torch.float32)
        img = 2 * (img / 255) - 1

        return img, self.labels[i]


In [8]:
dataset = PlantVillageDataset("./dataset")
len(dataset)

54305

In [9]:
train_ds, test_ds = dataset.split(0.9, seed = SEED)
len(train_ds), len(test_ds)

(48874, 5431)

In [10]:
# Testing the data retrieval
x = train_ds[5]
print(x[0].shape, x[0][4,77,0])
print(x[1].shape, x[1].dtype, "\n")

x = test_ds[5]
print(x[0].shape, x[0][175,255,1])
print(x[1].shape, x[1].dtype)

torch.Size([256, 256, 3]) tensor(0.6549)
torch.Size([38]) torch.int64 

torch.Size([256, 256, 3]) tensor(0.1059)
torch.Size([38]) torch.int64


In [15]:
# Creating DataLoaders
BATCH_SIZE = 4
NUM_WORKERS = 4

train_dl = DataLoader(train_ds, BATCH_SIZE, True, num_workers=NUM_WORKERS)
test_dl = DataLoader(test_ds, BATCH_SIZE, False, num_workers=NUM_WORKERS)

## Model Building

In [16]:
class PlantDiseaseRecogniser(ivy.Module):
    "Plant Disease Recognition model for 256x256 images."

    def __init__(self, num_classes: int):
        kernal_size = [3, 3]
        self.cnn = ivy.Sequential(
            ivy.Conv2D(3, 8, kernal_size, 1, "same"),
            ivy.BatchNorm2D(8),
            ivy.LeakyReLU(0.2),
            ivy.MaxPool2D((2,2), 2, 0),
            ivy.Conv2D(8, 32, kernal_size, 1, "same"),
            ivy.BatchNorm2D(32),
            ivy.LeakyReLU(0.2),
            ivy.MaxPool2D((2,2), 2, 0),
            ivy.Conv2D(32, 64, kernal_size, 1, "same"),
            ivy.BatchNorm2D(64),
            ivy.LeakyReLU(0.2),
            ivy.MaxPool2D((2,2), 2, 0),
            ivy.Conv2D(64, 128, kernal_size, 1, "same"),
            ivy.BatchNorm2D(128),
            ivy.LeakyReLU(0.2),
            ivy.MaxPool2D((2,2), 2, 0),
            ivy.Conv2D(128, 128, kernal_size, 1, "same"),
            ivy.BatchNorm2D(128),
            ivy.LeakyReLU(0.2),
        )

        self.fc = ivy.Sequential(
            ivy.Linear(32768, 1000),
            ivy.LeakyReLU(0.2),
            ivy.Dropout(0.2),
            ivy.Linear(1000, num_classes),
        )

        self.loss_func = ivy.CrossEntropyLoss()
        self.optimizer = ivy.Adam()

        super().__init__()


    def train_model(
        self,
        train_dl: DataLoader,
        test_dl: DataLoader,
        ckpt_path: str,
        learning_rate: float | None = None,
        epochs: int | None = None
    ) -> None:
        """
        Train the model on PlantVillageDataset.

        Args:
            train_data: dataset for training the model.
            test_data: dataset for testing the model.
            ckpt_path: Directory path where checkpoints will be saved
            learning_rate: learning rate while training
            epochs: No. of training epochs
        """

        os.makedirs(ckpt_path, exist_ok=True)
        append_to_ckpt = lambda x: os.path.join(ckpt_path, x)

        best_ckpt = append_to_ckpt("best.pt")
        last_ckpt = append_to_ckpt("last.pt")
        run_json_path = append_to_ckpt("training.json")

        if os.path.exists(run_json_path):
            with open(run_json_path, 'r') as f:
                run = json.load(f)
        else:
            if learning_rate is None or epochs is None:
                raise ValueError("learning_rate and epochs must be"
                    "set if previous run json does not exist")

            run = {'lr': 0, "epochs": 0, "last_epoch": 0, # epochs start from 1
                   "train_losses": [], "test_losses": []}

        if learning_rate is not None: run["lr"] = learning_rate
        if epochs is not None: run["epochs"] = epochs

        self.optimizer._lr = run["lr"]

        if os.path.exists(last_ckpt): self.load_model(last_ckpt)

        calc = AverageCalculator()
        if len(run["test_losses"]) == 0:
            least_loss = 100.
        else:
            least_loss = min(run["test_losses"])
        print(f"Lowest test loss yet: {least_loss}")

        DEVICE = ivy.default_device()        
        for epoch in range(run["last_epoch"]+1, run["epochs"]+1):
            print(f"\nEpoch {epoch}/{run['epochs']}:")
            self.train(True)
            for data in tqdm(train_dl):
                imgs = ivy.array(data[0], device=DEVICE)
                labels = ivy.array(data[1], device=DEVICE)
                loss, grads = ivy.execute_with_gradients(
                    lambda x: self.calculate_loss(*x),
                    (imgs, labels, self.v),
                    xs_grad_idxs = [[2]],
                )
                self.v = self.optimizer.step(self.v, grads)
                calc.update(loss.item(), imgs.shape[0])

            run["train_losses"].append(calc.avg())
            calc.reset()
            print("Training Loss:", run["train_losses"][-1])

            self.train(False)
            for data in tqdm(test_dl):
                imgs = ivy.array(data[0], device=DEVICE)
                labels = ivy.array(data[1], device=DEVICE)
                loss = self.calculate_loss(imgs, labels)
                calc.update(loss.item(), imgs.shape[0])

            run["test_losses"].append(calc.avg())
            calc.reset()
            print("Testing Loss:", run["test_losses"][-1])

            if run["test_losses"][-1] < least_loss:
                self.save(best_ckpt)
                least_loss = run["test_losses"][-1]
                print("New best model saved!")

            self.save(last_ckpt)

            run["last_epoch"] = epoch
            with open(run_json_path, 'w') as f:
                json.dump(run, f, indent=4)

        print("\nTraining complete!")


    def _forward(self, X: ivy.Array, training: bool = False) -> ivy.Array:
        out = self.cnn(X).flatten(start_dim=1)
        out = self.fc(out)
        if not training:
            out = out.argmax(axis=1)
        return out.softmax(axis=1)


    def calculate_loss(
        self,
        inputs: ivy.Array,
        targets: ivy.Array,
        variables: ivy.Container = None,
    ) -> ivy.Array:

        preds = self(inputs, training=True, v=variables)
        losses = self.loss_func(targets, preds)
        return losses.mean()


    def load_model(self, model_path: str):
        self.__dict__.update(self.load(model_path).__dict__)

## Model Training

In [17]:
model = PlantDiseaseRecogniser(38)

In [None]:
model.train_model(train_dl, test_dl,
                  ckpt_path="./model",
                  learning_rate=0.001,
                  epochs=10)