In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# ############################################################################
#
# Context-Aware Adaptive Quantization (CAAQ) for Multi-Crop Disease Diagnosis
#
# Kaggle GPU Implementation with Quantization Aware Training (QAT) - Final
#
# Author: Gemini
# Date: August 16, 2025
#
# ############################################################################


# ############################################################################
# STEP 0: ENVIRONMENT SETUP
# ############################################################################
# This notebook implements the complete pipeline for the CAAQ research framework
# using a more robust Quantization Aware Training (QAT) approach.
# ############################################################################

import os
import zipfile
import time
import copy
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization
import torch.backends.quantized

# Set the quantization engine for compatibility
# Using 'qnnpack' as a more robust alternative to 'fbgemm'
torch.backends.quantized.engine = 'qnnpack'

from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, ConcatDataset, Subset
from tqdm import tqdm
import matplotlib.pyplot as plt

print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)


# ############################################################################
# STEP 1: DATASET DOWNLOAD AND PREPARATION (KAGGLE)
# ############################################################################
KAGGLE_INPUT_DIR = "/kaggle/input"
plantdoc_base_path = os.path.join(KAGGLE_INPUT_DIR, 'plantdoc-dataset')
# Corrected path for the PlantVillage dataset based on previous errors
plantvillage_base_path = os.path.join(KAGGLE_INPUT_DIR, 'new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)')

plantdoc_train_path = os.path.join(plantdoc_base_path, 'train')
plantdoc_valid_path = os.path.join(plantdoc_base_path, 'test')
plantvillage_path = plantvillage_base_path

print("Using dataset paths:")
print(f"PlantDoc Train: {plantdoc_train_path}")
print(f"PlantDoc Valid/Test: {plantdoc_valid_path}")
print(f"PlantVillage Train: {plantvillage_path}")


# --- 1.2: Define Data Augmentation and Transforms ---
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))], p=0.3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# --- 1.3: Load and Combine Datasets ---
plantdoc_train_full = datasets.ImageFolder(plantdoc_train_path)
plantdoc_valid_full = datasets.ImageFolder(plantdoc_valid_path)
plantvillage_train_full = datasets.ImageFolder(os.path.join(plantvillage_path, 'train'))

all_classes = sorted(list(set(plantdoc_train_full.classes + plantvillage_train_full.classes + plantdoc_valid_full.classes)))
class_to_idx = {cls_name: i for i, cls_name in enumerate(all_classes)}
num_classes = len(all_classes)
print(f"Total unique classes (from all splits): {num_classes}")

plantdoc_train_dataset = datasets.ImageFolder(plantdoc_train_path, transform=data_transforms['train'])
plantdoc_train_dataset.class_to_idx = class_to_idx
plantdoc_train_dataset.samples = [(s[0], class_to_idx[os.path.basename(os.path.dirname(s[0]))]) for s in plantdoc_train_dataset.samples]
plantdoc_train_dataset.targets = [s[1] for s in plantdoc_train_dataset.samples]

plantvillage_train_dataset = datasets.ImageFolder(os.path.join(plantvillage_path, 'train'), transform=data_transforms['train'])
plantvillage_train_dataset.class_to_idx = class_to_idx
plantvillage_train_dataset.samples = [(s[0], class_to_idx[os.path.basename(os.path.dirname(s[0]))]) for s in plantvillage_train_dataset.samples]
plantvillage_train_dataset.targets = [s[1] for s in plantvillage_train_dataset.samples]

combined_train_dataset = ConcatDataset([plantdoc_train_dataset, plantvillage_train_dataset])

val_dataset = datasets.ImageFolder(plantdoc_valid_path, data_transforms['val'])
val_dataset.class_to_idx = class_to_idx
val_dataset.samples = [(s[0], class_to_idx[os.path.basename(os.path.dirname(s[0]))]) for s in val_dataset.samples]
val_dataset.targets = [s[1] for s in val_dataset.samples]


# --- 1.4: Create DataLoaders ---
SUBSET_SIZE = 72000
if SUBSET_SIZE:
    train_indices = np.random.choice(len(combined_train_dataset), SUBSET_SIZE, replace=False)
    train_subset = Subset(combined_train_dataset, train_indices)
else:
    train_subset = combined_train_dataset

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

print(f"Combined training images: {len(combined_train_dataset)} (Using subset of {len(train_subset)})")
print(f"Validation images: {len(val_dataset)}")


# ############################################################################
# STEP 2: TRAIN THE FP32 "TEACHER" MODEL
# ############################################################################
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=15):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)
        for phase in ['train', 'val']:
            if phase == 'train': model.train()
            else: model.eval()
            running_loss, running_corrects = 0.0, 0
            for inputs, labels in tqdm(dataloaders[phase], desc=phase):
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train': scheduler.step()
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model

# DEFINITIVE FIX: Switch to a quantization-friendly model architecture
teacher_model_fp32 = models.quantization.mobilenet_v3_large(weights='IMAGENET1K_V2', quantize=False)
# Adapt the classifier for the new model
teacher_model_fp32.classifier[3] = nn.Linear(teacher_model_fp32.classifier[3].in_features, num_classes)

teacher_model_fp32 = teacher_model_fp32.to(device)
criterion = nn.CrossEntropyLoss()
# In STEP 2, when defining the optimizer:
optimizer = optim.Adam(teacher_model_fp32.parameters(), lr=0.001, weight_decay=1e-4)
NUM_EPOCHS = 15
exp_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS, eta_min=1e-6)

