In [None]:
# Uncertainty Analysis: Compare BNN, Deep Ensembles, and Evidential DL
import torch
import numpy as np
import matplotlib.pyplot as plt
from models.mc_dropout import MCDropoutCNN
from models.deep_ensemble import ensemble_predict
from utils.metrics import compute_entropy
from data.load_data import load_mnist

# Load trained models
mc_model = MCDropoutCNN()
mc_model.load_state_dict(torch.load("mc_dropout.pth"))
mc_model.eval()

# Load data
_, test_loader = load_mnist(batch_size=1)

# Get a batch of test data
images, labels = next(iter(test_loader))

# Monte Carlo Dropout Prediction (Multiple forward passes)
mc_model.enable_dropout()
num_samples = 30
mc_predictions = np.array([torch.softmax(mc_model(images), dim=1).detach().numpy() for _ in range(num_samples)])
mc_mean = mc_predictions.mean(axis=0)
mc_variance = mc_predictions.var(axis=0)

# Deep Ensemble Prediction
ensemble_mean, ensemble_variance = ensemble_predict(images)

# Plot softmax probabilities
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i in range(5):
    axes[0, i].imshow(images[i].squeeze(), cmap="gray")
    axes[0, i].set_title(f"True: {labels[i].item()}")
    axes[0, i].axis("off")

    # Uncertainty measures
    entropy = compute_entropy(mc_mean[i])

    axes[1, i].bar(range(10), mc_mean[i])
    axes[1, i].set_title(f"Entropy: {entropy:.3f}")

plt.tight_layout()
plt.show()

# Print ECE for calibration comparison
from utils.metrics import expected_calibration_error

ece_mc = expected_calibration_error(mc_mean, labels)
ece_ensemble = expected_calibration_error(ensemble_mean, labels)

print(f"Expected Calibration Error (MC Dropout): {ece_mc:.4f}")
print(f"Expected Calibration Error (Deep Ensemble): {ece_ensemble:.4f}")
