# Imports 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms,models
from torchvision.transforms import RandAugment
import matplotlib.pyplot as plt
import numpy as np
import random
from torch.optim import Adam
import torch.optim as optim
import mlflow
import mlflow.pytorch
from sklearn.metrics import classification_report, confusion_matrix
import copy
import time
from torch.utils.data import DataLoader, random_split, Subset



# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)



# Load and Transform Data

In [None]:


CIFAR10_MEAN = [0.4914, 0.4822, 0.4465]
CIFAR10_STD  = [0.2470, 0.2435, 0.2616]

def get_transforms(use_randaugment=False, use_cutout=False):
    # --- Training transforms (with augmentation) ---
    train_tf_list = [
        transforms.Resize(256),
        transforms.RandomCrop(224),           # crop to 224×224
        transforms.RandomHorizontalFlip()
    ]
    
    if use_randaugment:
        from torchvision.transforms import RandAugment
        train_tf_list.append(RandAugment())
    
    train_tf_list.append(transforms.ToTensor())
    
    if use_cutout:
        class Cutout:
            def __init__(self, n_holes=1, length=16):
                self.n_holes = n_holes
                self.length = length
            def __call__(self, img):
                h, w = img.shape[1], img.shape[2]
                mask = torch.ones((h, w), dtype=torch.float32)
                import numpy as np
                for _ in range(self.n_holes):
                    y = np.random.randint(h)
                    x = np.random.randint(w)
                    y1 = np.clip(y - self.length // 2, 0, h)
                    y2 = np.clip(y + self.length // 2, 0, h)
                    x1 = np.clip(x - self.length // 2, 0, w)
                    x2 = np.clip(x + self.length // 2, 0, w)
                    mask[y1:y2, x1:x2] = 0.0
                mask = mask.expand_as(img)
                return img * mask
        train_tf_list.append(Cutout(n_holes=1, length=16))
    
    train_tf_list.append(transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD))
    train_transform = transforms.Compose(train_tf_list)
    
    # --- Validation/Test transforms (no augmentation) ---
    val_test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
    ])
    
    return train_transform, val_test_transform

def get_dataloaders(batch_size=128, use_randaugment=False, use_cutout=False, num_workers=4, val_split=0.1):
    train_tf, val_test_tf = get_transforms(use_randaugment, use_cutout)
    
    # Load the full CIFAR-10 train set with training transforms (including augmentations)
    full_train = datasets.CIFAR10("./data", train=True, download=True, transform=train_tf)
    n = len(full_train)
    val_size = int(n * val_split)
    train_size = n - val_size
    
    # Split indices
    train_indices, val_indices = random_split(list(range(n)), [train_size, val_size])
    
    # Create subsets
    train_set = Subset(full_train, train_indices)
    
    # For validation set, use clean transform (no augmentation)
    # So create a new dataset instance with val/test transform
    full_val = datasets.CIFAR10("./data", train=True, download=False, transform=val_test_tf)
    val_set = Subset(full_val, val_indices)
    
    # Test set
    test_set = datasets.CIFAR10("./data", train=False, download=True, transform=val_test_tf)
    
    # DataLoaders
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader   = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, val_loader, test_loader


# Model

In [None]:
model_base = models.resnet18(pretrained=True)
num_classes = 10  # CIFAR-10
model_base.fc = nn.Linear(model_base.fc.in_features, num_classes)


# Model Training

In [None]:
def train_model_mlflow(model, train_loader, val_loader, epochs=5, lr=1e-3, device='cuda',
                       experiment_name="resnet_experiments", optimizer_type="SGD"):
    mlflow.set_experiment(experiment_name)
    with mlflow.start_run(run_name=f"training_run_{int(time.time())}"):

        mlflow.log_params({"epochs": epochs, "lr": lr, "optimizer": optimizer_type})
        model.to(device)
        criterion = nn.CrossEntropyLoss()
        trainable_params = [p for p in model.parameters() if p.requires_grad]

        if optimizer_type.lower() == "sgd":
            optimizer = optim.SGD(trainable_params, lr=lr, momentum=0.9, weight_decay=5e-4)
        elif optimizer_type.lower() == "adam":
            optimizer = optim.Adam(trainable_params, lr=lr, weight_decay=5e-4)
        else:
            raise ValueError("Unsupported optimizer type")

        history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
        start_time = time.time()

        for epoch in range(epochs):
            # Training
            model.train()
            running_loss, correct, total = 0.0, 0, 0
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

            train_loss = running_loss / total
            train_acc = correct / total

            # Validation
            model.eval()
            val_loss, val_correct, val_total = 0.0, 0, 0
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item() * inputs.size(0)
                    _, predicted = outputs.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()

            val_loss /= val_total
            val_acc = val_correct / val_total

            # Save metrics
            history['train_loss'].append(train_loss)
            history['val_loss'].append(val_loss)
            history['train_acc'].append(train_acc)
            history['val_acc'].append(val_acc)

            print(f"Epoch [{epoch+1}/{epochs}] Train Loss: {train_loss:.4f}, "
                  f"Train Acc: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

            mlflow.log_metrics({"train_loss": train_loss, "train_acc": train_acc,
                                "val_loss": val_loss, "val_acc": val_acc}, step=epoch)

        total_time = time.time() - start_time
        mlflow.log_metric("training_time_sec", total_time)
        mlflow.pytorch.log_model(model, "final_model")
        print(f"Model saved to MLflow. Training time: {total_time:.2f} sec")

    return model, history


In [None]:
# Make sure this is executed before training
train_loader, val_loader, test_loader = get_dataloaders(
    batch_size=128,
    use_randaugment=True,
    use_cutout=False,
    num_workers=4,
    val_split=0.1
)


# Complete unfreeze

In [None]:
model_unfrozen = copy.deepcopy(model_base)
for param in model_unfrozen.parameters():
    param.requires_grad = True

trained_model_unfrozen, history_unfrozen = train_model_mlflow(
    model=model_unfrozen,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=20,
    lr=1e-3,
    device=device,
    experiment_name="resnet_cifar10"
)

torch.save(trained_model_unfrozen.state_dict(), "resnet_unfrozen.pth")


# Last layer unfreeze

In [None]:
model_partial = copy.deepcopy(model_base)
for name, param in model_partial.named_parameters():
    if "fc" not in name:
        param.requires_grad = False

trained_model_fc, history_fc = train_model_mlflow(
    model=model_partial,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=20,
    lr=1e-3,
    device=device,
    experiment_name="resnet_cifar10"
)

torch.save(trained_model_fc.state_dict(), "resnet_partial_freeze.pth")


# Layer4 and fc layer unfreeze

In [None]:
import torch
import torch.nn as nn
import copy

# -----------------------------------------------------
# Create partial-unfreeze model from model_base
# -----------------------------------------------------
model_partial = copy.deepcopy(model_base)

# Freeze ALL layers first
for param in model_partial.parameters():
    param.requires_grad = False

# Unfreeze ONLY layer4 and fc
for name, param in model_partial.named_parameters():
    if "layer4" in name or "fc" in name:
        param.requires_grad = True

# -----------------------------------------------------
# Train using MLflow logging
# -----------------------------------------------------
trained_model_partial, history_partial = train_model_mlflow(
    model=model_partial,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=20,
    lr=1e-3,                 # use higher LR for unfrozen layers
    device=device,
    experiment_name="resnet_cifar10_partial_unfreeze"
)

# -----------------------------------------------------
# Save final weights
# -----------------------------------------------------
torch.save(trained_model_partial.state_dict(), "resnet_partial_unfreeze.pth")

print("✔ Partial-unfreeze training complete and model saved.")


In [None]:
plt.figure(figsize=(14,6))

# ---------------------------
# Loss Curves
# ---------------------------
plt.subplot(1,2,1)
plt.plot(history_unfrozen['train_loss'], label='Train Loss (Unfrozen)')
plt.plot(history_unfrozen['val_loss'], label='Val Loss (Unfrozen)')

plt.plot(history_partial['train_loss'], label='Train Loss (FC Only)')
plt.plot(history_partial['val_loss'], label='Val Loss (FC Only)')

plt.plot(history_fc['train_loss'], label='Train Loss (Layer4 + FC)')
plt.plot(history_fc['val_loss'], label='Val Loss (Layer4 + FC)')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Curves')
plt.legend()


# ---------------------------
# Accuracy Curves
# ---------------------------
plt.subplot(1,2,2)
plt.plot(history_unfrozen['train_acc'], label='Train Acc (Unfrozen)')
plt.plot(history_unfrozen['val_acc'], label='Val Acc (Unfrozen)')

plt.plot(history_partial['train_acc'], label='Train Acc (FC Only)')
plt.plot(history_partial['val_acc'], label='Val Acc (FC Only)')

plt.plot(history_fc['train_acc'], label='Train Acc (Layer4 + FC)')
plt.plot(history_fc['val_acc'], label='Val Acc (Layer4 + FC)')

plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Accuracy Curves')
plt.legend()

plt.show()


In [None]:
import torch
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt

# -----------------------------
# Function to evaluate metrics
# -----------------------------
def evaluate_model(model, data_loader, device):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    # Overall accuracy
    acc = accuracy_score(y_true, y_pred)
    print(f"\nTest Accuracy: {acc:.4f}")

    # Per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, labels=range(10), average=None
    )

    # Get class names safely even if dataset is Subset
    if isinstance(data_loader.dataset, torch.utils.data.Subset):
        class_names = data_loader.dataset.dataset.classes
    else:
        class_names = data_loader.dataset.classes

    print("\nPer-class Metrics:")
    for i, class_name in enumerate(class_names):
        print(f"{class_name}: Precision={precision[i]:.3f}, Recall={recall[i]:.3f}, F1={f1[i]:.3f}")

    # Macro and Weighted F1
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro'
    )
    precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
        y_true, y_pred, average='weighted'
    )

    print(f"\nMacro F1: {f1_macro:.3f}, Weighted F1: {f1_weighted:.3f}")

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()
    
    return acc, precision, recall, f1

