In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from bayesian_torch.layers import LinearFlipout
from bayesian_torch.models.dnn_to_bnn import get_kl_loss 
import psutil

import random
import os
import time
import pickle
import matplotlib.pyplot as plt
import geopandas as gpd
from tqdm import tqdm
from sklearn.metrics import roc_curve, auc, confusion_matrix, ConfusionMatrixDisplay, classification_report
from imblearn.over_sampling import ADASYN
from utilities import plot_prediction_area_curves, get_pa_intersection

In [2]:
class BNNMineralProspectivity(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.input_dim = input_dim
        self.fc1 = LinearFlipout(input_dim, 32)
        self.relu1 = nn.ReLU()
        self.fc2 = LinearFlipout(32, 32)
        self.relu2 = nn.ReLU()
        self.fc_out = LinearFlipout(32, 1)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        total_kl_divergence = torch.tensor(0.0, device=x.device)

        current_x, layer1_kl = self.fc1(x) 
        total_kl_divergence += layer1_kl.to(x.device) 
        current_x = self.relu1(current_x)

        current_x, layer2_kl = self.fc2(current_x)
        total_kl_divergence += layer2_kl.to(x.device)
        current_x = self.relu2(current_x)

        logits, output_layer_kl = self.fc_out(current_x)
        total_kl_divergence += output_layer_kl.to(x.device)
            
        return logits, total_kl_divergence

In [3]:
def test_outputs(
    model: BNNMineralProspectivity,
    test_loader: DataLoader,
    device: torch.device,
    num_mc_samples: int = 50 
):
    model.eval() 
    all_true_labels_list, all_mean_probs_list, all_pred_variances_list = [], [], []
    with torch.no_grad(): 
        for x_batch, y_batch in tqdm(test_loader, desc="Generating Test Outputs (MC Sampling)"):
            x_batch = x_batch.to(device) 
            
            mc_logits_samples = []
            for _ in range(num_mc_samples):
                # Model's forward pass returns (logits, kld)
                # Only need logits for prdictions.
                logits_sample, _ = model(x_batch) 
                mc_logits_samples.append(logits_sample)
            
            mc_logits_stacked = torch.stack(mc_logits_samples)
            mc_probs_stacked = torch.sigmoid(mc_logits_stacked)

            mean_probs = mc_probs_stacked.mean(dim=0).squeeze()
            pred_variance = mc_probs_stacked.var(dim=0).squeeze()
            
            if mean_probs.ndim == 0: mean_probs = mean_probs.unsqueeze(0)
            if pred_variance.ndim == 0: pred_variance = pred_variance.unsqueeze(0)

            all_true_labels_list.append(y_batch.cpu().numpy().flatten())
            all_mean_probs_list.append(mean_probs.cpu().numpy().flatten()) 
            all_pred_variances_list.append(pred_variance.cpu().numpy().flatten())

    all_true_labels = np.concatenate(all_true_labels_list)
    all_mean_probs = np.concatenate(all_mean_probs_list)
    all_pred_variances = np.concatenate(all_pred_variances_list)
    all_pred_labels = (all_mean_probs > 0.5).astype(int)
    
    return all_true_labels, all_mean_probs, all_pred_labels, all_pred_variances


In [4]:
def mineral_classification_bnn_train(
    train_features_path: str = './data/dataset_train.pt',
    train_labels_path: str = './data/mineral_train.pt',
    test_features_path: str = './data/dataset_test.pt',
    test_labels_path: str = './data/mineral_test.pt',
    output_dir_base_name: str = "./bnn_mineral_outputs", 
    num_epochs: int = 50,
    batch_size: int = 32,
    learning_rate: float = 1e-3,
    kl_weight_scale: float = 1.0, 
    random_state: int = 42, 
    print_every_epoch: int = 1, 
    num_mc_eval_epoch: int = 5
):
    RANDOM_SEED = random_state 
    torch.manual_seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(RANDOM_SEED)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    output_dir = output_dir_base_name 
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    X_train_np = torch.load(train_features_path).numpy()
    y_train_np = torch.load(train_labels_path).numpy().ravel()
    X_test_np = torch.load(test_features_path).numpy()
    y_test_np = torch.load(test_labels_path).numpy().ravel() 

    X_train_np, X_test_np = X_train_np.astype(np.float32), X_test_np.astype(np.float32)
    y_train_np, y_test_np = y_train_np.astype(np.float32), y_test_np.astype(np.float32)

    adasyn = ADASYN(random_state=RANDOM_SEED) 
    X_train_np, y_train_np = adasyn.fit_resample(X_train_np, y_train_np) 

    train_dataset = TensorDataset(torch.from_numpy(X_train_np), torch.from_numpy(y_train_np).unsqueeze(1))
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataset = TensorDataset(torch.from_numpy(X_test_np), torch.from_numpy(y_test_np).unsqueeze(1))
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Model, Optimizer, Criterion
    input_dim = X_train_np.shape[1] 
    model = BNNMineralProspectivity(input_dim).to(device) 
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) 
    criterion_nll = nn.BCEWithLogitsLoss() 
    kl_weight = kl_weight_scale / len(train_dataset) 

    #### TRAINING 
    train_losses_epoch, train_errors_epoch, test_errors_epoch, epochs_list = [], [], [], []
    epoch_durations_list = []
    
    # For memory tracking
    process = psutil.Process(os.getpid())
    overall_peak_ram_mb = 0
    
    print("\nStarting BNN training...")
    overall_training_start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start_time = time.time() 
        model.train()
        
        # Reset the peak tracker for each epoch
        peak_ram_in_epoch_mb = 0
        
        epoch_total_loss, correct_train, total_train = 0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for x_batch, y_batch in pbar:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device) 
            optimizer.zero_grad() 
            logits, kl_div_from_model = model(x_batch)
            nll_loss = criterion_nll(logits, y_batch) 
            elbo_loss = nll_loss + kl_weight * kl_div_from_model 
            elbo_loss.backward(); optimizer.step() 

            epoch_total_loss += elbo_loss.item()
            probs = torch.sigmoid(logits); predicted_train = (probs > 0.5).float() 
            total_train += y_batch.size(0); correct_train += (predicted_train == y_batch).sum().item()
            pbar.set_postfix({"L": elbo_loss.item()})
            
            # Update both epoch and overall peaks after each batch
            current_ram_mb = process.memory_info().rss / (1024 ** 2)
            peak_ram_in_epoch_mb = max(peak_ram_in_epoch_mb, current_ram_mb)
            overall_peak_ram_mb = max(overall_peak_ram_mb, current_ram_mb)
            
        avg_epoch_loss = epoch_total_loss / len(train_loader)
        train_accuracy = correct_train / total_train if total_train > 0 else 0
        train_error = (1 - train_accuracy) * 100
        train_losses_epoch.append(avg_epoch_loss); train_errors_epoch.append(train_error)
        epochs_list.append(epoch + 1)
        
        if (epoch + 1) % print_every_epoch == 0 or epoch == num_epochs - 1:
            print(f"Epoch {epoch+1} - Avg Loss: {avg_epoch_loss:.4f}, Train Err: {train_error:.2f}%")

        current_epoch_duration = time.time() - epoch_start_time
        epoch_durations_list.append(current_epoch_duration)
        
        # peak RAM for the epoch
        print(f"Epoch {epoch+1} completed in {current_epoch_duration:.2f} seconds. Peak RAM in Epoch: {peak_ram_in_epoch_mb:.2f} MB")
        
    overall_training_duration = time.time() - overall_training_start_time
    print(f"\nTotal training time: {overall_training_duration:.2f} seconds.")
    
    return {
        "model": model,
        "test_loader": test_loader,
        "device": device,
        "epochs_list": epochs_list,
        "train_losses_epoch": train_losses_epoch,
        "train_errors_epoch": train_errors_epoch,
        "test_errors_epoch": test_errors_epoch,
        "output_dir": output_dir,
        "RANDOM_SEED": RANDOM_SEED,
        "num_epochs_trained": num_epochs,
        "epoch_durations": epoch_durations_list,
        "peak_ram_mb": overall_peak_ram_mb 
    }

