In [None]:
# -*- coding: utf-8 -*-
"""
Plant Disease Detection with Balanced ViT + Augmentation
"""

# ===========================================================================
# Step 1: Install Dependencies
# ===========================================================================
!pip install transformers torchvision opencv-python-headless imbalanced-learn --quiet

# ===========================================================================
# Step 2: Dataset Preparation & Balancing
# ===========================================================================
import os
import cv2
import numpy as np
from PIL import Image
from torchvision import transforms, datasets
from imblearn.over_sampling import RandomOverSampler
import torch
from torch.utils.data import WeightedRandomSampler, DataLoader

# Dataset paths
BASE_PATH = "/kaggle/input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)"
TRAIN_DIR = os.path.join(BASE_PATH, "train")
VALID_DIR = os.path.join(BASE_PATH, "valid")

# Calculate class weights
def get_class_weights(dataset):
    class_counts = torch.bincount(torch.tensor(dataset.targets))
    class_weights = 1. / class_counts.float()
    return class_weights

# ===========================================================================
# Step 3: Enhanced Data Augmentation
# ===========================================================================
class GaborFilterTransform:
    def __init__(self, ksize=31, sigma=4.0, theta=np.pi/4, lambd=10.0, gamma=0.5):
        self.ksize = ksize
        self.sigma = sigma
        self.theta = theta
        self.lambd = lambd
        self.gamma = gamma

    def __call__(self, img):
        img_np = np.array(img)
        gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
        kernel = cv2.getGaborKernel(
            (self.ksize, self.ksize),
            self.sigma,
            self.theta,
            self.lambd,
            self.gamma,
            0, cv2.CV_32F
        )
        filtered = cv2.filter2D(gray, cv2.CV_32F, kernel)
        filtered = cv2.normalize(filtered, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
        return Image.fromarray(cv2.cvtColor(filtered, cv2.COLOR_GRAY2RGB))

# Augmentation transforms
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    GaborFilterTransform(theta=np.pi/4),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    GaborFilterTransform(theta=np.pi/4),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ===========================================================================
# Step 4: Load and Balance Dataset
# ===========================================================================
# Load datasets
train_dataset = datasets.ImageFolder(TRAIN_DIR, transform=train_transform)
val_dataset = datasets.ImageFolder(VALID_DIR, transform=val_transform)

# Calculate class weights and create sampler
class_weights = get_class_weights(train_dataset)
samples_weights = class_weights[train_dataset.targets]
sampler = WeightedRandomSampler(
    weights=samples_weights,
    num_samples=len(samples_weights),
    replacement=True
)

# ===========================================================================
# Step 5: Model Setup with Hyperparameters
# ===========================================================================
from transformers import ViTForImageClassification

# Hyperparameters
BATCH_SIZE = 32
LR = 3e-5
EPOCHS = 10
PATIENCE = 3  # For early stopping

# Dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4
)
val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4
)

# Initialize ViT
model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=38,
    ignore_mismatched_sizes=True
)

# ===========================================================================
# Step 6: Training with Optimization
# ===========================================================================
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import balanced_accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights.to(device))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=1)

best_acc = 0
no_improve = 0

for epoch in range(EPOCHS):
    # Training
    model.train()
    train_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_preds = []
    val_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            val_preds.extend(outputs.argmax(1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    # Calculate balanced accuracy
    balanced_acc = balanced_accuracy_score(val_labels, val_preds)
    print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.4f} | Val Acc: {balanced_acc:.2%}")

    # Early stopping & model checkpoint
    if balanced_acc > best_acc:
        best_acc = balanced_acc
        no_improve = 0
        torch.save(model.state_dict(), 'best_model.pth')
    else:
        no_improve += 1
        if no_improve >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    scheduler.step(balanced_acc)

# ===========================================================================
# Step 7: Save Final Model
# ===========================================================================
torch.save({
    'model_state_dict': model.state_dict(),
    'class_to_idx': train_dataset.class_to_idx,
    'transform': train_transform,
}, 'plant_disease_vit.pth')
