# **An efficiency analysis of ConvNeXt-guided distillation for edge-based plant disease detection**

# Phase 1: Environment Setup and Data Preprocessing
---
**Framework:** PyTorch | Batch Size: 16 | Input Dimensions: 224x224  


In [1]:
from collections import Counter

def dataset_summary(name, ds):
    print(f"\n{name} Dataset Summary")
    print("-" * 30)

    try:
        print(f"Total images: {len(ds)}")
    except TypeError:
        print("Could not determine dataset length")
        return

    # If ImageFolder-style dataset
    if hasattr(ds, "classes"):
        print(f"Number of classes: {len(ds.classes)}")
        print("Classes:", ds.classes)

    # Try to get labels
    labels = None
    if hasattr(ds, "targets"):
        labels = ds.targets
    elif hasattr(ds, "samples"):
        labels = [label for _, label in ds.samples]

    if labels is not None:
        class_counts = Counter(labels)
        print("\nImages per class:")
        for cls, count in class_counts.items():
            class_name = ds.classes[cls] if hasattr(ds, "classes") else cls
            print(f"  {class_name}: {count}")
    else:
        print("Could not extract class-wise counts")


# ---- Call this for your datasets ----
if 'dataset' in globals():
    dataset_summary("Training", dataset)

if 'val_dataset' in globals():
    dataset_summary("Validation", val_dataset)

if 'test_dataset' in globals():
    dataset_summary("Testing", test_dataset)


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_dir = "plant_dataset"

batch_size = 16
epochs = 5
num_classes = len(os.listdir(data_dir))

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(data_dir, transform=transform)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def evaluate(model, loader):
    model.eval()
    preds, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            p = torch.argmax(out, dim=1)
            preds.extend(p.cpu().numpy())
            labels.extend(y.cpu().numpy())
    return (
        accuracy_score(labels, preds),
        precision_score(labels, preds, average="macro"),
        recall_score(labels, preds, average="macro"),
        f1_score(labels, preds, average="macro")
    )


In [3]:
import torch
import torchvision

print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)


Torch: 2.9.1+cpu
Torchvision: 0.24.1+cpu


# Phase 2: Dataset Partitioning and Loader Initialization
---
**Data Splits:** Train, Validation, Test | Batch Size: 16 | Transformation: Resize (224x224)

In [4]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

data_dir = "plant_dataset"

train_dir = os.path.join(data_dir, "Train", "Train")
val_dir = os.path.join(data_dir, "Validation", "Validation")
test_dir = os.path.join(data_dir, "Test", "Test")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_dataset = datasets.ImageFolder(train_dir, transform=transform)
val_dataset = datasets.ImageFolder(val_dir, transform=transform)
test_dataset = datasets.ImageFolder(test_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

num_classes = len(train_dataset.classes)

print("Classes:", train_dataset.classes)
print("Number of classes:", num_classes)


Classes: ['Healthy', 'Powdery', 'Rust']
Number of classes: 3


# Phase 3: Model Training and Performance Evaluation
---
**Architecture:** MobileNetV2 (Pre-trained)  
**Optimizer:** AdamW | **Loss Function:** Cross-Entropy

In [None]:
from tqdm import tqdm
from torchvision.models import mobilenet_v2

mobilenet = mobilenet_v2(weights="IMAGENET1K_V1")
mobilenet.classifier[1] = nn.Linear(
    mobilenet.classifier[1].in_features, num_classes
)
mobilenet.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(mobilenet.parameters(), lr=1e-4)

for epoch in range(epochs):
    mobilenet.train()
    running_loss = 0.0

    progress_bar = tqdm(
        train_loader,
        desc=f"Epoch [{epoch+1}/{epochs}]",
        leave=True
    )

    for x, y in progress_bar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = mobilenet(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)

    acc, prec, rec, f1 = evaluate(mobilenet, val_loader)

    print(
        f"Epoch [{epoch+1}/{epochs}] Completed | "
        f"Avg Loss: {avg_loss:.4f} | "
        f"Val Acc: {acc:.4f} | "
        f"Prec: {prec:.4f} | "
        f"Rec: {rec:.4f} | "
        f"F1: {f1:.4f}"
    )

acc_m, prec_m, rec_m, f1_m = evaluate(mobilenet, test_loader)

print("\nFinal MobileNetV2 Test Results")
print("Accuracy:", acc_m)
print("Precision:", prec_m)
print("Recall:", rec_m)
print("F1-score:", f1_m)


# Phase 3.1: Computational Complexity and Efficiency Analysis
---
**Metric:** FLOPs Calculation | Library: fvcore | Input Tensor: (1, 3, 224, 224)

In [18]:
from fvcore.nn import FlopCountAnalysis
import torch

mobilenet.eval()

dummy = torch.randn(1, 3, 224, 224).to(device)

flops_mobilenet = FlopCountAnalysis(mobilenet, dummy)
total_flops_mobilenet = flops_mobilenet.total()

print(f"MobileNetV2 FLOPs: {total_flops_mobilenet / 1e9:.3f} GFLOPs")


Unsupported operator aten::hardtanh_ encountered 35 time(s)
Unsupported operator aten::add encountered 10 time(s)


MobileNetV2 FLOPs: 0.313 GFLOPs


# Phase 3.2: Model Deployment Profiling and Efficiency Metrics
---
**Metric:** Inference Latency | Resource: Parameters & FLOPs | Storage: Model Size (MB)

In [38]:
import time
import os
import torch
import pandas as pd
from fvcore.nn import FlopCountAnalysis

# ------------------------------
# Count TOTAL parameters
# ------------------------------
def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())


