# Robust Test-Time Adaptation (TTA) Full Benchmark
## USA (ASDID) -> India (MH-SoyaHealthVision)

This notebook implements the full benchmark suite for domain generalization in soybean disease classification. It compares direct transfer, Adaptive Batch Normalization (AdaBN), and Test-time Entropy Minimization (TENT).

### Benchmark Grid:
- **Seeds**: `[21, 42, 73]`
- **Models**: `convnext_tiny`, `resnet50`
- **Methods**: `Direct_Transfer`, `AdaBN`, `TENT_Online` (Robust TTA)

### Metrics:
- **Macro F1**: Overall balance.
- **Rust Recall**: Critical for safety (avoiding False Negatives for Rust).
- **Frogeye Precision**: Accuracy in identifying Frogeye.
- **Rust as Healthy FN**: Safety hazard count per run.

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import os
import sys
import numpy as np
from tqdm.notebook import tqdm

# Add src to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from src.dataset import get_dataloaders
from src.model import get_model, apply_adabn, configure_for_tta
from src.tta import TTAOptimizer
from src.utils import (
    set_seed, 
    calculate_metrics, 
    log_experiment, 
    plot_confusion_matrix
)

## 1. Global Configuration

In [None]:
QUICK_CHECK = True  # Toggle for fast verification

SEEDS = [42] if QUICK_CHECK else [21, 42, 73]
MODELS = ['resnet50'] if QUICK_CHECK else ['convnext_tiny', 'resnet50']
METHODS = ['Direct_Transfer', 'AdaBN', 'TENT_Online']

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
DATA_ASDID = "../data/ASDID"
DATA_MH = "../data/MH-SoyaHealthVision/Soyabean_Leaf_Image_Dataset"
CLASS_NAMES = ['Healthy', 'Rust', 'Frogeye']
LOG_DIR = "../results/experiment_logs"
CM_DIR = os.path.join(LOG_DIR, "confusion_matrices")

os.makedirs(CM_DIR, exist_ok=True)
print(f"Device: {DEVICE}")
print(f"Mode: {'Quick Check' if QUICK_CHECK else 'Full Benchmark'}")

## 2. Helper Components

In [None]:
def train_baseline(model, train_loader, val_loader, device, epochs=5):
    """Trains the baseline model on the source dataset (ASDID)."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    best_val_f1 = 0 
    best_model_state = None
    
    for epoch in range(epochs):
        model.train()
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
        # Validation
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        metrics = calculate_metrics(all_labels, all_preds, CLASS_NAMES)
        if metrics['F1'] > best_val_f1:
            best_val_f1 = metrics['F1']
            best_model_state = model.state_dict().copy()
            
    if best_model_state:
        model.load_state_dict(best_model_state)
    return model

## 3. Execution Grid

In [None]:
for model_name in MODELS:
    for seed in SEEDS:
        run_id_base = f"{model_name}_Seed{seed}"
        print(f"\n{'='*60}\nStarting Grid: {run_id_base}\n{'='*60}")
        set_seed(seed)
        
        # Setup Data
        train_loader_asdid, val_loader_asdid, test_loader_asdid, _, _, _, _ = get_dataloaders("ASDID", DATA_ASDID, seed=seed)
        train_loader_mh, _, test_loader_mh, _, _, _, _ = get_dataloaders("MH", DATA_MH, seed=seed)
        
        # 1. Base Training on ASDID
        print(f"Training source model on ASDID...")
        base_model = get_model(model_name, num_classes=3).to(DEVICE)
        base_model = train_baseline(base_model, train_loader_asdid, val_loader_asdid, DEVICE, epochs=1 if QUICK_CHECK else 5)
        
        for method in METHODS:
            print(f"\n>>> Evaluating Method: {method}")
            # RESET MODEL to base state for each method
            model = get_model(model_name, num_classes=3).to(DEVICE)
            model.load_state_dict(base_model.state_dict())
            
            adaptation = "None"
            if method == 'AdaBN':
                model = apply_adabn(model, train_loader_mh, DEVICE)
                adaptation = "AdaBN"
            elif method == 'TENT_Online':
                model = configure_for_tta(model, method='tent')
                tta_opt = TTAOptimizer(model, lr=1e-3, steps=1)
                model = tta_opt.run_adaptation(train_loader_mh, DEVICE) # Simulated online stream
                adaptation = "TENT"
            
            # Final Evaluation on target domain
            model.eval()
            all_preds, all_labels = [], []
            with torch.no_grad():
                for inputs, labels in test_loader_mh:
                    inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    all_preds.extend(preds.cpu().numpy())
                    all_labels.extend(labels.cpu().numpy())
            
            metrics = calculate_metrics(all_labels, all_preds, CLASS_NAMES)
            
            # Logging & CM Plotting
            run_id = f"{run_id_base}_{method}"
            cm_path = os.path.join(CM_DIR, f"{run_id}_cm.png")
            plot_confusion_matrix(metrics['Confusion_Matrix'], CLASS_NAMES, title=f"CM: {run_id}", save_path=cm_path)
            
            log_experiment(run_id, seed, model_name, "ASDID", "MH", metrics, adaptation=adaptation, log_dir=LOG_DIR)
            
            print(f"Result [{method}]: Accuracy: {metrics['Accuracy']:.4f}, F1: {metrics['F1']:.4f}")
            print(f"Safety Check [Rust Recall]: {metrics['Rust_Recall']:.4f}, Rust->Healthy FN: {metrics['Rust_as_Healthy_FN']}")