In [8]:
! pip install numpy pandas seaborn matplotlib tqdm
! pip install datasets "transformers[torch]" scikit-learn



# Rice Leaf Disease Detection - Model Evaluation
This notebook evaluates multiple transformer-based models for rice leaf disease classification.

In [9]:

import os
import json
import torch
import time
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForImageClassification, AutoProcessor
from sklearn.metrics import classification_report, confusion_matrix
from google.colab import drive


## Function: Model Evaluation
This function loads a pre-trained model and evaluates it on the test dataset.

In [10]:

def evaluate_model(model_name, dataset, labels, batch_size=16, device="cuda" if torch.cuda.is_available() else "cpu"):
    """Loads a model and evaluates it on the dataset in batches."""
    print(f"Evaluating {model_name}...")

    model = AutoModelForImageClassification.from_pretrained(model_name).to(device)
    processor = AutoProcessor.from_pretrained(model_name)

    y_true, y_pred = [], []

    start_time = time.time()

    for example in tqdm(dataset, desc=f"Testing {model_name}"):
        image, label = example["image"], example["label"]
        inputs = processor(images=image, return_tensors="pt").to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            pred_label = torch.argmax(outputs.logits, dim=-1).cpu().item()

        y_true.append(label)
        y_pred.append(pred_label)

    elapsed_time = time.time() - start_time
    print(f"Model {model_name} evaluation completed in {elapsed_time:.2f} seconds.")
    return y_true, y_pred, elapsed_time


## Function: Generate Report
This function generates and saves a classification report and confusion matrix.

In [11]:

def generate_report(y_true, y_pred, labels, model_name, output_dir, elapsed_time):
    """Generates and saves classification report and confusion matrix."""
    model_safe_name = model_name.split("/")[-1]
    model_safe_name = model_safe_name.split("_")[0] + "-tl"

    report = classification_report(y_true, y_pred, target_names=labels, output_dict=True)
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype("float") / cm.sum(axis=1, keepdims=True)

    report["evaluation_time_sec"] = elapsed_time

    # Save JSON report
    report_path = os.path.join(output_dir, f"{model_safe_name}/report.json")
    with open(report_path, "w") as f:
        json.dump(report, f, indent=4)

    # Save Excel report
    report_df = pd.DataFrame(report).transpose()
    excel_path = os.path.join(output_dir, f"{model_safe_name}/report.xlsx")

    with pd.ExcelWriter(excel_path) as writer:
        report_df.to_excel(writer, sheet_name="Classification Report")
        pd.DataFrame(cm, index=labels, columns=labels).to_excel(writer, sheet_name="Confusion Matrix")
        pd.DataFrame(cm_normalized, index=labels, columns=labels).to_excel(writer, sheet_name="Normalized Confusion Matrix")

    # Save confusion matrix plot
    def save_cm_plot(matrix, title, filename, fmt="d"):
        plt.figure(figsize=(10, 8))
        sns.heatmap(matrix, annot=True, fmt=fmt, cmap="Blues", xticklabels=labels, yticklabels=labels)
        plt.xlabel("Predicted Label")
        plt.ylabel("True Label")
        plt.title(title, pad=20)
        plt.xticks(rotation=30)
        plt.yticks(rotation=30)
        plt.savefig(os.path.join(output_dir, filename), bbox_inches="tight", pad_inches=0.3)
        plt.close()

    save_cm_plot(cm, f"{model_safe_name} Confusion Matrix", f"{model_safe_name}/confusion_matrix.png")
    save_cm_plot(cm_normalized, f"{model_safe_name} Confusion Matrix", f"{model_safe_name}/normalized_confusion_matrix.png", fmt=".2f")


In [12]:
def save_dataset_info(dataset, output_dir):
    info_path = os.path.join(output_dir, "dataset_info.json")
    # Get class distribution
    labels = dataset.features["label"].names
    label_counts = {label: 0 for label in labels}

    for example in dataset:
        label_counts[labels[example["label"]]] += 1

    dataset_info = {
        "num_samples": len(dataset),
        "num_classes": len(labels),
        "class_distribution": label_counts,
    }

    # Save dataset info as JSON
    with open(info_path, "w") as f:
        json.dump(dataset_info, f, indent=4)

    # Save class distribution plot
    plt.figure(figsize=(10, 6))
    sns.barplot(x=list(label_counts.keys()), y=list(label_counts.values()), palette="viridis")
    plt.xticks(rotation=30)
    plt.xlabel("Classes")
    plt.ylabel("Count")
    plt.title("Class Distribution in Test Dataset")
    plt.savefig(os.path.join(output_dir, "class_distribution.png"), bbox_inches="tight", pad_inches=0.3)
    plt.close()

    print("✅ Dataset info saved.")

## Main Function
This function loads the dataset, evaluates models, and saves reports to Google Drive.

In [13]:

def main():
    """Mounts Google Drive, loads dataset, evaluates models, and saves reports."""
    drive.mount("/content/drive")

    models_path = [
        "cvmil/resnet-50_rice-leaf-disease-augmented_tl",
        "cvmil/vit-base-patch16-224_rice-leaf-disease-augmented_tl",
        "cvmil/swin-base-patch4-window7-224_rice-leaf-disease-augmented_tl",
        "cvmil/deit-base-patch16-224_rice-leaf-disease-augmented_tl",
        "cvmil/beit-base-patch16-224_rice-leaf-disease-augmented_tl",
        "cvmil/dinov2-base_rice-leaf-disease-augmented_tl",
    ]

    dataset = load_dataset("cvmil/rice-leaf-disease-augmented", split="test")
    labels = dataset.features["label"].names

    output_dir = "/content/drive/Shareddrives/CS198-Drones/final_test_tl/"
    os.makedirs(output_dir, exist_ok=True)

    save_dataset_info(dataset, output_dir)

    for model_name in models_path:
        try:
            y_true, y_pred, elapsed_time = evaluate_model(model_name, dataset, labels)
            generate_report(y_true, y_pred, labels, model_name, output_dir, elapsed_time)
        except Exception as e:
            print(f"⚠️ Error processing {model_name}: {e}")

    print("✅ Evaluation completed. Reports saved to Google Drive.")

if __name__ == "__main__":
    main()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).



Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect.

  sns.barplot(x=list(label_counts.keys()), y=list(label_counts.values()), palette="viridis")


✅ Dataset info saved.
Evaluating cvmil/resnet-50_rice-leaf-disease-augmented_tl...


Testing cvmil/resnet-50_rice-leaf-disease-augmented_tl: 100%|██████████| 2000/2000 [01:26<00:00, 22.99it/s]


Model cvmil/resnet-50_rice-leaf-disease-augmented_tl evaluation completed in 86.98 seconds.
Evaluating cvmil/vit-base-patch16-224_rice-leaf-disease-augmented_tl...


Testing cvmil/vit-base-patch16-224_rice-leaf-disease-augmented_tl: 100%|██████████| 2000/2000 [01:16<00:00, 26.06it/s]


Model cvmil/vit-base-patch16-224_rice-leaf-disease-augmented_tl evaluation completed in 76.75 seconds.
Evaluating cvmil/swin-base-patch4-window7-224_rice-leaf-disease-augmented_tl...


config.json:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/348M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Testing cvmil/swin-base-patch4-window7-224_rice-leaf-disease-augmented_tl:  12%|█▏        | 248/2000 [00:16<01:53, 15.38it/s]


KeyboardInterrupt: 