# ------------------------------
# Measure inference time
# ------------------------------
def measure_inference_time(model, device, runs=50):
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    with torch.no_grad():
        for _ in range(10):  # warm-up
            _ = model(dummy_input)
        start = time.time()
        for _ in range(runs):
            _ = model(dummy_input)
        end = time.time()
    return (end - start) / runs


# ------------------------------
# Parameter count
# ------------------------------
mobilenet_params = count_all_parameters(mobilenet)

# ------------------------------
# Inference time
# ------------------------------
mobilenet_inference_time = measure_inference_time(mobilenet, device)

# ------------------------------
# Model size
# ------------------------------
torch.save(mobilenet.state_dict(), "mobilenetv2.pth")
mobilenet_model_size_mb = os.path.getsize("mobilenetv2.pth") / (1024 * 1024)

# ------------------------------
# FLOPs using fvcore (GFLOPs)
# ------------------------------
dummy = torch.randn(1, 3, 224, 224).to(device)
total_flops_mobilenet = FlopCountAnalysis(mobilenet, dummy).total() / 1e9  # GFLOPs

# ------------------------------
# Efficiency metrics
# ------------------------------
accuracy_per_million_params = acc_m / (mobilenet_params / 1e6)
accuracy_per_gflop = acc_m / total_flops_mobilenet

# ------------------------------
# Summary table
# ------------------------------
mobilenet_summary = pd.DataFrame([{
    "Model": "MobileNetV2 (Baseline)",
    "Accuracy": acc_m,
    "Precision": prec_m,
    "Recall": rec_m,
    "F1-Score": f1_m,
    "Parameters (M)": mobilenet_params / 1e6,
    "FLOPs (GFLOPs)": total_flops_mobilenet,
    "Inference Time (s)": mobilenet_inference_time,
    "Model Size (MB)": mobilenet_model_size_mb,
    "Accuracy / Million Params": accuracy_per_million_params,
    "Accuracy / GFLOP": accuracy_per_gflop
}])

# ------------------------------
# Save results
# ------------------------------
mobilenet_summary.to_csv("mobilenet_efficiency_metrics.csv", index=False)

mobilenet_summary


Unsupported operator aten::hardtanh_ encountered 35 time(s)
Unsupported operator aten::add encountered 10 time(s)


Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score,Parameters (M),FLOPs (GFLOPs),Inference Time (s),Model Size (MB),Accuracy / Million Params,Accuracy / GFLOP
0,MobileNetV2 (Baseline),0.98,0.980644,0.98,0.979854,2.227715,0.312917,0.010406,8.728404,0.439913,3.13182


In [14]:
"""print("Train samples:", len(train_dataset))
print("Validation samples:", len(val_dataset))
print("Test samples:", len(test_dataset))"""


'print("Train samples:", len(train_dataset))\nprint("Validation samples:", len(val_dataset))\nprint("Test samples:", len(test_dataset))'

In [6]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


Using device: cpu


# Phase 4: Advanced ConvNeXtV2 Model Implementation 
---
**Architecture:** ConvNeXt-Tiny (V2-Inspired) | Optimization: AdamW + Cosine Annealing | 
**Regularization:** GRN & Label Smoothing

In [6]:
import torch
import torch.nn as nn
from torchvision.models import convnext_tiny
from tqdm import tqdm

# ---------- Global Response Normalization (ConvNeXt V2 idea) ----------
class GRN(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
        self.eps = eps

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps)
        return self.gamma * (x * nx) + self.beta + x