In [None]:
if __name__ == "__main__":
    base_data_dir = './data' 
    run_output_dir_base = "./bnn_mineral_outputs_run" 

    training_results = mineral_classification_bnn_train(
        train_features_path=os.path.join(base_data_dir,'dataset_train.pt'),
        train_labels_path=os.path.join(base_data_dir,'mineral_train.pt'),
        test_features_path=os.path.join(base_data_dir,'dataset_test.pt'),
        test_labels_path=os.path.join(base_data_dir,'mineral_test.pt'),
        output_dir_base_name=run_output_dir_base, 
        num_epochs=100, 
        batch_size=32,
        learning_rate=1e-3,
        kl_weight_scale=1,
        random_state=42,
        print_every_epoch=1,
        num_mc_eval_epoch=1
    )
    
    model_trained = training_results["model"]
    epoch_durations_list = training_results["epoch_durations"]
    peak_ram_mb = training_results["peak_ram_mb"] 

    print("\n--- Final Summary ---")

    # Average time per epoch
    if epoch_durations_list:
        avg_time = np.mean(epoch_durations_list)
        std_time = np.std(epoch_durations_list)
        print(f"BNN - Average time per epoch: {avg_time:.2f} ± {std_time:.2f} seconds")

    # Overall peak RAM
    print(f"BNN - Overall Peak RAM: {peak_ram_mb:.2f} MB")

    # Number of parameters
    total_params = sum(p.numel() for p in model_trained.parameters() if p.requires_grad)
    print(f"BNN - Total trainable parameters: {total_params:,}")