# -----------------------------
# Example usage for test set
# -----------------------------
print("Last FC layer unfreeze:")
acc_resnet_fc = evaluate_model(trained_model_fc, test_loader, device)

print("Partial Layer4 + FC unfreeze:")
acc_resnet_partial = evaluate_model(trained_model_partial, test_loader, device)

print("Fully unfrozen Fine-tuned ResNet:")
acc_resnet_unfrozen = evaluate_model(trained_model_unfrozen, test_loader, device)



In [None]:
import matplotlib.pyplot as plt
import torch

CIFAR10_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(3,1,1)
CIFAR10_STD  = torch.tensor([0.2470, 0.2435, 0.2616]).view(3,1,1)

def unnormalize(img):
    return torch.clamp(img * CIFAR10_STD + CIFAR10_MEAN, 0, 1)

def get_class_names(loader):
    ds = loader.dataset

    if hasattr(ds, "classes"):
        return ds.classes

    if hasattr(ds, "dataset") and hasattr(ds.dataset, "classes"):
        return ds.dataset.classes

    return [str(i) for i in range(10)]


def show_misclassifications(model, data_loader, device, max_images=20):
    class_names = get_class_names(data_loader)

    model.eval()
    wrong_images, wrong_preds, wrong_labels = [], [], []

    with torch.no_grad():
        for imgs, labels in data_loader:
            imgs, labels = imgs.to(device), labels.to(device)

            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            wrong_mask = preds != labels

            idxs = wrong_mask.nonzero(as_tuple=False).flatten()
            for idx in idxs:
                wrong_images.append(imgs[idx].cpu())
                wrong_preds.append(preds[idx].item())
                wrong_labels.append(labels[idx].item())

                if len(wrong_images) == max_images:
                    break

            if len(wrong_images) == max_images:
                break

    if len(wrong_images) == 0:
        print("❗No misclassified images found.")
        return

    plt.figure(figsize=(15, 8))
    for i in range(len(wrong_images)):
        plt.subplot(4, 5, i + 1)
        img = unnormalize(wrong_images[i]).permute(1, 2, 0).numpy()
        plt.imshow(img)
        plt.title(f"Pred: {class_names[wrong_preds[i]]}\nTrue: {class_names[wrong_labels[i]]}",
                  fontsize=9)
        plt.axis("off")

    plt.tight_layout()
    plt.show()