# ---------- ConvNeXt V2–Inspired Teacher Model ----------
class ConvNeXtV2Inspired(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = convnext_tiny(weights="IMAGENET1K_V1")
        in_features = self.backbone.classifier[2].in_features

        self.backbone.classifier = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(in_features),
            nn.Unflatten(1, (in_features, 1, 1)),
            GRN(in_features),
            nn.Flatten(),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)


# ---------- Model, Loss, Optimizer, Scheduler ----------
model = ConvNeXtV2Inspired(num_classes).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=epochs
)


# ---------- Training Loop ----------
for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    progress_bar = tqdm(
        train_loader,
        desc=f"Epoch [{epoch+1}/{epochs}]",
        leave=True
    )

    for x, y in progress_bar:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    scheduler.step()
    avg_loss = running_loss / len(train_loader)

    acc, prec, rec, f1 = evaluate(model, val_loader)

    print(
        f"Epoch {epoch+1} Completed | "
        f"Loss: {avg_loss:.4f} | "
        f"Val Acc: {acc:.4f} | "
        f"Prec: {prec:.4f} | "
        f"Rec: {rec:.4f} | "
        f"F1: {f1:.4f}"
    )


# ---------- Final Test Evaluation ----------
acc_c, prec_c, rec_c, f1_c = evaluate(model, test_loader)

print("\nConvNeXt V2-Inspired Teacher Test Results")
print("Accuracy:", acc_c)
print("Precision:", prec_c)
print("Recall:", rec_c)
print("F1-score:", f1_c)


Epoch [1/5]: 100%|██████████| 83/83 [04:15<00:00,  3.08s/it, loss=0.323]


Epoch 1 Completed | Loss: 0.4283 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


Epoch [2/5]: 100%|██████████| 83/83 [03:41<00:00,  2.67s/it, loss=0.297]


Epoch 2 Completed | Loss: 0.3041 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


Epoch [3/5]: 100%|██████████| 83/83 [03:54<00:00,  2.82s/it, loss=0.292]


Epoch 3 Completed | Loss: 0.2983 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


Epoch [4/5]: 100%|██████████| 83/83 [03:50<00:00,  2.77s/it, loss=0.292]


Epoch 4 Completed | Loss: 0.2947 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


Epoch [5/5]: 100%|██████████| 83/83 [03:41<00:00,  2.66s/it, loss=0.292]


Epoch 5 Completed | Loss: 0.2931 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000

ConvNeXt V2-Inspired Teacher Test Results
Accuracy: 0.9733333333333334
Precision: 0.9741876310272537
Recall: 0.9733333333333333
F1-score: 0.9733188165920482


Save the model 


In [7]:
# ---------- Save Model ----------
save_path = "convnext_v2_teacher.pth"

torch.save({
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),  # optional
    "num_classes": num_classes
}, save_path)

print(f"Model saved to {save_path}")

Model saved to convnext_v2_teacher.pth


In [4]:
import torch
import torch.nn as nn
from torchvision.models import convnext_tiny

