In [6]:
# agriguard_pytorch.py
# PyTorch version of AGRIGUARD training pipeline
# Compatible with torch 2.x + CUDA (works with your torch 2.7.1+cu118)

import os
import math
import argparse
from pathlib import Path
from glob import glob
import random
import numpy as np
import pandas as pd
from PIL import Image

from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score, precision_score, recall_score,
                             f1_score, balanced_accuracy_score, classification_report,
                             confusion_matrix, roc_auc_score)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
from torchvision import models

# ---------------------------
# User config (defaults)
# ---------------------------
def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", default="data", help="root data folder; expects train/<class> folders and optionally val/")
    p.add_argument("--img_size", type=int, default=224)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--epochs_head", type=int, default=8)
    p.add_argument("--epochs_finetune", type=int, default=12)
    p.add_argument("--lr_head", type=float, default=1e-3)
    p.add_argument("--lr_finetune", type=float, default=5e-5)
    p.add_argument("--unfreeze_last_n", type=int, default=30)
    p.add_argument("--use_focal", action="store_true")
    p.add_argument("--focal_alpha", type=float, default=0.75)
    p.add_argument("--focal_gamma", type=float, default=2.0)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--save_path", default="mobilenetv3_best.pt")
    p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    return p.parse_args()

args = get_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)

DEVICE = torch.device(args.device)
print("Using device:", DEVICE)
print("Torch:", torch.__version__)

# ---------------------------
# Dataset & helpers
# ---------------------------
def discover_images(folder):
    image_paths, labels = [], []
    classes = sorted([d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith(".")])
    for idx, cls in enumerate(classes):
        class_dir = os.path.join(folder, cls)
        for p in glob(os.path.join(class_dir, "*")):
            if p.lower().endswith(('.jpg', '.jpeg', '.png')):
                image_paths.append(p)
                labels.append(idx)
    return image_paths, labels, classes

class PlantDataset(Dataset):
    def __init__(self, df, img_size=224, training=False):
        self.df = df.reset_index(drop=True)
        self.img_size = img_size
        self.training = training

        # Transforms: GPU-friendly, deterministic-ish
        if training:
            self.transform = T.Compose([
                T.Resize((self.img_size, self.img_size)),
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.RandomRotation(8),
                T.RandomResizedCrop(self.img_size, scale=(0.9, 1.0)),
                T.ToTensor(),
                T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
            ])
        else:
            self.transform = T.Compose([
                T.Resize((self.img_size, self.img_size)),
                T.ToTensor(),
                T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
            ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row['path']
        label = row['binary_label']  # 1 = healthy, 0 = diseased

        # read image
        img = Image.open(path).convert('RGB')
        img = self.transform(img)
        return img, torch.tensor(label, dtype=torch.float32), path

# ---------------------------
# Build dataframe (train/test split)
# ---------------------------
train_root = os.path.join(args.data_dir, "train")
if not os.path.exists(train_root):
    raise FileNotFoundError(f"Expected {train_root} to exist and contain class subfolders")

image_paths, labels, classes = discover_images(train_root)
print(f"Found classes: {classes}")
print("Total images:", len(image_paths))

df = pd.DataFrame({"path": image_paths, "class_idx": labels})
df['class_name'] = df['class_idx'].map(lambda i: classes[i])
df['is_healthy'] = df['class_name'].str.contains('healthy', case=False, na=False)
df['binary_label'] = df['is_healthy'].astype(int)

total = len(df)
healthy_count = df['binary_label'].sum()
diseased_count = total - healthy_count
print(f"Healthy: {healthy_count}, Diseased: {diseased_count}")

# Stratified split
train_df, temp_df = train_test_split(df, test_size=0.30, stratify=df['binary_label'], random_state=args.seed)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df['binary_label'], random_state=args.seed)
print("Split sizes -> Train:", len(train_df), "Val:", len(val_df), "Test:", len(test_df))

# ---------------------------
# Dataloaders
# ---------------------------
train_ds = PlantDataset(train_df, img_size=args.img_size, training=True)
val_ds = PlantDataset(val_df, img_size=args.img_size, training=False)
test_ds = PlantDataset(test_df, img_size=args.img_size, training=False)

# We can optionally use a WeightedRandomSampler if class imbalance severe:
counts = train_df['binary_label'].value_counts().to_dict()
print("Train class counts:", counts)
class_sample_count = np.array([counts.get(0,0), counts.get(1,0)])
# weight per sample: inverse of class frequency
weights = 1.0 / class_sample_count[train_df['binary_label'].values]
sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler,
                          num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False,
                         num_workers=args.num_workers, pin_memory=True)

