<a href="https://colab.research.google.com/github/HRashidLiaquat/lessons-learned/blob/Transformer-based-crop-disease-detection-system/Agri.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Project Workflow**


1.   Import Important Libraries
2.   Get Data Ready in Kaggle
3.   Preparing data
4.   Loading Training Images
5.   Data Loaders
6.   Build a Traning Model (Transfer Learning)
7.   Model Training
8.   Model Testing (Training Loop)
9.   Model Evaluation
10.  Testing with New Data Point
11.  Save Model








**Import Important libraries**

In [None]:
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from PIL import Image
import zipfile
import random
import matplotlib.pyplot as plt



**Mount GDrive**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

**Zip to unZip data**

In [None]:
zip_path = Path("/content/drive/MyDrive/Colab Notebooks/agriarchive.zip")
extract_path = Path("/content/data")

print("ZIP exists:", zip_path.exists())

extract_path.mkdir(parents=True, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

**Get Ready dataset**

In [None]:
datset_path = Path('/content/data/PlantVillage')

if datset_path.exists():
  print("Dataset found!")
else:
  print("Dataset not found!")

In [None]:
if datset_path.exists():
  datsetfolderlist = list(datset_path.iterdir())
  print("See all folder in my dataset main folder")
  for allfolder in datsetfolderlist:
      print(allfolder)

**Check Total plant classes data**

In [None]:

traning_path = datset_path / '/content/data/PlantVillage/train'
print(traning_path)
# test_path = datset_path / '/content/data/PlantVillage/train/val'
train_classes = [item.name for item in traning_path.iterdir() if item.is_dir()]
train_classes.sort()
num_classes = len(train_classes)
print(f"Total plant classes: {num_classes}")
# print(test_path)

**Display Raw data**

In [None]:
rawdataset = datasets.ImageFolder(traning_path, transform=transforms.ToTensor())

plt.figure(figsize=(20, 8))
for i in range(8):
    idx = random.randint(0, len(rawdataset)-1)
    img, label = rawdataset[idx]
    plt.subplot(2, 4, i+1)
    plt.imshow(img.permute(1, 2, 0))
    plt.title(rawdataset.classes[label])
    plt.axis('off')

**Data Preprocessing (Normalization and augmentation)**

In [None]:
IMG_SIZE = 224
BATCH_SIZE = 16

**Data Augmentattion**

In [None]:
aug_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.2,0.2,0.2)

])

In [None]:
aug_transforms

**WithOut augmentation**

In [None]:
test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

**Data Preprocessing**



In [None]:
prepro_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])

In [None]:
prepro_transforms

**Data Normaliztion**

In [None]:
normalization_transform = transforms.Normalize(
    mean = [0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)

In [None]:
normalization_transform

In [None]:
train_trasform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(0.2, 0.2, 0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

**Show Data after Preprocessing**

In [None]:
processed_datase = datasets.ImageFolder(traning_path, transform=transforms.ToTensor())
def denorm(x):
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    return torch.clamp(x * std + mean, 0, 1)


In [None]:
plt.figure(figsize=(20, 8))
for i in range(8):
    idx = random.randint(0, len(processed_datase)-1)
    img, label = processed_datase[idx]

    img = denorm(img)

    plt.subplot(2, 4, i+1)
    plt.imshow(img.permute(1, 2, 0))
    plt.title(processed_datase.classes[label])
    plt.axis('off')

plt.suptitle('Show After Preprocessing (224x224)')
plt.tight_layout()
plt.show()

**Load Dataset**

In [None]:
dataset = datasets.ImageFolder(datset_path, transform=train_trasform)
print("Total images:", len(dataset))
print("Classes:", dataset.classes)

In [None]:
train_dir = "/content/drive/MyDrive/PlantVillage/train"
val_path   = "/content/drive/MyDrive/PlantVillage/val"

In [None]:
train_dir ,val_path

In [None]:
train_dir = '/content/data/PlantVillage/train'
val_path   = '/content/data/PlantVillage/val'

train_dataset = datasets.ImageFolder(train_dir, transform=train_trasform)
val_dataset   = datasets.ImageFolder(val_path, transform=test_transform)

print("Total train images:", len(train_dataset))
print("Train classes:", train_dataset.classes)
print("Total val images:", len(val_dataset))