# ---------- Recreate GRN ----------
class GRN(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
        self.eps = eps

    def forward(self, x):
        gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
        nx = gx / (gx.mean(dim=1, keepdim=True) + self.eps)
        return self.gamma * (x * nx) + self.beta + x


# ---------- Recreate Model Architecture ----------
class ConvNeXtV2Inspired(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.backbone = convnext_tiny(weights=None)  # IMPORTANT
        in_features = self.backbone.classifier[2].in_features

        self.backbone.classifier = nn.Sequential(
            nn.Flatten(),
            nn.LayerNorm(in_features),
            nn.Unflatten(1, (in_features, 1, 1)),
            GRN(in_features),
            nn.Flatten(),
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.backbone(x)


# ---------- Load Saved Model ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = torch.load("convnext_v2_teacher.pth", map_location=device)

num_classes = checkpoint["num_classes"]

model = ConvNeXtV2Inspired(num_classes).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()  # VERY IMPORTANT for inference

print("Model loaded successfully!")

Model loaded successfully!


In [24]:
print(type(teacher))
print(teacher)


<class 'str'>
Hybrid KD MobileNetV2


In [21]:
locals().keys()


dict_keys(['__name__', '__doc__', '__package__', '__loader__', '__spec__', '__builtin__', '__builtins__', '_ih', '_oh', '_dh', 'In', 'Out', 'get_ipython', 'exit', 'quit', 'open', '_', '__', '___', '__vsc_ipynb_file__', '_i', '_ii', '_iii', '_i1', 'torch', 'torchvision', '_i2', '_i3', 'optuna', 'nn', 'F', 'mobilenet_v2', 'tqdm', 'DistillationLoss', 'objective', 'study', '_i4', 'os', 'optim', 'datasets', 'transforms', 'DataLoader', 'accuracy_score', 'precision_score', 'recall_score', 'f1_score', 'device', 'data_dir', 'batch_size', 'epochs', 'num_classes', 'transform', 'dataset', 'loader', 'evaluate', '_i5', 'train_dir', 'val_dir', 'test_dir', 'train_dataset', 'val_dataset', 'test_dataset', 'train_loader', 'val_loader', 'test_loader', '_i6', '_i7', '_i8', '_i9', '_i10', '_i11', '_i12', '_i13', 'teacher', '_i14', '_i15', '_i16', '_i17', 'convnext_tiny', 'GRN', 'ConvNeXtV2Inspired', 'model', 'criterion', 'optimizer', 'scheduler', 'epoch', 'running_loss', 'progress_bar', 'x', 'y', 'out', 'lo

# Phase 4.2: Teacher Model Profiling and Efficiency Metrics
---
**Metrics:** Parametric Density | Performance: Inference Latency | Storage: Binary Model Size (MB)

In [34]:
import time
import os
import torch
import pandas as pd
from fvcore.nn import FlopCountAnalysis

# ------------------------------
# Parameter counting (TOTAL params)
# ------------------------------
def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())


# ------------------------------
# Inference time measurement
# ------------------------------
def measure_inference_time(model, device, runs=50):
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    with torch.no_grad():
        for _ in range(10):  # warm-up
            _ = model(dummy_input)
        start = time.time()
        for _ in range(runs):
            _ = model(dummy_input)
        end = time.time()
    return (end - start) / runs


# ------------------------------
# TOTAL parameter count
# ------------------------------
convnext_params = count_all_parameters(model)

# ------------------------------
# Inference time
# ------------------------------
convnext_inference_time = measure_inference_time(model, device)

# ------------------------------
# Model size on disk
# ------------------------------
torch.save(model.state_dict(), "convnext_v2_inspired.pth")
convnext_model_size_mb = os.path.getsize("convnext_v2_inspired.pth") / (1024 * 1024)

# ------------------------------
# FLOPs using fvcore (CORRECT METHOD)
# ------------------------------
dummy = torch.randn(1, 3, 224, 224).to(device)
total_flops_convnext = FlopCountAnalysis(model, dummy).total() / 1e9  # GFLOPs

# ------------------------------
# Efficiency metrics
# ------------------------------
accuracy_per_million_params = acc_c / (convnext_params / 1e6)
accuracy_per_gflop = acc_c / total_flops_convnext

# ------------------------------
# Summary table
# ------------------------------
convnext_summary = pd.DataFrame([{
    "Model": "ConvNeXt V2-Inspired Teacher",
    "Accuracy": acc_c,
    "Precision": prec_c,
    "Recall": rec_c,
    "F1-Score": f1_c,
    "Parameters (M)": convnext_params / 1e6,
    "FLOPs (GFLOPs)": total_flops_convnext,
    "Inference Time (s)": convnext_inference_time,
    "Model Size (MB)": convnext_model_size_mb,
    "Accuracy / Million Params": accuracy_per_million_params,
    "Accuracy / GFLOP": accuracy_per_gflop
}])

# ------------------------------
# Save results
# ------------------------------
convnext_summary.to_csv("convnext_v2_inspired_efficiency_metrics.csv", index=False)

convnext_summary


Unsupported operator aten::gelu encountered 18 time(s)
Unsupported operator aten::mul encountered 20 time(s)
Unsupported operator aten::add_ encountered 18 time(s)
Unsupported operator aten::unflatten encountered 1 time(s)
Unsupported operator aten::linalg_vector_norm encountered 1 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::add encountered 3 time(s)
Unsupported operator aten::div encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
backbone.features.1.0.stochastic_depth, backbone.features.1.1.stochastic_depth, backbone.features.1.2.stochastic_depth, backbone.features.3.0.stochastic_depth, backbone.features.3.1.stochastic_depth, backbone.features.3.2.stoch

Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score,Parameters (M),FLOPs (GFLOPs),Inference Time (s),Model Size (MB),Accuracy / Million Params,Accuracy / GFLOP
0,ConvNeXt V2-Inspired Teacher,0.973333,0.974188,0.973333,0.973319,27.823971,4.469672,0.036494,106.207631,0.034982,0.217764


# Phase 5:Phase 5: Knowledge Distillation (KD) Training Strategy
---
**Teacher:** ConvNeXt-Tiny (V2-Inspired) | Student: MobileNetV2 | KD Parameters: $\alpha=0.7$, $T=4.0$

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

# ------------------------------
# Knowledge Distillation Loss
# ------------------------------
class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.7, temperature=4.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.ce = nn.CrossEntropyLoss()

    def forward(self, student_logits, teacher_logits, labels):
        hard_loss = self.ce(student_logits, labels)

        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)

        kd_loss = F.kl_div(
            soft_student,
            soft_teacher,
            reduction="batchmean"
        ) * (self.temperature ** 2)

        return self.alpha * hard_loss + (1 - self.alpha) * kd_loss


# ------------------------------
# Student Model (MobileNetV2)
# ------------------------------
from torchvision.models import mobilenet_v2

student = mobilenet_v2(weights="IMAGENET1K_V1")
student.classifier[1] = nn.Linear(
    student.classifier[1].in_features, num_classes
)
student.to(device)


# ------------------------------
# Teacher Model (Frozen)
# ------------------------------
teacher = model   # your ConvNeXt V2–Inspired teacher
teacher.eval()

for p in teacher.parameters():
    p.requires_grad = False


# ------------------------------
# Optimizer & Loss
# ------------------------------
optimizer = torch.optim.AdamW(
    student.parameters(),
    lr=1e-4,
    weight_decay=1e-4
)

criterion = DistillationLoss(
    alpha=0.7,
    temperature=4.0
)


# ------------------------------
# Distillation Training Loop
# ------------------------------
epochs_kd = 5

for epoch in range(epochs_kd):
    student.train()
    running_loss = 0.0

    progress_bar = tqdm(
        train_loader,
        desc=f"KD Epoch [{epoch+1}/{epochs_kd}]",
        leave=True
    )

    for x, y in progress_bar:
        x, y = x.to(device), y.to(device)

        with torch.no_grad():
            teacher_logits = teacher(x)

        student_logits = student(x)

        loss = criterion(student_logits, teacher_logits, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = running_loss / len(train_loader)

    acc, prec, rec, f1 = evaluate(student, val_loader)

    print(
        f"KD Epoch {epoch+1} Completed | "
        f"Loss: {avg_loss:.4f} | "
        f"Val Acc: {acc:.4f} | "
        f"Prec: {prec:.4f} | "
        f"Rec: {rec:.4f} | "
        f"F1: {f1:.4f}"
    )
  

# ------------------------------
# Final Test Evaluation
# ------------------------------
acc_kd, prec_kd, rec_kd, f1_kd = evaluate(student, test_loader)

print("\nHybrid Distilled MobileNetV2 Test Results")
print("Accuracy:", acc_kd)
print("Precision:", prec_kd)
print("Recall:", rec_kd)
print("F1-score:", f1_kd)


KD Epoch [1/5]: 100%|██████████| 83/83 [02:51<00:00,  2.06s/it, loss=0.0891]


KD Epoch 1 Completed | Loss: 0.2486 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


KD Epoch [2/5]: 100%|██████████| 83/83 [02:49<00:00,  2.04s/it, loss=0.0718]


KD Epoch 2 Completed | Loss: 0.1130 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


KD Epoch [3/5]: 100%|██████████| 83/83 [02:54<00:00,  2.10s/it, loss=0.263] 


KD Epoch 3 Completed | Loss: 0.1153 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


KD Epoch [4/5]: 100%|██████████| 83/83 [03:15<00:00,  2.35s/it, loss=0.111] 


KD Epoch 4 Completed | Loss: 0.0970 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000


KD Epoch [5/5]: 100%|██████████| 83/83 [02:59<00:00,  2.16s/it, loss=0.0522]


KD Epoch 5 Completed | Loss: 0.0805 | Val Acc: 1.0000 | Prec: 1.0000 | Rec: 1.0000 | F1: 1.0000

Hybrid Distilled MobileNetV2 Test Results
Accuracy: 0.9733333333333334
Precision: 0.9741876310272537
Recall: 0.9733333333333333
F1-score: 0.9733188165920482


In [10]:
# ------------------------------
# Save Distilled Student Model
# ------------------------------

save_path = "mobilenetv2_distilled.pth"

torch.save({
    "model_state_dict": student.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),  # optional (for resuming training)
    "num_classes": num_classes
}, save_path)

print(f"Distilled model saved to {save_path}")

Distilled model saved to mobilenetv2_distilled.pth


In [5]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v2

# ------------------------------
# Device
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------------
# Load Checkpoint
# ------------------------------
checkpoint = torch.load("mobilenetv2_distilled.pth", map_location=device)
num_classes = checkpoint["num_classes"]

# ------------------------------
# Recreate Student Architecture
# ------------------------------
student = mobilenet_v2(weights=None)  # IMPORTANT: do NOT load ImageNet weights
student.classifier[1] = nn.Linear(
    student.classifier[1].in_features,
    num_classes
)

student.load_state_dict(checkpoint["model_state_dict"])
student.to(device)
student.eval()  # IMPORTANT for inference

print("Distilled MobileNetV2 loaded successfully!")

Distilled MobileNetV2 loaded successfully!


In [None]:
torch.save(student.state_dict(), "hybrid_kd_mobilenetv2.pth")


# Phase 5.1: Final Model Profiling and Efficiency Analysis (Knowledge Distillation)
**Architecture:** Hybrid KD MobileNetV2 | 
**Metrics:** Parameter Density & GFLOPs | 
**Efficiency:** Accuracy per Compute Unit

In [12]:
import time
import os
import torch
import pandas as pd

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def measure_inference_time(model, device, runs=50):
    model.eval()
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input)
        start = time.time()
        for _ in range(runs):
            _ = model(dummy_input)
        end = time.time()
    return (end - start) / runs


# ---------- Parameter Count ----------
hybrid_params = count_parameters(student)

# ---------- Inference Time ----------
hybrid_inference_time = measure_inference_time(student, device)

# ---------- Model Size ----------
torch.save(student.state_dict(), "hybrid_kd_mobilenetv2.pth")
hybrid_model_size_mb = os.path.getsize("hybrid_kd_mobilenetv2.pth") / (1024 * 1024)

# ---------- FLOPs ----------
from fvcore.nn import FlopCountAnalysis

dummy = torch.randn(1, 3, 224, 224).to(device)
hybrid_flops = FlopCountAnalysis(student, dummy).total() / 1e9  # GFLOPs

# ---------- Efficiency Metric ----------
accuracy_per_million_params = acc_kd / (hybrid_params / 1e6)
accuracy_per_gflop = acc_kd / hybrid_flops


# ---------- Summary Table ----------
hybrid_summary = pd.DataFrame([{
    "Model": "Hybrid KD MobileNetV2 (Final)",
    "Accuracy": acc_kd,
    "Precision": prec_kd,
    "Recall": rec_kd,
    "F1-Score": f1_kd,
    "Parameters (M)": hybrid_params / 1e6,
    "FLOPs (GFLOPs)": hybrid_flops,
    "Inference Time (s)": hybrid_inference_time,
    "Model Size (MB)": hybrid_model_size_mb,
    "Accuracy / Million Params": accuracy_per_million_params,
    "Accuracy / GFLOP": accuracy_per_gflop
}])

hybrid_summary.to_csv("hybrid_model_efficiency_metrics.csv", index=False)

hybrid_summary


Unsupported operator aten::hardtanh_ encountered 35 time(s)
Unsupported operator aten::add encountered 10 time(s)


Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score,Parameters (M),FLOPs (GFLOPs),Inference Time (s),Model Size (MB),Accuracy / Million Params,Accuracy / GFLOP
0,Hybrid KD MobileNetV2 (Final),0.973333,0.974188,0.973333,0.973319,2.227715,0.312917,0.01137,8.731517,0.43692,3.110515


# Phase 6: Comprehensive Comparative Analysis and Model Benchmarking
---
**Evaluation Scope:** Baseline vs. Teacher vs. Hybrid Student|
**Hardware:** CUDA/CPU Synchronized Timing| 
**Metrics:** Accuracy/Complexity Trade-off|

In [40]:
import time
import os
import torch
import pandas as pd
from fvcore.nn import FlopCountAnalysis

# ------------------------------
# Utility: count TOTAL parameters
# ------------------------------
def count_all_parameters(model):
    return sum(p.numel() for p in model.parameters())


# ------------------------------
# Utility: CPU/GPU safe inference timing
# ------------------------------
def measure_inference_time(model, device, runs=30):
    model.eval()
    dummy = torch.randn(1, 3, 224, 224).to(device)

    with torch.no_grad():
        # warm-up
        for _ in range(10):
            _ = model(dummy)

        if device.type == "cuda":
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            for _ in range(runs):
                _ = model(dummy)
            end.record()

            torch.cuda.synchronize()
            return start.elapsed_time(end) / (runs * 1000)  # seconds

        else:
            start = time.time()
            for _ in range(runs):
                _ = model(dummy)
            end = time.time()
            return (end - start) / runs


# ------------------------------
# Dummy input for FLOPs
# ------------------------------
dummy = torch.randn(1, 3, 224, 224).to(device)

# ==============================
# MobileNetV2 (Baseline)
# ==============================
mobilenet_params = count_all_parameters(mobilenet)
mobilenet_flops = FlopCountAnalysis(mobilenet, dummy).total() / 1e9  # GFLOPs
mobilenet_time = measure_inference_time(mobilenet, device)
torch.save(mobilenet.state_dict(), "mobilenetv2.pth")
mobilenet_size = os.path.getsize("mobilenetv2.pth") / (1024 * 1024)

# ==============================
# ConvNeXt V2-Inspired Teacher
# ==============================
convnext_params = count_all_parameters(model)
convnext_flops = FlopCountAnalysis(model, dummy).total() / 1e9  # GFLOPs
convnext_time = measure_inference_time(model, device)
torch.save(model.state_dict(), "convnext_teacher.pth")
convnext_size = os.path.getsize("convnext_teacher.pth") / (1024 * 1024)

# ==============================
# Hybrid KD MobileNetV2 (Student)
# ==============================
hybrid_params = count_all_parameters(student)
hybrid_flops = FlopCountAnalysis(student, dummy).total() / 1e9  # GFLOPs
hybrid_time = measure_inference_time(student, device)
torch.save(student.state_dict(), "hybrid_student.pth")
hybrid_size = os.path.getsize("hybrid_student.pth") / (1024 * 1024)

# ==============================
# Final comparison table
# ==============================
comparison_table = pd.DataFrame([
    {
        "Model": "MobileNetV2 (Baseline)",
        "Accuracy": acc_m,
        "Precision": prec_m,
        "Recall": rec_m,
        "F1-Score": f1_m,
        "Parameters (M)": mobilenet_params / 1e6,
        "FLOPs (GFLOPs)": mobilenet_flops,
        "Inference Time (s)": mobilenet_time,
        "Model Size (MB)": mobilenet_size,
        "Accuracy / GFLOP": acc_m / mobilenet_flops
    },
    {
        "Model": "ConvNeXt V2-Inspired (Teacher)",
        "Accuracy": acc_c,
        "Precision": prec_c,
        "Recall": rec_c,
        "F1-Score": f1_c,
        "Parameters (M)": convnext_params / 1e6,
        "FLOPs (GFLOPs)": convnext_flops,
        "Inference Time (s)": convnext_time,
        "Model Size (MB)": convnext_size,
        "Accuracy / GFLOP": acc_c / convnext_flops
    },
    {
        "Model": "Hybrid KD MobileNetV2 (Final)",
        "Accuracy": acc_kd,
        "Precision": prec_kd,
        "Recall": rec_kd,
        "F1-Score": f1_kd,
        "Parameters (M)": hybrid_params / 1e6,
        "FLOPs (GFLOPs)": hybrid_flops,
        "Inference Time (s)": hybrid_time,
        "Model Size (MB)": hybrid_size,
        "Accuracy / GFLOP": acc_kd / hybrid_flops
    }
])

comparison_table.to_csv("final_3_model_comparison.csv", index=False)
comparison_table


Unsupported operator aten::hardtanh_ encountered 35 time(s)
Unsupported operator aten::add encountered 10 time(s)
Unsupported operator aten::gelu encountered 18 time(s)
Unsupported operator aten::mul encountered 20 time(s)
Unsupported operator aten::add_ encountered 18 time(s)
Unsupported operator aten::unflatten encountered 1 time(s)
Unsupported operator aten::linalg_vector_norm encountered 1 time(s)
Unsupported operator aten::mean encountered 1 time(s)
Unsupported operator aten::add encountered 3 time(s)
Unsupported operator aten::div encountered 1 time(s)
The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
backbone.features.1.0.stochastic_depth, backbone.features.1.1.stochastic_depth, backbone.features.1.2.stochastic_

Unnamed: 0,Model,Accuracy,Precision,Recall,F1-Score,Parameters (M),FLOPs (GFLOPs),Inference Time (s),Model Size (MB),Accuracy / GFLOP
0,MobileNetV2 (Baseline),0.98,0.980644,0.98,0.979854,2.227715,0.312917,0.010507,8.728404,3.13182
1,ConvNeXt V2-Inspired (Teacher),0.973333,0.974188,0.973333,0.973319,27.823971,4.469672,0.036628,106.206906,0.217764
2,Hybrid KD MobileNetV2 (Final),0.98,0.980644,0.98,0.979854,2.227715,0.312917,0.009694,8.72932,3.13182


# Trying Olama Bc

class_names = train_dataset.classes

In [6]:
import torch.nn.functional as F

In [7]:
class_names = train_dataset.classes

In [8]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [9]:
import torch
from torchvision import datasets, transforms
import os

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Load training dataset
train_dataset = datasets.ImageFolder(
    root="plant_dataset/Train/Train",
    transform=transform
)

# Load test dataset
test_dataset = datasets.ImageFolder(
    root="plant_dataset/Test/Test",
    transform=transform
)

class_names = train_dataset.classes

print("Classes:", class_names)
print("Number of classes:", len(class_names))
print("Checkpoint classes:", num_classes)

Classes: ['Healthy', 'Powdery', 'Rust']
Number of classes: 3
Checkpoint classes: 3


In [10]:
from torch.utils.data import DataLoader

test_loader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=True
)

In [11]:
import torch.nn.functional as F

def predict_image(image_tensor):
    image_tensor = image_tensor.to(device)

    with torch.no_grad():
        outputs = student(image_tensor)
        probs = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probs, 1)

    disease_name = class_names[predicted.item()]
    return disease_name, confidence.item()

