In [None]:
# %%
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import pathlib
import random
import numpy as np
import json
# Setting the seed
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

# %%
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import pathlib
import random


import torch.nn.functional as F

# Evaluate the model with VI and calculate entropy
def VI_dropout_predict_with_entropy(model, inputs, n_samples=10):
    # model.train()  # Enable dropout during prediction
    outputs = torch.stack([model(inputs) for _ in range(n_samples)])
    mean_output = outputs.mean(dim=0)
    uncertainty = outputs.var(dim=0)
    # Calculate entropy
    probs = F.softmax(mean_output, dim=1)  # Convert logits to probabilities
    entropy = -torch.sum(probs * torch.log(probs + 1e-12), dim=1)  # Add epsilon to avoid log(0)
    # Getting the max probability
    max_probs, _ = torch.max(probs, dim=1)
    return mean_output, uncertainty, entropy, max_probs

class FlipoutLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_sigma=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_sigma = prior_sigma
        
        self.weight_mu = nn.Parameter(torch.zeros(out_features, in_features).normal_(0, 0.01))
        self.weight_rho = nn.Parameter(torch.zeros(out_features, in_features).normal_(-5, 0.01))
        self.bias_mu = nn.Parameter(torch.zeros(out_features).normal_(0, 0.01))
        self.bias_rho = nn.Parameter(torch.zeros(out_features).normal_(-5, 0.01))
        
        # Random signs for flipout
        self.register_buffer('input_sign', None)
        self.register_buffer('output_sign', None)
        
    def get_random_signs(self, batch_size, shape, device):
        return (2 * torch.bernoulli(torch.ones(batch_size, *shape, device=device) * 0.5) - 1)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        if (self.input_sign is None or 
            self.input_sign.size(0) != batch_size or 
            self.output_sign is None or 
            self.output_sign.size(0) != batch_size):
            
            self.input_sign = self.get_random_signs(batch_size, (self.in_features,), x.device)
            self.output_sign = self.get_random_signs(batch_size, (self.out_features,), x.device)
        
        weight_sigma = F.softplus(self.weight_rho)
        bias_sigma = F.softplus(self.bias_rho)
        
        mean_output = F.linear(x, self.weight_mu, self.bias_mu)
        
        # Compute perturbation with corrected dimensions
        r_w = torch.randn_like(self.weight_mu)
        perturbed_weights = (r_w * weight_sigma).unsqueeze(0)
        
        x_reshape = x * self.input_sign
        perturbation = F.linear(x_reshape, perturbed_weights.squeeze(0)) * self.output_sign
        
        bias_perturbation = bias_sigma * torch.randn_like(self.bias_mu)
        
        return mean_output + perturbation + bias_perturbation

    def kl_loss(self):
        weight_sigma = F.softplus(self.weight_rho)
        bias_sigma = F.softplus(self.bias_rho)
        
        kl_weight = 0.5 * torch.sum(
            torch.log1p((weight_sigma**2) / (self.prior_sigma**2)) +
            (self.weight_mu**2) / (self.prior_sigma**2) - 1
        )
        
        kl_bias = 0.5 * torch.sum(
            torch.log1p((bias_sigma**2) / (self.prior_sigma**2)) +
            (self.bias_mu**2) / (self.prior_sigma**2) - 1
        )
        
        return (kl_weight + kl_bias) / self.weight_mu.numel()