In [6]:
model_trained = training_results["model"]
test_loader_final = training_results["test_loader"]
device_final = training_results["device"]
epochs_list_final = training_results["epochs_list"]
train_losses_final = training_results["train_losses_epoch"]
train_errors_final = training_results["train_errors_epoch"]
test_errors_final = training_results["test_errors_epoch"]
output_dir_final = training_results["output_dir"]
RANDOM_SEED_final = training_results["RANDOM_SEED"]
num_epochs_completed = training_results["num_epochs_trained"]
epoch_durations_list = training_results["epoch_durations"]

In [7]:
num_mc_for_detailed_eval = 100

In [None]:
plt.figure(figsize=(8, 5))

plt.plot(epochs_list_final, train_errors_final, '-', label='Training Error')

mask = ~np.isnan(test_errors_final)
if mask.any():
    plt.plot(
        np.array(epochs_list_final)[mask],
        np.array(test_errors_final)[mask],
        '-', label='Test Error')

plt.xlabel('Epoch')
plt.ylabel('Error Rate (%)')
plt.title('BNN Error Rate vs Epoch')
plt.legend()
plt.grid(True)

outfile = os.path.join(output_dir_final, 'bnn_error_rate_vs_epoch.png')
plt.tight_layout()
plt.savefig(outfile)
plt.show()

In [9]:
model_save_path_final = os.path.join(output_dir_final, 'bnn_mpm_adasyn.pth')
torch.save(model_trained.state_dict(), model_save_path_final)

In [10]:
true_labels, mean_probs, pred_labels, pred_variances = test_outputs(
    model_trained, test_loader_final, device_final, num_mc_samples=num_mc_for_detailed_eval)

np.save(os.path.join(output_dir_final, 'bnn_true_labels.npy'), true_labels)
np.save(os.path.join(output_dir_final, 'bnn_mean_probs.npy'), mean_probs)
np.save(os.path.join(output_dir_final, 'bnn_pred_labels.npy'), pred_labels)
np.save(os.path.join(output_dir_final, 'bnn_pred_variances.npy'), pred_variances)

Generating Test Outputs (MC Sampling): 100%|██████████| 10300/10300 [10:35<00:00, 16.21it/s]


In [11]:
final_test_error_rate = (1.0 - np.sum(pred_labels == true_labels) / len(true_labels)) * 100 if len(true_labels) > 0 else 0.0
fpr, tpr, _ = roc_curve(true_labels, mean_probs) 
roc_auc_score = auc(fpr, tpr) if len(fpr) > 1 and len(tpr) > 1 else 0.0

In [None]:
import glob

output_dir_final = sorted(glob.glob("./bnn_mineral_outputs_run*"))[-1]

true_labels = np.load(os.path.join(output_dir_final, 'bnn_true_labels.npy'))
mean_probs = np.load(os.path.join(output_dir_final, 'bnn_mean_probs.npy'))
pred_labels = np.load(os.path.join(output_dir_final, 'bnn_pred_labels.npy'))
pred_variances = np.load(os.path.join(output_dir_final, 'bnn_pred_variances.npy'))

print(f"True labels shape: {true_labels.shape}")
print(f"Mean probabilities shape: {mean_probs.shape}")
print(f"Predicted labels shape: {pred_labels.shape}")
print(f"Predicted variances shape: {pred_variances.shape}")

True labels shape: (329588,)
Mean probabilities shape: (329588,)
Predicted labels shape: (329588,)
Predicted variances shape: (329588,)


In [None]:
# ROC plot
fpr, tpr, roc_thresholds = roc_curve(true_labels, mean_probs)
roc_auc_score = auc(fpr, tpr)

plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'BNN ROC (AUC = {roc_auc_score:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

font_size = 14

plt.xlabel('False positive rate', fontsize=font_size, labelpad=10)
plt.ylabel('True positive rate', fontsize=font_size, labelpad=10)

plt.tick_params(axis='both', which='major', labelsize=font_size)

plt.legend(loc="lower right", fontsize=font_size)

#plt.grid(True)
plt.tight_layout()  
os.makedirs(output_dir_final, exist_ok=True)
plt.savefig(os.path.join(output_dir_final, 'bnn_roc_curve_final.pdf'), dpi=300, bbox_inches='tight') 
plt.show()