In [15]:
image_tensor, true_label = next(iter(test_loader))

disease, confidence = predict_image(image_tensor)

print("Predicted Disease:", disease)
print("Confidence:", confidence)
print("True Label:", class_names[true_label.item()])

Predicted Disease: Healthy
Confidence: 0.9522280097007751
True Label: Healthy


In [16]:
def get_severity_hint(confidence):
    if confidence > 0.90:
        return "High confidence prediction. Likely clear visible symptoms."
    elif confidence > 0.75:
        return "Moderate confidence prediction. Symptoms may be developing."
    else:
        return "Low confidence prediction. Consider rechecking the plant."

In [19]:
severity_hint = get_severity_hint(confidence)

report = generate_plant_report(disease, confidence, severity_hint)

print("\n--- AI Agricultural Report ---\n")
print(report)


--- AI Agricultural Report ---

 Confirmation: Your plant appears healthy with a high model confidence of 0.95. Continue providing it with adequate sunlight, water, and nutrients as needed.

Preventive Care Steps: Ensure proper watering schedule (avoid overwatering), provide appropriate lighting, and fertilize when necessary using a balanced fertilizer.

Monitoring Recommendations: Keep an eye on the plant for any signs of pests or disease. Check the overall health regularly by examining leaves, stems, and roots.