class BayesianNNFlipout(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = FlipoutLinear(64 * 4 * 4, 512, prior_sigma=0.1)
        self.fc2 = FlipoutLinear(512, 10, prior_sigma=0.1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 4 * 4)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def kl_loss(self):
        return self.fc1.kl_loss() + self.fc2.kl_loss()

def train_epoch(model, train_loader, optimizer, device, epoch, total_epochs=100):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    beta = min(1.0, (epoch / (total_epochs * 0.2))) * 0.1
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        nll_loss = F.cross_entropy(output, target)
        kl_loss = model.kl_loss()
        loss = nll_loss + beta * kl_loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % 100 == 0:
            print(f'Train Batch: {batch_idx}/{len(train_loader)}, '
                  f'Loss: {loss.item():.4f}, '
                  f'Acc: {100. * correct/total:.2f}%')
    
    return total_loss / len(train_loader), correct / total

def evaluate(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    
    print(f'Test set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {100. * accuracy:.2f}%')
    
    return test_loss, accuracy

# %%

# Data preparation
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)
testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

# %%

# Training and evaluation for different sample sizes
sample_sizes = [1, 5, 10, 50, 100, 2000, 4000]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# %%
for size in sample_sizes:
    hp_size = size

    hp_path = f"../src/results/VI_HP_size_{hp_size}/best_params_sample_{hp_size}.txt"

    # Open the file and extract hyperparameters
    with open(hp_path, "r") as f:
        lines = f.readlines()

    # Parse the hyperparameters
    batch_size = int(lines[2].split(":")[1].strip().strip(","))
    prior_sigma = int(lines[3].split(":")[1].strip().strip(","))
    lr = int(lines[4].split(":")[1].strip().strip(","))
    epochs = int(lines[5].split(":")[1].strip().strip(","))
    num_layers = float(lines[6].split(":")[1].strip().strip(","))
    best_trial = int(lines[10].split(":")[1].strip())

    print(f"Best hyperparameters for {size} samples: {batch_size}, {prior_sigma}, {lr}, {num_layers}, {epochs}, {best_trial}")

    hyperparameters = {
        'lr': lr, 
        'num_epochs': epochs, 
        'batch_size': batch_size, 
        'prior_sigma': prior_sigma, 
        'num_layers': num_layers, 
    }

    print(f"\Loading model with {size} samples per class...")
    
    # Subset dataset to include only 'size' samples per class
    indices = []
    class_counts = {i: 0 for i in range(10)}
    for idx, (_, label) in enumerate(trainset):
        if class_counts[label] < size:
            indices.append(idx)
            class_counts[label] += 1
        if all(count >= size for count in class_counts.values()):
            break

    subset = Subset(trainset, indices)
    trainloader = DataLoader(subset, batch_size=hyperparameters['batch_size'], shuffle=True)

    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model = BayesianNNFlipout().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

    # Load the best model
    model.load_state_dict(torch.load(f"../src/models/svhn_model_{size}_samples_VI.pth"))

    model.eval()
    correct = 0
    total = 0
    all_entropies = []
    all_max_probs = []
    all_uncertainties = []

    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            mean_output, uncertainty, entropy, max_probs = VI_dropout_predict_with_entropy(model, inputs, n_samples=10)
            _, predicted = torch.max(mean_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_entropies.extend(entropy.cpu().numpy())  # Collect entropies
            all_max_probs.extend(max_probs.cpu().numpy())  # Collect max probabilities
            all_uncertainties.append(uncertainty.cpu().numpy())

    accuracy = 100 * correct / total

    # Calculate statistics for entropy and max probabilities
    mean_entropy = np.mean(all_entropies)
    std_entropy = np.std(all_entropies)
    mean_max_prob = np.mean(all_max_probs)
    std_max_prob = np.std(all_max_probs)
    print(f"Mean Entropy: {mean_entropy:.4f}")
    print(f"Standard Deviation of Entropy: {std_entropy:.4f}")
    print(f"Mean Max Probability: {mean_max_prob:.4f}")
    print(f"Standard Deviation of Max Probability: {std_max_prob:.4f}")

    # Save the results
    results_path = f"./results/VI/svhn_results_{size}_samples_VI.json"
    pathlib.Path("./results/VI/").mkdir(parents=True, exist_ok=True)
    with open(results_path, "w") as f:
            json.dump({"mean_entropy": f"{mean_entropy:.4f}", "std_entropy": f"{std_entropy:.4f}", 
                       "mean_max_prob": f"{mean_max_prob:.4f}", "std_max_prob": f"{std_max_prob:.4f}"} , f)
    print(f"Results saved at {results_path}")