print("Starting Teacher Model Training...")
dataloaders_dict = {'train': train_loader, 'val': val_loader}
teacher_model_fp32 = train_model(teacher_model_fp32, dataloaders_dict, criterion, optimizer, exp_lr_scheduler, num_epochs=5)
torch.save(teacher_model_fp32.state_dict(), 'teacher_model_fp32.pth')
print("Teacher model trained and saved.")


# ############################################################################
# STEP 3: QUANTIZATION AWARE TRAINING (QAT)
# ############################################################################
def create_qat_model(fp32_model_path):
    # DEFINITIVE FIX: Use the quantizable version of MobileNetV3
    # Load a fresh model pre-configured for quantization
    model_qat = models.quantization.mobilenet_v3_large(weights=None, quantize=False)
    model_qat.classifier[3] = nn.Linear(model_qat.classifier[3].in_features, num_classes)
    
    # Load the trained weights
    model_qat.load_state_dict(torch.load(fp32_model_path))
    
    # Set model to training mode before preparing for QAT
    model_qat.train()
    
    # Fuse modules for QAT
    model_qat.fuse_model()

    # Set the QAT configuration using the more compatible 'qnnpack'
    model_qat.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')

    # Prepare the model for Quantization Aware Training
    torch.quantization.prepare_qat(model_qat, inplace=True)
    
    return model_qat

print("\nPreparing model for Quantization Aware Training...")
qat_model = create_qat_model('teacher_model_fp32.pth')
qat_model = qat_model.to(device) # Move back to GPU for training

# Fine-tune the QAT model for a few epochs
# Use a smaller learning rate for fine-tuning
optimizer_qat = optim.SGD(qat_model.parameters(), lr=0.0001)
# The training function is the same, we just use the QAT model
print("Starting QAT fine-tuning...")
# Fine-tune for fewer epochs
qat_model = train_model(qat_model, dataloaders_dict, criterion, optimizer_qat, exp_lr_scheduler, num_epochs=3)

# Convert the QAT model to a fully quantized INT8 model
print("\nConverting QAT model to INT8...")
qat_model.to('cpu') # Must be on CPU for conversion
qat_model.eval()
quantized_model_int8 = torch.quantization.convert(qat_model)
print("QAT INT8 model created successfully.")


# ############################################################################
# STEP 4: INFERENCE AND EVALUATION
# ############################################################################
def evaluate_model_performance(model, dataloader, model_name="Model"):
    """Evaluates model accuracy, inference speed, and size."""
    # Always move the model to 'cpu' for evaluation.
    model.to('cpu').eval()

    model_file = f'{model_name.replace(" ", "_")}.pth'
    torch.save(model.state_dict(), model_file)
    size_mb = os.path.getsize(model_file) / (1024 * 1024)
    os.remove(model_file)

    correct, total = 0, 0
    latencies = []
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc=f"Evaluating {model_name}"):
            start_time = time.time()
            outputs = model(images)
            end_time = time.time()
            latencies.append(end_time - start_time)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    avg_latency_ms = (sum(latencies) / len(dataloader.dataset)) * 1000

    print(f"--- {model_name} ---")
    print(f"Accuracy: {accuracy:.2f}%")
    print(f"Model Size: {size_mb:.2f} MB")
    print(f"Avg. Latency: {avg_latency_ms:.3f} ms/image\n")

    return {'Model': model_name, 'Accuracy (%)': accuracy, 'Size (MB)': size_mb, 'Latency (ms/img)': avg_latency_ms}

results = []
# Load FP32 Teacher Model for evaluation
teacher_model_fp32_eval = models.quantization.mobilenet_v3_large(quantize=False)
teacher_model_fp32_eval.classifier[3] = nn.Linear(teacher_model_fp32_eval.classifier[3].in_features, num_classes)
teacher_model_fp32_eval.load_state_dict(torch.load('teacher_model_fp32.pth', map_location='cpu'))
results.append(evaluate_model_performance(teacher_model_fp32_eval, val_loader, "FP32 Teacher"))

# Evaluate the final INT8 QAT model
results.append(evaluate_model_performance(quantized_model_int8, val_loader, "INT8 QAT Model"))


# ############################################################################
# STEP 5: RESULTS AND CONCLUSION
# ############################################################################
results_df = pd.DataFrame(results)
print("--- Comparative Results ---")
print(results_df.to_string(index=False))

print("\n--- Analysis of Results ---")
fp32_acc = results_df.loc[results_df['Model'] == 'FP32 Teacher', 'Accuracy (%)'].values[0]
qat_acc = results_df.loc[results_df['Model'] == 'INT8 QAT Model', 'Accuracy (%)'].values[0]

accuracy_drop = fp32_acc - qat_acc

print(f"The FP32 Teacher model achieved {fp32_acc:.2f}% accuracy.")
print(f"The INT8 QAT model achieved {qat_acc:.2f}% accuracy.")
print(f"The accuracy drop after Quantization Aware Training was only {accuracy_drop:.2f}%.")
print("\nConclusion: Quantization Aware Training successfully created a compressed INT8 model")
print("while preserving a high level of accuracy, overcoming the limitations of PTSQ.")

