In [None]:
# efficientnet_phase2.py
import os
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import torch.nn.functional as F
import torchvision.models as models

In [None]:
# Ensure a consistent random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
# --- CONFIG ---
BASE_PATH = r"C:\Users\ADITYA DAS\Desktop\Machine Learning\CP_DATASET"
CLASSES = ["BLIGHT", "BLAST", "BROWNSPOT", "HEALTHY"]
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 60 # Typically more epochs for fine-tuning
LEARNING_RATE = 1e-5 # Typically lower for fine-tuning

In [None]:
# Check for GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

In [None]:
# --- Custom Dataset ---
class PlantDiseaseDataset(Dataset):
    def __init__(self, filepaths, labels, transform=None, augment=False):
        self.filepaths = filepaths
        self.labels = labels
        self.transform = transform
        self.augment = augment

    def __len__(self):
        return len(self.filepaths)

    def __getitem__(self, idx):
        img_path = self.filepaths[idx]
        image = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.augment:
            # Apply Color Jitter (as part of transform) and GridMask
            image = np.array(image) # Convert to numpy for GridMask
            image = grid_mask(image)
            image = Image.fromarray(image) # Convert back to PIL for torchvision transforms

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(label, dtype=torch.long)

In [None]:
# --- Data Augmentation Functions ---
def color_jitter_transform():
    return transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05)

def grid_mask(img, d_min=50, d_max=100, ratio=0.5):
    h, w, _ = img.shape
    d = random.randint(d_min, d_max)
    l = int(d * ratio)

    mask = np.ones((h, w), dtype=np.float32)

    for i in range(0, h, d):
        for j in range(0, w, d):
            y1 = i
            y2 = min(i + l, h)
            x1 = j
            x2 = min(j + l, w)

            mask[y1:y2, x1:x2] = 0.0

    mask = np.expand_dims(mask, axis=-1)
    return img * mask

def cutmix(images, labels, alpha=1.0):
    batch_size = images.shape[0]
    img_h, img_w = images.shape[2], images.shape[3]

    lam = np.random.beta(alpha, alpha, size=batch_size)
    rand_idx = torch.randperm(batch_size)

    mixed_images = images.clone()
    mixed_labels = labels.clone()

    for i in range(batch_size):
        curr_lam = lam[i]
        
        # Calculate bounding box for cut-and-paste
        cut_rat = np.sqrt(1. - curr_lam)
        cut_w = img_w * cut_rat
        cut_h = img_h * cut_rat

        cx = np.random.uniform(0, img_w)
        cy = np.random.uniform(0, img_h)

        x1 = int(cx - cut_w / 2)
        y1 = int(cy - cut_h / 2)
        x2 = int(cx + cut_w / 2)
        y2 = int(cy + cut_h / 2)

        x1 = np.clip(x1, 0, img_w)
        y1 = np.clip(y1, 0, img_h)
        x2 = np.clip(x2, 0, img_w)
        y2 = np.clip(y2, 0, img_h)
        
        # Adjust lambda based on actual patch size
        bb_area = (x2 - x1) * (y2 - y1)
        lam_adjusted = 1.0 - (bb_area / (img_w * img_h))

        mixed_images[i, :, y1:y2, x1:x2] = images[rand_idx[i], :, y1:y2, x1:x2]

        # One-hot encode labels for mixing
        label1_onehot = F.one_hot(labels[i], num_classes=len(CLASSES)).float()
        label2_onehot = F.one_hot(labels[rand_idx[i]], num_classes=len(CLASSES)).float()
        
        mixed_labels[i] = lam_adjusted * label1_onehot + (1.0 - lam_adjusted) * label2_onehot
    
    return mixed_images, mixed_labels

In [None]:
# --- Load filepaths & labels ---
all_filepaths, all_labels = [], []
for idx, class_name in enumerate(CLASSES):
    aug_path = os.path.join(BASE_PATH, class_name, "augmented")
    files = glob.glob(os.path.join(aug_path, "*.jpg")) + \
            glob.glob(os.path.join(aug_path, "*.jpeg")) + \
            glob.glob(os.path.join(aug_path, "*.png"))
    all_filepaths.extend(files)
    all_labels.extend([idx] * len(files))

print(f"✅ Total images found: {len(all_filepaths)}")

In [None]:
# --- Split data ---
train_filepaths, val_filepaths, train_labels, val_labels = train_test_split(
    all_filepaths, all_labels, test_size=0.2, random_state=SEED, stratify=all_labels
)

print(f"✅ Train samples: {len(train_filepaths)} | Val samples: {len(val_filepaths)}")