# Misclassification Report

In [None]:

class_names = train_loader.dataset.dataset.classes  # handles Subset
show_misclassifications(trained_model_unfrozen, test_loader, device)


In [None]:
import math
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# You need your unnormalize function for displaying images
def unnormalize(tensor, mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]):
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor.clamp(0, 1)

def apply_gradcam_correct_vs_wrong(model, data_loader, class_names, device, max_images=20, alpha=0.3):
    model.eval()
    target_layer = model.layer4[-1]  # last conv layer
    cam = GradCAM(model=model, target_layers=[target_layer])

    correct_images, correct_cams, correct_preds = [], [], []
    wrong_images, wrong_cams, wrong_preds, wrong_labels = [], [], [], []

    for images, labels in data_loader:
        images = images.to(device)
        labels = labels.to(device)

        for i in range(images.size(0)):
            img = images[i].unsqueeze(0)  # 1xCxHxW
            img.requires_grad = True

            outputs = model(img)
            pred = outputs.argmax(dim=1).item()
            true_label = labels[i].item()

            # Generate CAM for predicted class
            targets = [ClassifierOutputTarget(pred)]
            grayscale_cam = cam(input_tensor=img, targets=targets)[0, :, :]  # already numpy

            # Resize to input size
            cam_tensor = torch.tensor(grayscale_cam)[None, None, ...]  # 1x1xHxW
            cam_resized = F.interpolate(cam_tensor, size=(224,224), mode='bilinear', align_corners=False)[0,0].numpy()

            # Unnormalize for plotting
            img_np = unnormalize(images[i].cpu()).permute(1,2,0).numpy()

            # Separate correct and wrong predictions
            if pred == true_label:
                if len(correct_images) < max_images:
                    correct_images.append(img_np)
                    correct_cams.append(cam_resized)
                    correct_preds.append(pred)
            else:
                if len(wrong_images) < max_images:
                    wrong_images.append(img_np)
                    wrong_cams.append(cam_resized)
                    wrong_preds.append(pred)
                    wrong_labels.append(true_label)

            # Stop if both reached max_images
            if len(correct_images) >= max_images and len(wrong_images) >= max_images:
                break
        if len(correct_images) >= max_images and len(wrong_images) >= max_images:
            break

    # Function to plot grid of images with CAM
    def plot_grid(images_list, cams_list, preds_list, labels_list=None, title=""):
        cols = 5
        rows = math.ceil(len(images_list)/cols)
        fig, axes = plt.subplots(rows, cols, figsize=(cols*4, rows*4))
        axes = axes.flatten()
        for idx in range(len(images_list)):
            axes[idx].imshow(images_list[idx])
            axes[idx].imshow(cams_list[idx], cmap='jet', alpha=alpha)
            if labels_list is not None:
                axes[idx].set_title(f"Pred: {class_names[preds_list[idx]]}\nTrue: {class_names[labels_list[idx]]}", fontsize=9)
            else:
                axes[idx].set_title(f"Pred: {class_names[preds_list[idx]]}", fontsize=9)
            axes[idx].axis('off')
        # Turn off extra axes
        for idx in range(len(images_list), len(axes)):
            axes[idx].axis('off')
        plt.suptitle(title, fontsize=14)
        plt.tight_layout()
        plt.show()

     # Plot correct predictions
    if correct_images:
        plot_grid(correct_images, correct_cams, correct_preds, title="Grad-CAM: Correctly Classified Images")

    # Plot misclassified predictions with true labels
    if wrong_images:
        plot_grid(wrong_images, wrong_cams, wrong_preds, labels_list=wrong_labels,
                  title="Grad-CAM: Misclassified Images")