In [None]:
# Confusion Matrix
cm = confusion_matrix(true_labels, pred_labels)
display = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Barren', 'Mineral'])
fig_cm, ax_cm = plt.subplots(figsize=(8, 6))
display.plot(ax=ax_cm, cmap=plt.cm.Blues)

font_size = 14

ax_cm.set_xlabel('Predicted label', fontsize=font_size, labelpad=10)
ax_cm.set_ylabel('True label', fontsize=font_size, labelpad=10)

ax_cm.tick_params(axis='both', which='major', labelsize=font_size)

cbar = ax_cm.images[0].colorbar
if cbar:
    cbar.ax.tick_params(labelsize=font_size)

for text in ax_cm.texts:
    text.set_fontsize(font_size)

plt.tight_layout()
os.makedirs(output_dir_final, exist_ok=True)
plt.savefig(os.path.join(output_dir_final, 'bnn_confusion_matrix_final.pdf'), dpi=300, bbox_inches='tight') 
plt.show()

In [None]:
print(classification_report(true_labels, pred_labels, target_names=['Class 0', 'Class 1'], zero_division=0))

In [None]:
from sklearn.metrics import roc_curve, auc, recall_score, confusion_matrix

# Uncertainty metrics
def g_mean(y_true, y_pred):
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sens = tp / (tp + fn) if (tp + fn) > 0 else 0
    spec = tn / (tn + fp) if (tn + fp) > 0 else 0
    return np.sqrt(sens * spec)

def bootstrap_ci(y_true, y_probs, metric_func, n_bootstrap=1000, seed=42):
    np.random.seed(seed)
    n = len(y_true)
    vals = []
    for _ in range(n_bootstrap):
        idx = np.random.choice(n, n, replace=True)
        vals.append(metric_func(y_true[idx], y_probs[idx]))
    vals = np.array(vals)
    mean_val = np.mean(vals)
    lower = np.percentile(vals, 2.5)
    upper = np.percentile(vals, 97.5)
    return mean_val, lower, upper

def roc_auc_func(y_true, y_probs):
    fpr, tpr, _ = roc_curve(y_true, y_probs)
    return auc(fpr, tpr)

def recall_func(y_true, y_probs):
    return recall_score(y_true, (y_probs > 0.5).astype(int), zero_division=0)

def gms_func(y_true, y_probs):
    return g_mean(y_true, (y_probs > 0.5).astype(int))

roc_mean, roc_low, roc_up = bootstrap_ci(true_labels, mean_probs, roc_auc_func)
rec_mean, rec_low, rec_up = bootstrap_ci(true_labels, mean_probs, recall_func)
gms_mean, gms_low, gms_up = bootstrap_ci(true_labels, mean_probs, gms_func)

print(f"ROC-AUC : {roc_mean:.3f} (95% CI: {roc_low:.3f} – {roc_up:.3f})")
print(f"Recall  : {rec_mean:.3f} (95% CI: {rec_low:.3f} – {rec_up:.3f})")
print(f"G-Mean  : {gms_mean:.3f} (95% CI: {gms_low:.3f} – {gms_up:.3f})")

### Evaluation on noisy data

In [None]:
DATA_DIR = './data'
RANDOM_STATE = RANDOM_SEED_final
BATCH_SIZE = 32 
NUM_MC_SAMPLES_EVAL = num_mc_for_detailed_eval 

X_test_np = torch.load(os.path.join(DATA_DIR, 'dataset_test.pt')).numpy()
y_test_np = torch.load(os.path.join(DATA_DIR, 'mineral_test.pt')).numpy().ravel()

np.random.seed(RANDOM_STATE) 
noise = np.random.normal(0, 0.1, X_test_np.shape).astype(np.float32)
X_test_noisy = X_test_np + noise

noisy_dataset = TensorDataset(torch.from_numpy(X_test_noisy), torch.from_numpy(y_test_np))
noisy_loader = DataLoader(noisy_dataset, batch_size=BATCH_SIZE, shuffle=False)

noisy_labels, noisy_probs, _, _ = test_outputs(
    model_trained, noisy_loader, device_final, num_mc_samples=NUM_MC_SAMPLES_EVAL)

fpr_clean, tpr_clean, _ = roc_curve(true_labels, mean_probs)
roc_auc_clean = auc(fpr_clean, tpr_clean)

fpr_noisy, tpr_noisy, _ = roc_curve(noisy_labels, noisy_probs)
roc_auc_noisy = auc(fpr_noisy, tpr_noisy)

print(f"Baseline ROC AUC on clean data: {roc_auc_clean:.4f}")
print(f"ROC AUC on noisy data:        {roc_auc_noisy:.4f}")
print(f"Performance Drop (AUC Drop):  {roc_auc_clean - roc_auc_noisy:.4f}")