# Set up imports

In [None]:
import os
if not os.path.exists("./notebooks"):
    %cd ..

import numpy as np
from src.data_processing import load_mean_std

import torch
from torchvision import transforms
import src.model
from src.training import monte_carlo_predictions


# 0. Set Device 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Define Monte Carlo Dropout testing

In [None]:
def monte_carlo_dropout(model, test_loader, samples = 20):
    predictions = []
    for _ in range(samples):
        predictions.append(monte_carlo_predictions(model, test_loader))
        
    predictions = np.stack(predictions , 0)
    mean_predictions = np.mean(predictions, axis=0)
    entropy = -1.0  * np.sum(mean_predictions * np.log(mean_predictions + 1e-16), axis=-1)
    return predictions, entropy

## 2. Define Model Loading

In [None]:
def load_model(model, name) :
    model_path = f"./models/{name}.pth"
    model.load_state_dict(torch.load(model_path, weights_only=True,map_location=torch.device('cpu')))
    model.device = device
    model.to(device)

# 3. Load Models


In [None]:
from src.dataset import prepare_dataset_loaders
from src.config import DATASET_DIR 
mean, std = load_mean_std(f"{DATASET_DIR}/scaling_params.json")

dropout_model = src.model.DropoutCNN() 
load_model(dropout_model, "DropoutCNN")

model_names = ["OriginalSizeCNN-UNIFORM-RELU", "OriginalSizeCNN-HE-RELU", "OriginalSizeCNN-XAVIER-RELU"]
ensemble_models = []

for model_name in model_names:
    sample_model = src.model.OriginalSizeCNN()
    load_model(sample_model, model_name)
    ensemble_models.append(sample_model)

ensemble_model = src.model.EnsembleCNN(ensemble_models, 2)
load_model(ensemble_model, "EnsembleCNN")

## 4. Calculate Entropy for models

In [None]:
batch_size = 10 # TODO make it not value from keyboard.


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, batch_size)


dropout_predictions, dropout_entropy = monte_carlo_dropout(dropout_model, test_loader, samples=10)
dropout_mean_predictions = np.mean(dropout_predictions, axis=0)
dropout_variance_predictions = np.var(dropout_predictions, axis=0)

ensemble_predictions, ensemble_entropy = monte_carlo_dropout(ensemble_model, test_loader, samples=10)
ensemble_mean_predictions = np.mean(ensemble_predictions, axis=0)
ensemble_variance_predictions = np.var(ensemble_predictions, axis=0)


print("Dropout Mean Predictions:", dropout_mean_predictions)
print("Dropout Uncertainty Predictions:", dropout_mean_predictions)
print("Entropy for Dropout model:", dropout_entropy)
print("Entropy for Ensemble model:", ensemble_entropy)
print("Ensemble Mean Predictions:", ensemble_mean_predictions)




## Print Statistics 

In [None]:
from scipy.stats import pearsonr



#print("Dropout Mean Predictions:", dropout_mean_predictions)
#print("Dropout Uncertainty Predictions:", dropout_mean_predictions)
print("Entropy for Dropout model:", dropout_entropy)
print("Entropy for Ensemble model:", ensemble_entropy)
#print("Ensemble Mean Predictions:", ensemble_mean_predictions)

# Flatten and compare predictions
mc_mean_flat = dropout_mean_predictions.flatten()
ensemble_mean_flat = ensemble_mean_predictions.flatten()
correlation, _ = pearsonr(mc_mean_flat,
                          ensemble_mean_flat)
print(f"Prediction agreement (Pearson correlation): {correlation}")


# Plot Monte Carlo Dropout

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


# Predicted Probabilities
plt.figure(figsize=(10, 5))
sns.histplot(dropout_mean_predictions, bins=30, kde=True, color='blue')
plt.title("Distribution of Predicted Probabilities for Dropout Model")
plt.xlabel("Predicted Probability")
plt.ylabel("Frequency")

plt.figure(figsize=(10, 5))
sns.histplot(dropout_variance_predictions, bins=30, kde=True, color='blue')
plt.title("Distribution of Variances for Dropout Model")
plt.xlabel("Variance")
plt.ylabel("Frequency")


plt.figure(figsize=(10, 5))
sns.histplot(ensemble_mean_predictions, bins=30, kde=True, color='blue')
plt.title("Distribution of Predicted Probabilities for Ensemble Model")
plt.xlabel("Predicted Probability")
plt.ylabel("Frequency")