# Grad-Cam

In [None]:
class_names = test_loader.dataset.classes
apply_gradcam_correct_vs_wrong(trained_model_unfrozen, test_loader, class_names, device)


# Adversarial Sanity Check


In [None]:
import torch
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
import matplotlib.pyplot as plt
import numpy as np

# -----------------------------
# Function to compute metrics on original and FGSM adversarial examples
# -----------------------------
def adversarial_metrics(model, data_loader, device, epsilon=0.03):
    model.eval()
    y_true = []
    y_pred_original = []
    y_pred_adv = []

    for images, labels in data_loader:
        images, labels = images.to(device), labels.to(device)
        images.requires_grad = True

        # --- Forward on original ---
        outputs = model(images)
        preds_orig = outputs.argmax(dim=1)
        
        # --- FGSM attack ---
        loss = torch.nn.CrossEntropyLoss()(outputs, labels)
        model.zero_grad()
        loss.backward()
        perturbation = epsilon * images.grad.sign()
        adv_images = torch.clamp(images + perturbation, 0, 1)

        # --- Forward on adversarial ---
        with torch.no_grad():
            adv_outputs = model(adv_images)
            preds_adv = adv_outputs.argmax(dim=1)

        # --- Collect labels and predictions ---
        y_true.extend(labels.cpu().numpy())
        y_pred_original.extend(preds_orig.cpu().numpy())
        y_pred_adv.extend(preds_adv.cpu().numpy())

        # Free memory
        del images, labels, outputs, loss, adv_images, adv_outputs, perturbation
        torch.cuda.empty_cache()

    # Metrics calculation
    def calc_metrics(y_true, y_pred):
        acc = accuracy_score(y_true, y_pred)
        precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted')
        return acc, precision, recall, f1

    return calc_metrics(y_true, y_pred_original), calc_metrics(y_true, y_pred_adv)