In [None]:
# --- Transforms (including EfficientNet specific preprocessing) ---
train_transform = transforms.Compose([
    color_jitter_transform(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# --- Datasets and DataLoaders ---
train_dataset = PlantDiseaseDataset(train_filepaths, train_labels, transform=train_transform, augment=True)
val_dataset = PlantDiseaseDataset(val_filepaths, val_labels, transform=val_transform, augment=False)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# --- EfficientNetB0 Model for Phase 2 ---
class EfficientNetB0_Phase2(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNetB0_Phase2, self).__init__()
        self.backbone = models.efficientnet_b0(weights='IMAGENET1K_V1') # Load pre-trained weights

        # IMPORTANT: Do NOT freeze layers here. They will be unfrozen.

        # Replace the classifier head (must match Phase 1's head structure)
        num_ftrs = self.backbone.classifier[1].in_features
        self.classifier_head = nn.Sequential(
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
        # Re-assign the classifier to the backbone model for seamless forward pass
        self.backbone.classifier = self.classifier_head

    def forward(self, x):
        return self.backbone(x)

model = EfficientNetB0_Phase2(len(CLASSES)).to(DEVICE)

In [None]:
# Load the state dict from Phase 1
LOAD_PATH = r"C:\Users\ADITYA DAS\Desktop\Machine Learning\CP_MODEL\EfficientNetB0_Phase1_CutMix_GridMask.pth"
if os.path.exists(LOAD_PATH):
    model.load_state_dict(torch.load(LOAD_PATH, map_location=DEVICE))
    print(f"✅ Loaded model from: {LOAD_PATH}")
else:
    print(f"❌ Error: Phase 1 model not found at {LOAD_PATH}. Please run Phase 1 training first.")
    exit() # Exit if Phase 1 model is not found

# Now, unfreeze all layers in the entire model
for param in model.parameters():
    param.requires_grad = True

# --- Loss Function and Optimizer ---
criterion = nn.CrossEntropyLoss(label_smoothing=0.05) # Lower label smoothing for fine-tuning
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) # Lower learning rate

In [None]:
# --- Learning Rate Logger ---
class LearningRateLogger:
    def __init__(self, optimizer):
        self.optimizer = optimizer

    def on_epoch_end(self, epoch):
        for param_group in self.optimizer.param_groups:
            lr = param_group['lr']
            print(f"📉 Learning rate at epoch {epoch+1}: {lr:.6f}")

lr_logger = LearningRateLogger(optimizer)

# --- Compute class weights ---
train_labels_for_weights = []
for _, labels in train_loader:
    train_labels_for_weights.extend(labels.cpu().numpy())

class_weights_array = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.arange(len(CLASSES)),
    y=train_labels_for_weights
)
class_weights_tensor = torch.tensor(class_weights_array, dtype=torch.float).to(DEVICE)
print("✅ Computed class weights:", class_weights_array)

In [None]:
# --- Training Loop ---
best_val_accuracy = 0.0
patience_counter = 0

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")
    for i, (inputs, labels) in enumerate(train_loader_tqdm):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        # Apply CutMix
        inputs, mixed_labels = cutmix(inputs, labels)
        
        optimizer.zero_grad()
        outputs = model(inputs)

        if mixed_labels.dim() > 1 and mixed_labels.shape[1] > 1:
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            loss = F.kl_div(log_softmax_outputs, mixed_labels, reduction='batchmean')
        else:
            loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

        train_loader_tqdm.set_postfix(loss=running_loss/total_train, acc=correct_train/total_train)


    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = correct_train / total_train
    print(f"Epoch {epoch+1} Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    y_true_val, y_pred_val = [], []

    with torch.no_grad():
        val_loader_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]")
        for inputs, labels in val_loader_tqdm:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

            y_true_val.extend(labels.cpu().numpy())
            y_pred_val.extend(predicted.cpu().numpy())

            val_loader_tqdm.set_postfix(loss=val_loss/total_val, acc=correct_val/total_val)


    epoch_val_loss = val_loss / len(val_dataset)
    epoch_val_acc = correct_val / total_val
    print(f"Epoch {epoch+1} Val Loss: {epoch_val_loss:.4f} Acc: {epoch_val_acc:.4f}")

    lr_logger.on_epoch_end(epoch)

    # Early Stopping
    if epoch_val_acc > best_val_accuracy:
        best_val_accuracy = epoch_val_acc
        patience_counter = 0
        SAVE_PATH = r"C:\Users\ADITYA DAS\Desktop\Machine Learning\CP_MODEL\EfficientNetB0_Phase2_CutMix_GridMask.pth"
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"✅ Model saved at: {SAVE_PATH} (Best validation accuracy: {best_val_accuracy:.4f})")
    else:
        patience_counter += 1
        print(f"Patience: {patience_counter}/{4}")
        if patience_counter >= 4:
            print("Early stopping triggered.")
            break

In [None]:
# --- Evaluation ---
print("\n📊 Final Evaluation on Validation Set:")
model.eval()
y_true_final, y_pred_final = [], []
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        y_true_final.extend(labels.cpu().numpy())
        y_pred_final.extend(predicted.cpu().numpy())

print(classification_report(y_true_final, y_pred_final, target_names=CLASSES))

cm = confusion_matrix(y_true_final, y_pred_final)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=CLASSES, yticklabels=CLASSES)
plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.tight_layout()
plt.show()