In [None]:
!pip install blitz-bayesian-pytorch --quiet

In [None]:

from google.colab import drive
drive.mount('/content/drive')

import sys
project_path = "/content/drive/MyDrive/IUQ_Bioacoustics"
sys.path.append(project_path)


In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from blitz.modules import BayesianLinear
from blitz.utils import variational_estimator
from blitz.losses import kl_divergence_from_nn
from dataset import create_datasets
from baseline_model import MarineMammalBNN
import numpy as np

root_dir = f"{project_path}/data/preprocessed"
train_dataset, test_dataset, class_to_idx = create_datasets(
    root_dir=root_dir,
    test_size=0.2,
    min_samples=100,
    random_state=42
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1)

class_weights = train_dataset.get_class_weights()
NUM_CLASSES = len(class_to_idx)


In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"

model = MarineMammalBNN(num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

print("kl_loss exists:", hasattr(model, "kl_loss"))


In [None]:

def elbo_loss(model, x, y, criterion, kl_weight):
    pred = model(x)
    ce = criterion(pred, y)
    kl = kl_divergence_from_nn(model)
    return ce + kl_weight * kl


In [None]:

EPOCHS = 20
KL_WEIGHT = 0.01

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = elbo_loss(model, x, y, criterion, KL_WEIGHT)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss:.4f}")


In [None]:
torch.save(model.state_dict(), f"{project_path}/marine_bnn.pt")

In [None]:

def predict_with_uncertainty(model, x, n_samples=30):
    model.eval()
    preds = [model(x.to(device)).softmax(dim=1).detach() for _ in range(n_samples)]
    stacked = torch.stack(preds)
    return stacked.mean(dim=0), stacked.std(dim=0)


In [None]:

correct_list = []
confidence_list = []
entropy_list = []
std_list = []

for x, y in test_loader:
    x = x.to(device)
    y = y.to(device)
    mean_pred, std_pred = predict_with_uncertainty(model, x, n_samples=30)
    pred_label = mean_pred.argmax(dim=1).item()
    true_label = y.item()
    confidence = mean_pred.max().item()
    entropy = -torch.sum(mean_pred * torch.log(mean_pred + 1e-8)).item()
    std = std_pred.max().item()
    correct = int(pred_label == true_label)

    correct_list.append(correct)
    confidence_list.append(confidence)
    entropy_list.append(entropy)
    std_list.append(std)

mean_accuracy = np.mean(correct_list)
mean_confidence = np.mean(confidence_list)
mean_entropy = np.mean(entropy_list)
mean_std = np.mean(std_list)

print(f"Mean Accuracy: {mean_accuracy:.4f}")
print(f"Mean Predictive Confidence: {mean_confidence:.4f}")
print(f"Mean Predictive Entropy: {mean_entropy:.4f}")
print(f"Mean Predictive Std: {mean_std:.4f}")