# Compute for all models
models = {
    "Fully Unfrozen": trained_model_unfrozen,
    "Partial Layer4+FC": trained_model_partial,
    "Last FC Only": trained_model_fc
}

metrics_results = {}

for name, model in models.items():
    orig_metrics, adv_metrics = adversarial_metrics(model, test_loader, device, epsilon=0.03)
    metrics_results[name] = {"original": orig_metrics, "adversarial": adv_metrics}
    print(f"\n{name}:\nOriginal: {orig_metrics}\nAdversarial: {adv_metrics}")

# Plot comparison graphs
metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1']
x = np.arange(len(models))
width = 0.35

plt.figure(figsize=(16,6))
for i, metric in enumerate(metrics_names):
    plt.subplot(1,4,i+1)
    orig_vals = [metrics_results[m]['original'][i] for m in models]
    adv_vals = [metrics_results[m]['adversarial'][i] for m in models]

    plt.bar(x - width/2, orig_vals, width, label='Original')
    plt.bar(x + width/2, adv_vals, width, label='Adversarial')
    plt.xticks(x, models.keys(), rotation=15)
    plt.ylabel(metric)
    plt.ylim(0,1)
    plt.title(f"{metric} Comparison")
    plt.legend()

plt.tight_layout()
plt.show()


# Calibration

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