Recheck: Schedule a recheck after two weeks to assess if there are any changes in the plant's condition. If issues arise before then, do not hesitate to examine the plant again.


In [18]:
from ollama_helper import generate_plant_report

In [15]:
import torch.nn.functional as F

def predict_image(image_tensor):
    image_tensor = image_tensor.to(device)

    with torch.no_grad():
        outputs = student(image_tensor)
        probs = F.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probs, 1)

    disease_name = class_names[predicted.item()]
    
    return disease_name, confidence.item()

In [17]:
import os
from PIL import Image

# Pick class folder manually
target_class = "Powdery"

test_path = "plant_dataset/Test/Test"

# Get first image from Powdery folder
image_file = os.listdir(os.path.join(test_path, target_class))[0]

image_path = os.path.join(test_path, target_class, image_file)

print("Testing image:", image_path)

image = Image.open(image_path).convert("RGB")
image_tensor = transform(image).unsqueeze(0)

disease, confidence = predict_image(image_tensor)

print("Predicted:", disease)
print("Confidence:", confidence)

Testing image: plant_dataset/Test/Test\Powdery\80bc7d353e163e85.jpg
Predicted: Powdery
Confidence: 0.9335793256759644


In [23]:
severity_hint = get_severity_hint(confidence)

report = generate_plant_report(disease, confidence, severity_hint)

print("\n--- AI Agricultural Report ---\n")
print(report)


--- AI Agricultural Report ---

 Title: Comprehensive Plant Care Guide for Farmers

1. Preventive Care Tips:
   - Regularly inspect your plants for any signs of disease or pests. Early detection can prevent the spread and minimize damage.
   - Prune dead, diseased, or damaged parts of the plant to promote growth and improve overall health.
   - Rotate crops periodically to reduce the build-up of soil-borne diseases and pests.

2. Best Irrigation Practices:
   - Water at the base of the plant to avoid moisture build-up on leaves, reducing risk of fungal diseases.
   - Use drip irrigation or soaker hoses for efficient water use and consistent moisture levels.
   - Avoid overwatering as it can lead to root rot and other diseases.
   - Consider using rainwater harvesting systems for sustainable irrigation practices.

3. Fertilizer Recommendations:
   - Use a balanced fertilizer with equal amounts of nitrogen, phosphorus, and potassium (N-P-K ratio: 10-10-10).
   - Test your soil regularly