# ---------------------------
# Model, loss, optimizer
# ---------------------------
def build_model(pretrained=True):
    # torchvision has mobilenet_v3_small
    model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1 if pretrained else None)
    # replace classifier / head
    in_features = model.classifier[0].in_features if hasattr(model, "classifier") else 576
    # build a small head similar to TF variant
    model.classifier = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),  # some versions already do pooling; we keep safe
        nn.Flatten(),
        nn.Dropout(0.3),
        nn.Linear(in_features, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 1)  # logits for BCEWithLogitsLoss
    )
    return model

model = build_model(pretrained=True)
model.to(DEVICE)
print(model)

# Freeze backbone (everything except final head) initially
for name, param in model.named_parameters():
    if "classifier" not in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

# compute pos_weight for BCEWithLogitsLoss (pos_weight = negative/positive)
neg = counts.get(0, 0)
pos = counts.get(1, 0)
if pos == 0:
    pos = 1  # avoid div by zero
pos_weight = torch.tensor([neg / pos], dtype=torch.float32).to(DEVICE)
print("pos_weight for BCEWithLogitsLoss:", pos_weight.item())

# Loss
def focal_loss_with_logits(logits, targets, alpha=0.75, gamma=2.0, reduction="mean"):
    # logits: raw outputs (not sigmoid)
    targets = targets.view(-1,1)
    prob = torch.sigmoid(logits)
    p_t = prob * targets + (1 - prob) * (1 - targets)
    alpha_factor = targets * alpha + (1 - targets) * (1 - alpha)
    focal_weight = alpha_factor * (1 - p_t) ** gamma
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    loss = focal_weight * bce
    return loss.mean() if reduction=="mean" else loss.sum()

# We'll choose criterion at runtime
if args.use_focal:
    criterion = None  # we will call focal_loss manually
    print("Using focal loss (custom)")
else:
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # handles imbalance
    print("Using BCEWithLogitsLoss with pos_weight")

# Optimizers / schedulers
def make_optimizer(model, lr):
    # only params with requires_grad True
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=lr)

# Scheduler: linear warmup then cosine decay (per-epoch)
def make_scheduler(optimizer, total_epochs, warmup_epochs):
    def lr_lambda(epoch):
        if epoch < warmup_epochs:
            return float(epoch + 1) / float(max(1, warmup_epochs))
        # cosine from 1 -> 0 over remaining epochs
        t = (epoch - warmup_epochs) / max(1, (total_epochs - warmup_epochs))
        return 0.5 * (1.0 + math.cos(math.pi * t))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# ---------------------------
# Training / Evaluation loops
# ---------------------------
scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE.type == "cuda"))

def train_one_epoch(model, loader, optimizer, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    n = 0
    for imgs, labels, _ in loader:
        imgs = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True).unsqueeze(1)  # shape (B,1)
        optimizer.zero_grad()
        with torch.cuda.amp.autocast(enabled=(DEVICE.type=="cuda")):
            logits = model(imgs)
            if args.use_focal:
                loss = focal_loss_with_logits(logits, labels, alpha=args.focal_alpha, gamma=args.focal_gamma)
            else:
                loss = criterion(logits, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * imgs.size(0)
        n += imgs.size(0)
    return running_loss / n

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []
    all_paths = []
    for imgs, labels, paths in loader:
        imgs = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True).unsqueeze(1)
        with torch.cuda.amp.autocast(enabled=(DEVICE.type=="cuda")):
            logits = model(imgs)
            probs = torch.sigmoid(logits).cpu().numpy().ravel()
        preds = (probs > 0.5).astype(int)
        all_labels.extend(labels.cpu().numpy().ravel().tolist())
        all_probs.extend(probs.tolist())
        all_preds.extend(preds.tolist())
        all_paths.extend(paths)
    # metrics
    y_true = np.array(all_labels)
    y_pred = np.array(all_preds)
    y_prob = np.array(all_probs)
    acc = accuracy_score(y_true, y_pred)
    bal = balanced_accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    try:
        auc = roc_auc_score(y_true, y_prob)
    except:
        auc = float('nan')
    cm = confusion_matrix(y_true, y_pred)
    return {
        'accuracy': acc, 'balanced_acc': bal, 'precision': prec,
        'recall': rec, 'f1': f1, 'auc': auc, 'confusion_matrix': cm,
        'y_true': y_true, 'y_pred': y_pred, 'y_prob': y_prob, 'paths': all_paths
    }