# -----------------------------
# Function to compute ECE
# -----------------------------
def compute_ece(model, data_loader, device, n_bins=15):
    """
    Expected Calibration Error (ECE) for multiclass classification
    """
    model.eval()
    confidences = []
    predictions = []
    labels_list = []

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            confs, preds = torch.max(probs, dim=1)
            
            confidences.extend(confs.cpu().numpy())
            predictions.extend(preds.cpu().numpy())
            labels_list.extend(labels.numpy())

    confidences = np.array(confidences)
    predictions = np.array(predictions)
    labels_list = np.array(labels_list)
    
    # ECE calculation
    bin_edges = np.linspace(0, 1, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        bin_lower = bin_edges[i]
        bin_upper = bin_edges[i+1]
        mask = (confidences > bin_lower) & (confidences <= bin_upper)
        if np.any(mask):
            accuracy_in_bin = (predictions[mask] == labels_list[mask]).mean()
            avg_confidence_in_bin = confidences[mask].mean()
            ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * mask.mean()
    
    return ece, confidences, predictions, labels_list

# -----------------------------
# Plot reliability diagram
# -----------------------------
def plot_reliability_diagram(confidences, predictions, labels, n_bins=15, title="Reliability Diagram"):
    bin_edges = np.linspace(0,1,n_bins+1)
    acc_bins = np.zeros(n_bins)
    conf_bins = np.zeros(n_bins)
    counts = np.zeros(n_bins)

    for i in range(n_bins):
        mask = (confidences > bin_edges[i]) & (confidences <= bin_edges[i+1])
        if np.any(mask):
            acc_bins[i] = (predictions[mask] == labels[mask]).mean()
            conf_bins[i] = confidences[mask].mean()
            counts[i] = mask.mean()

    plt.figure(figsize=(6,6))
    plt.plot([0,1],[0,1], linestyle='--', color='gray')
    plt.plot(conf_bins, acc_bins, marker='o', label='Model')
    plt.fill_between(conf_bins, acc_bins, conf_bins, alpha=0.2, color='blue')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.title(title)
    plt.grid(True)
    plt.show()

# -----------------------------
# Compute & plot for all models
# -----------------------------
models = {
    "Fully Unfrozen": trained_model_unfrozen,
    "Partial Layer4+FC": trained_model_partial,
    "Last FC Only": trained_model_fc
}

for name, model in models.items():
    ece, confs, preds, labels = compute_ece(model, test_loader, device, n_bins=15)
    print(f"{name} ECE: {ece:.4f}")
    plot_reliability_diagram(confs, preds, labels, n_bins=15, title=f"{name} Reliability Diagram")


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from sklearn.calibration import calibration_curve

# -----------------------------
# Expected Calibration Error
# -----------------------------
def compute_ece(probs, labels, n_bins=15):
    """
    Computes Expected Calibration Error (ECE)
    probs: [N, C] predicted probabilities
    labels: [N] true labels
    """
    confidences, predictions = torch.max(probs, 1)
    labels = labels.cpu().numpy()
    confidences = confidences.cpu().numpy()
    predictions = predictions.cpu().numpy()
    
    ece = 0.0
    bin_edges = np.linspace(0, 1, n_bins+1)
    for i in range(n_bins):
        mask = (confidences > bin_edges[i]) & (confidences <= bin_edges[i+1])
        if np.sum(mask) > 0:
            acc = np.mean(predictions[mask] == labels[mask])
            conf = np.mean(confidences[mask])
            ece += np.sum(mask) / len(labels) * np.abs(acc - conf)
    return ece

# -----------------------------
# Temperature scaling wrapper
# -----------------------------
class ModelWithTemperature(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.temperature = nn.Parameter(torch.ones(1))  # initialized as 1

    def forward(self, x):
        logits = self.model(x)
        return self.temperature_scale(logits)

    def temperature_scale(self, logits):
        return logits / self.temperature

    def set_temperature(self, valid_loader, device):
        self.to(device)
        nll_criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.LBFGS([self.temperature], lr=0.01, max_iter=50)

        # Collect logits and labels without gradient for model
        logits_list = []
        labels_list = []

        self.model.eval()
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device)
                logits = self.model(images)
                logits_list.append(logits)
                labels_list.append(labels)

        logits_all = torch.cat(logits_list)
        labels_all = torch.cat(labels_list)

        def eval():
            optimizer.zero_grad()
            loss = nll_criterion(self.temperature_scale(logits_all), labels_all)
            loss.backward()
            return loss

        optimizer.step(eval)
        print(f"Optimal temperature: {self.temperature.item():.3f}")
        return self

# -----------------------------
# Evaluation function
# -----------------------------
def evaluate_model_with_calibration(model, data_loader, device, n_bins=15):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            logits = model(images)
            probs = F.softmax(logits, dim=1)
            all_probs.append(probs)
            all_labels.append(labels)

    all_probs = torch.cat(all_probs)
    all_labels = torch.cat(all_labels)

    ece = compute_ece(all_probs, all_labels, n_bins=n_bins)
    print(f"ECE: {ece:.4f}")

    # Reliability diagram
    plt.figure(figsize=(6,6))
    for i in range(all_probs.shape[1]):  # per-class calibration curve
        frac_pos, mean_conf = calibration_curve(
            (all_labels.cpu().numpy() == i),
            all_probs[:, i].cpu().numpy(),
            n_bins=n_bins
        )
        plt.plot(mean_conf, frac_pos, marker='o', label=f"Class {i}")
    plt.plot([0,1],[0,1],'k--')
    plt.xlabel("Mean predicted probability")
    plt.ylabel("Fraction of positives")
    plt.title("Reliability Diagram")
    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')
    plt.show()

    return ece

# -----------------------------
# Apply temperature scaling and evaluate all models
# -----------------------------
models_dict = {
    "FC Only": trained_model_fc,
    "Layer4 + FC": trained_model_partial,
    "Fully Unfrozen": trained_model_unfrozen
}

ece_results = {}

for name, model in models_dict.items():
    print(f"\nProcessing model: {name}")
    temp_model = ModelWithTemperature(model)
    temp_model.set_temperature(val_loader, device)
    ece = evaluate_model_with_calibration(temp_model, test_loader, device)
    ece_results[name] = ece

# -----------------------------
# Plot comparison of ECEs
# -----------------------------
plt.figure(figsize=(8,5))
plt.bar(ece_results.keys(), ece_results.values(), color=['skyblue','orange','green'])
plt.ylabel("ECE")
plt.title("ECE Comparison Across Models After Temperature Scaling")
plt.show()