# ---------------------------
# Full training orchestration
# ---------------------------
best_val_loss = float('inf')
best_metric = None

# HEAD training
optimizer = make_optimizer(model, lr=args.lr_head)
scheduler = make_scheduler(optimizer, total_epochs=args.epochs_head + args.epochs_finetune, warmup_epochs=2)

print("Starting head training...")
for epoch in range(args.epochs_head):
    loss = train_one_epoch(model, train_loader, optimizer, epoch, args.epochs_head + args.epochs_finetune)
    scheduler.step()
    val_stats = evaluate(model, val_loader)
    val_loss = 1.0 - val_stats['f1']  # cheap proxy for saved metric (you can compute val loss too)
    print(f"Epoch {epoch+1}/{args.epochs_head} - train_loss: {loss:.4f}, val_f1: {val_stats['f1']:.4f}, val_acc: {val_stats['accuracy']:.4f}")
    # save best by f1
    if val_stats['f1'] > (best_metric or 0):
        best_metric = val_stats['f1']
        torch.save({'model_state_dict': model.state_dict(), 'args': vars(args)}, args.save_path)
        print("Saved best model ->", args.save_path)

# Fine-tune: unfreeze last N layers of backbone
print("Unfreezing last", args.unfreeze_last_n, "layers of backbone for fine-tuning...")
# find backbone params and unfreeze last N by order
backbone_layers = [n for n, p in model.named_parameters() if "classifier" not in n]
# Unfreeze last N (best effort)
for name, p in model.named_parameters():
    if "classifier" not in name:
        p.requires_grad = False
# Attempt to unfreeze last N from backbone by iterating reversed order:
count_unf = 0
for name, p in reversed(list(model.named_parameters())):
    if "classifier" in name:
        continue
    if count_unf < args.unfreeze_last_n:
        p.requires_grad = True
        count_unf += 1

print("Trainable params after unfreeze:")
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_p = sum(p.numel() for p in model.parameters())
print(f"{trainable} / {total_p} trainable parameters")

# Recreate optimizer with lower LR
optimizer = make_optimizer(model, lr=args.lr_finetune)
scheduler = make_scheduler(optimizer, total_epochs=args.epochs_head + args.epochs_finetune, warmup_epochs=2)

print("Starting fine-tuning...")
for epoch in range(args.epochs_head, args.epochs_head + args.epochs_finetune):
    loss = train_one_epoch(model, train_loader, optimizer, epoch, args.epochs_head + args.epochs_finetune)
    scheduler.step()
    val_stats = evaluate(model, val_loader)
    print(f"Fine Epoch {epoch+1}/{args.epochs_head+args.epochs_finetune} - train_loss: {loss:.4f}, val_f1: {val_stats['f1']:.4f}, val_acc: {val_stats['accuracy']:.4f}")
    if val_stats['f1'] > (best_metric or 0):
        best_metric = val_stats['f1']
        torch.save({'model_state_dict': model.state_dict(), 'args': vars(args)}, args.save_path)
        print("Saved best model ->", args.save_path)

# Save final
final_path = "mobilenetv3_final.pt"
torch.save({'model_state_dict': model.state_dict(), 'args': vars(args)}, final_path)
print("Saved final model ->", final_path)

# Evaluate on test set
test_stats = evaluate(model, test_loader)
print("=== TEST RESULTS ===")
print("Accuracy:", test_stats['accuracy'])
print("Balanced Acc:", test_stats['balanced_acc'])
print("Precision:", test_stats['precision'])
print("Recall:", test_stats['recall'])
print("F1:", test_stats['f1'])
print("AUC:", test_stats['auc'])
print("Confusion matrix:\n", test_stats['confusion_matrix'])

# Save predictions
pred_df = pd.DataFrame({
    'path': test_stats['paths'],
    'true': test_stats['y_true'],
    'pred': test_stats['y_pred'],
    'prob': test_stats['y_prob']
})
pred_df.to_csv("pytorch_test_predictions.csv", index=False)
print("Saved predictions -> pytorch_test_predictions.csv")


AttributeError: partially initialized module 'torch._inductor' from 'C:\Users\Satyam\AppData\Roaming\Python\Python313\site-packages\torch\_inductor\__init__.py' has no attribute 'custom_graph_pass' (most likely due to a circular import)