In [1]:
import pyro
from pyro.optim import Adam  # type: ignore
from pyro.infer import SVI, Trace_ELBO, autoguide
from pyro.infer import MCMC, NUTS, HMC, Predictive
from pyro.nn import PyroModule, PyroSample
import pyro.distributions as dist
import os
import sys
import math
import numpy as np
import pandas as pd
from typing import List, Optional, Tuple

import torch
import torch.nn as nn


def set_seed(seed: int = 0):
    np.random.seed(seed)
    torch.manual_seed(seed)
    pyro.set_rng_seed(seed)


def get_MNIST_dataset():
    """
    Load and preprocess the MNIST dataset, storing train and test data as a single tensor file.
    Returns:
        train_images: Tensor of shape (60000, 784)
        train_labels: Tensor of shape (60000,)
        test_images: Tensor of shape (10000, 784)
        test_labels: Tensor of shape (10000,)
    """
    from torchvision import datasets, transforms

    dir = os.path.join(os.getcwd(), 'MNIST')
    # dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'MNIST')
    os.makedirs(dir, exist_ok=True)
    dataset_path = os.path.join(dir, 'mnist_data.pt')

    if not os.path.exists(dataset_path):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.view(-1))  # Flatten to 784-D
        ])

        train_data = datasets.MNIST(root=dir, train=True, download=True, transform=transform)
        test_data = datasets.MNIST(root=dir, train=False, download=True, transform=transform)

        print('Processing MNIST dataset...', end=' ')
        train_images = torch.stack([img for img, _ in train_data])  # (60000, 784)
        train_labels = torch.tensor([label for _, label in train_data])  # (60000,)
        test_images = torch.stack([img for img, _ in test_data])  # (10000, 784)
        test_labels = torch.tensor([label for _, label in test_data])  # (10000,)

        torch.save({
            'train_images': train_images,
            'train_labels': train_labels,
            'test_images': test_images,
            'test_labels': test_labels
        }, dataset_path)
    else:
        data = torch.load(dataset_path)
        train_images = data['train_images']
        train_labels = data['train_labels']
        test_images = data['test_images']
        test_labels = data['test_labels']

    print('Done.')
    return train_images, train_labels, test_images, test_labels


def get_digits_dataset(test_size=0.2, random_state=42):
    """
    Load and preprocess the Digits dataset (8x8 images, 64 features).
    Splits into train/test sets.

    Returns:
        train_images: Tensor of shape (N_train, 64), float32 normalized [0, 1]
        train_labels: Tensor of shape (N_train,), int64
        test_images: Tensor of shape (N_test, 64), float32 normalized [0, 1]
        test_labels: Tensor of shape (N_test,), int64
    """
    from sklearn.datasets import load_digits
    from sklearn.model_selection import train_test_split
    digits = load_digits()
    images = torch.tensor(digits.images).float()  # type: ignore # shape (1797, 8, 8)
    labels = torch.tensor(digits.target).long()   # type: ignore # shape (1797,)

    # normalize 0~16 to 0~1
    images /= 16.0

    # flatten
    images = images.view(-1, 64)

    train_images, test_images, train_labels, test_labels = train_test_split(
        images, labels, test_size=test_size, random_state=random_state, stratify=labels)

    return train_images, train_labels, test_images, test_labels


def get_NonLinear_dataset(
    n_samples=10000,
    n_features=10,
    n_classes=10,
    n_clusters_per_class=3,
    n_informative=5,
    class_sep=2.0,
    nonlinear_strength=0.6,
):
    from sklearn.datasets import make_classification
    from sklearn.model_selection import train_test_split

    assert n_informative < n_features
    assert (n_classes * n_clusters_per_class) <= 2**n_informative

    class_sep = class_sep
    nonlinear_strength = nonlinear_strength

    # Generate data with sklearn (numpy)
    X_np, y_np = make_classification(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=n_informative,
        n_redundant=0,
        n_classes=n_classes,
        n_clusters_per_class=n_clusters_per_class,
        class_sep=class_sep,
        random_state=42
    )

    # Convert to torch tensors (float for X, long for y)
    X = torch.tensor(X_np, dtype=torch.float32)
    y = torch.tensor(y_np, dtype=torch.long)

    # Compute theta and radius (torch)
    theta = torch.atan2(X[:, 1], X[:, 0])
    radius = torch.sqrt(X[:, 0] ** 2 + X[:, 1] ** 2)

    # Nonlinear transformation class-wise
    for i in range(n_classes):
        mask = (y == i)
        freq = (i + 1) * 1.5

        # Index mask for all rows where y==i
        idx = mask.nonzero(as_tuple=True)[0]

        X[idx, 0] = X[idx, 0] + nonlinear_strength * torch.cos(freq * theta[idx]) * radius[idx]
        X[idx, 1] = X[idx, 1] + nonlinear_strength * torch.sin(freq * theta[idx]) * radius[idx]
        X[idx, 2] = X[idx, 2] + nonlinear_strength * (torch.sin(freq * theta[idx]) + 0.1 * radius[idx] ** 2)

    # Split train/test using sklearn but convert back to tensors
    from sklearn.model_selection import train_test_split

    train_X_np, test_X_np, train_y_np, test_y_np = train_test_split(
        X.numpy(), y.numpy(), test_size=0.2, stratify=y.numpy()
    )

    train_X = torch.tensor(train_X_np, dtype=torch.float32)
    test_X = torch.tensor(test_X_np, dtype=torch.float32)
    train_y = torch.tensor(train_y_np, dtype=torch.long)
    test_y = torch.tensor(test_y_np, dtype=torch.long)

    return train_X, train_y, test_X, test_y


def get_stats(task_type: str, pred_y, pred_uncertainty, true_y):
    from sklearn.metrics import (
        accuracy_score, confusion_matrix, precision_recall_fscore_support,
        mean_squared_error, mean_absolute_error, r2_score, explained_variance_score
    )
    pred_y = pred_y.cpu().numpy()
    uncertainty = pred_uncertainty.cpu().numpy()
    true_y = true_y.cpu().numpy()

    print("\n===== Test Statistics =====")
    print(f"Task Type: {task_type}")

    if task_type == "classification":
        # Core metrics
        acc = accuracy_score(true_y, pred_y)
        precision, recall, f1, support = precision_recall_fscore_support(true_y, pred_y, average='weighted', zero_division=0)
        cm = confusion_matrix(true_y, pred_y)

        print(f"Accuracy: {acc:.4f}")
        print(f"Precision (weighted): {precision:.4f}")
        print(f"Recall (weighted):    {recall:.4f}")
        print(f"F1 Score (weighted):  {f1:.4f}")
        print("Confusion Matrix (rows: true, cols: pred):")
        print(cm)

        # Uncertainty
        print(f"Mean Uncertainty (normalized entropy): {uncertainty.mean():.4f}")
        print(f"Min/Max Uncertainty: {uncertainty.min():.4f} / {uncertainty.max():.4f}")
        print(f"Uncertainty StdDev:  {np.std(uncertainty):.4f}")

        # Per-class stats
        classes = np.unique(true_y)
        for cls in classes:
            idx = (true_y == cls)
            cls_acc = accuracy_score(true_y[idx], pred_y[idx])
            cls_unc = uncertainty[idx].mean()
            print(f"Class {cls}: Accuracy={cls_acc:.4f}, Mean Uncertainty={cls_unc:.4f}, Support={idx.sum()}")

        # Misclassification stats
        mis_idx = (true_y != pred_y)
        mis_rate = mis_idx.mean()
        mis_unc = uncertainty[mis_idx].mean() if mis_idx.any() else 0.0
        print(f"Misclassification Rate: {mis_rate:.4f}")
        print(f"Mean Uncertainty on Misclassified: {mis_unc:.4f}")

    elif task_type == "regression":
        pred_y = pred_y.squeeze()
        true_y = true_y.squeeze()
        uncertainty = uncertainty.squeeze()

        mse = mean_squared_error(true_y, pred_y)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(true_y, pred_y)
        r2 = r2_score(true_y, pred_y)
        evs = explained_variance_score(true_y, pred_y)

        print(f"MSE:  {mse:.6f}")
        print(f"RMSE: {rmse:.6f}")
        print(f"MAE:  {mae:.6f}")
        print(f"R² Score: {r2:.4f}")
        print(f"Explained Variance Score: {evs:.4f}")

        # Uncertainty
        print(f"Mean Predictive Std (uncertainty): {uncertainty.mean():.4f}")
        print(f"Min/Max Uncertainty: {uncertainty.min():.4f} / {uncertainty.max():.4f}")
        print(f"Uncertainty StdDev:  {uncertainty.std():.4f}")

        # Residuals
        residuals = pred_y - true_y
        abs_residuals = np.abs(residuals)
        print(f"Mean Absolute Residual: {abs_residuals.mean():.4f}")
        print(f"Residual StdDev:        {residuals.std():.4f}")

        # Correlation between uncertainty and residuals
        corr = np.corrcoef(uncertainty, abs_residuals)[0, 1]
        print(f"Correlation (|residual| vs uncertainty): {corr:.4f}")

    else:
        raise ValueError(f"Unknown task_type: {task_type}")


def plot_3d_classification(X, true_y, pred_y, uncertainty_y=None, title="NN Probabilistic Prediction"):
    import plotly.express as px
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE
    import umap

    import warnings
    warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")

    X = X.cpu().numpy()
    pred_y = pred_y.cpu().numpy()
    true_y = true_y.cpu().numpy()
    if uncertainty_y is not None:
        uncertainty = uncertainty_y.cpu().numpy()
        # mu = np.mean(uncertainty)
        # sigma = np.std(uncertainty)
        # uncertainty = (uncertainty - mu) / sigma * 0.3 + 0.5 # 90% data between 0 and 1

    # Dimensionality reduction if input is high-dimensional
    print(f'Performing Dimensionality Reduction(t-SNE/uMAP) Before Plotting({X.shape[1]}D -> 3D)...', end=' ')
    if X.shape[1] > 3:
        # algo_pca = PCA(n_components=3)
        # X_3d = algo_pca.fit_transform(X)
        # algo_tsne = TSNE(n_components=3, max_iter=1000, init='pca', random_state=42)
        # X_3d = algo_tsne.fit_transform(X)
        algo_umap = umap.UMAP(n_components=3, n_neighbors=40, min_dist=0.1)
        X_3d = np.array(algo_umap.fit_transform(X))
    else:
        # Already 3 or fewer dimensions
        X_3d = X if X.shape[1] == 3 else np.pad(X, ((0, 0), (0, 3 - X.shape[1])), mode='constant')
    print('Done.')

    classes = np.unique(pred_y)
    colors = px.colors.qualitative.Set1
    # fig.add_trace(go.Scatter(y=uncertainty))

    target = {
        'true_y': [true_y, False],
        'pred_y': [pred_y, True],
    }

    fig = make_subplots(
        rows=1, cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=("True_y clusters (t-SNE/uMAP)", "Pred_y clusters (t-SNE/uMAP) with Uncertainty Measure")
    )

    for col, (target_key, target_value) in enumerate(target.items(), start=1):  # col = 1, 2
        for i, cls in enumerate(classes):
            idx = target_value[0] == cls
            fig.add_trace(go.Scatter3d(
                x=X_3d[idx, 0], y=X_3d[idx, 1], z=X_3d[idx, 2],
                mode='markers', name=f'{target_key}: Class {cls}',
                marker=dict(  # uncertain = decision boundary = more visible
                    size=4 if (uncertainty is None) or (not target_value[1]) else 3 + 10 * uncertainty[idx],
                    color=colors[i % len(colors)],
                    opacity=1.0,)), row=1, col=col)

    fig.update_layout(height=700, width=1400)
    fig.update_scenes(dict(xaxis_title="X1", yaxis_title="X2", zaxis_title="X3",), row=1, col=1)
    fig.update_scenes(dict(xaxis_title="X1", yaxis_title="X2", zaxis_title="X3",), row=1, col=2)

    fig.show()


def apply_PCA(train_X, test_X, explained_var_threshold=0.98):
    from sklearn.decomposition import PCA
    # Convert to numpy if needed
    if isinstance(train_X, torch.Tensor):
        train_X_np = train_X.numpy()
    else:
        train_X_np = train_X
    if isinstance(test_X, torch.Tensor):
        test_X_np = test_X.numpy()
    else:
        test_X_np = test_X

    # First run PCA without limiting components to get full variance profile
    pca_full = PCA()
    pca_full.fit(train_X_np)
    cumulative_variance = np.cumsum(pca_full.explained_variance_ratio_)
    n_components = np.searchsorted(cumulative_variance, explained_var_threshold) + 1

    print(f"[PCA] Selected {n_components} components to preserve {explained_var_threshold*100:.1f}% variance.")

    # # Optional: Print cumulative variance for reference
    # print("[PCA] Cumulative explained variance (first 20 components):")
    # for i in range(min(20, len(cumulative_variance))):
    #     print(f"  Component {i+1}: {cumulative_variance[i]:.4f}")

    # Now run PCA with the selected number of components
    pca = PCA(n_components=n_components)
    train_X_pca = pca.fit_transform(train_X_np)
    test_X_pca = pca.transform(test_X_np)

    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    train_X_pca_norm = scaler.fit_transform(train_X_pca)
    test_X_pca_norm = scaler.transform(test_X_pca)

    return (
        torch.tensor(train_X_pca_norm, dtype=torch.float32),
        torch.tensor(test_X_pca_norm, dtype=torch.float32)
    )


class BNN(PyroModule):
    def __init__(self, in_dim=1, out_dim=1, hid_dims=[5, 5], hid_layers=2, prior_scale=5., task_type="classification"):
        super().__init__()
        self.task_type = task_type  # "regression" or "classification"
        self.activation = nn.Tanh()  # nn.ReLU()
        assert in_dim > 0 and out_dim > 0 and all(hid_dims) > 0 and hid_layers > 0  # make sure the dimensions are valid
        self.layer_sizes = [in_dim] + hid_layers * hid_dims + [out_dim]
        layer_list = [PyroModule[nn.Linear](self.layer_sizes[idx - 1], self.layer_sizes[idx]) for idx in range(1, len(self.layer_sizes))]  # type: ignore
        self.layers = PyroModule[torch.nn.ModuleList](layer_list)  # type: ignore

        for i, layer in enumerate(self.layers):  # type: ignore
            in_f = self.layer_sizes[i]
            out_f = self.layer_sizes[i+1]
            layer.weight = PyroSample(dist.Normal(0., prior_scale * np.sqrt(2 / in_f)).expand([out_f, in_f]).to_event(2))
            layer.bias = PyroSample(dist.Normal(0., prior_scale).expand([out_f]).to_event(1))

        # normalization layers (e.g. BatchNorm) are usually avoided in BNNs
        # 0. uncertainty and variability in weights are already regularizing
        # 1. break the i.i.d. assumption inside the plate.
        # 2. introduce non-local statistics (mean/var across batch), which conflicts with Bayesian weight sampling
        # 3. not compatible with Pyro's plate() mechanism

    def model(self, x, y=None):  # sampling likelihood as forward pass
        # Priors: pyro.sample without obs
        # Likelihood: pyro.sample with obs
        # How it works: Potential Energy: U(θ)=−logp(data∣θ)-logp(θ)
        #   1. the forward pass is the sampling process of the likelihood p(data|θ) = p(y∣X,θ). After Prior p(θ) is sampled, the likelihood can be sampled too
        #   2. then the potential, kinetic, diffussion terms of this Markov Chain can be calculated over latent space of θ to model true Posterior p(θ∣data)
        #   3. the forward pass will be sampled many times over random variables θ as Markov Chain drifting in latent space of θ during training as its track converge to the true posterior
        #   4. then during inference, the forward pass will be evaluated many times over true posterior of theta, to get mean and uncertainty measure as outputs
        # y == None: inference mode, else: training mode

        # [batch_size, features] -> column vector
        out = x.view(-1, self.layer_sizes[0])  # Reshape input to (batch_size, in_dim)
        for layer in self.layers[:-1]:  # type: ignore
            out = self.activation(layer(out))
        out = self.layers[-1](out)  # type: ignore (batch_size, out_dim)

        # even the input features are dependent, and model can capture the dependencies (e.g. LSTM),
        # as long as the outputs of such model has no explicit temporal structure, they should still be treated as i.i.d samples
        with pyro.plate("data", out.shape[0]):  # samples in batch is i.i.d (enable parallelism)
            # epistemic uncertainty (modeled by the Bayesian weights/bias), due to finite data
            # aleatoric uncertainty (modeled by the response noise), due to inherent noise in observations (avoid overfitting the noise)
            # during training, conditional likelihood is use in MCMC/VI
            # during inference, we unconditionally sample from likelihood directly (special treatment for classification)
            if self.task_type == "regression":
                sigma = pyro.sample("sigma", dist.Uniform(0., 1.))  # infer the response noise Gamma(.5, 1)
                pyro.sample("obs", dist.Normal(out.squeeze(-1), 1. / sigma.sqrt()), obs=y)
                return
            elif self.task_type == "classification":
                pyro.sample("obs", dist.Categorical(logits=out), obs=y)  # Bayesian NN outputs are logits for numerical stability
                # classification likelihood is Categorical, we cannot drawing samples from it to compute mean/std
                # instead, we need to return NN output logits and operate on it directly, note that this does not lose the likelihood information
                # after returning it, we can refer to it again as "_RETURN" in predictive sampling
                return out

    def fit_VI(self, x, y, num_steps=1000, lr=1e-2):
        self.model_type = 'VI'
        pyro.clear_param_store()
        # self.guide = autoguide.AutoMultivariateNormal(self.model) # Structured VI
        self.guide = autoguide.AutoDiagonalNormal(self.model)  # Mean-Field VI
        optimizer = Adam({"lr": lr})
        svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())

        self.loss_history = []
        for step in range(num_steps):
            loss = svi.step(x, y)
            self.loss_history.append(loss)
            if step % 100 == 0:
                print(f"[{step:04d}] ELBO loss: {loss:.4f}")
        return self.guide, self.loss_history

    def fit_MCMC(self, x, y, num_samples=1000):
        self.model_type = 'MCMC'
        jit_compile = False
        pyro.clear_param_store()
        print("Running MCMC...", end=' ')
        kernel = NUTS(self.model, jit_compile=jit_compile)
        # kernel = HMC(self.model, step_size=0.1, num_steps=10, jit_compile=jit_compile)
        self.mcmc = MCMC(kernel, num_samples=num_samples, warmup_steps=int(0.3*num_samples))
        self.mcmc.run(x, y)
        # MCMC trajectories used as lossless True Posterior, no resampling/interpolation needed, simply feed random track into
        self.posterior_samples = self.mcmc.get_samples()
        print("Done.")
        return self.posterior_samples

    def predict(self, x, num_samples=100):
        # Monte Carlo estimate of predictive mean and std from posterior
        if self.model_type == 'MCMC':
            guide = None
            posterior_samples = self.posterior_samples
            parallel = False
        elif self.model_type == 'VI':
            guide = self.guide
            posterior_samples = None
            parallel = False
        else:
            assert False, f"Unsupported model type:{self.model_type}: Use 'MCMC' or 'VI'."

        if self.task_type == 'regression':
            return_sites = ["obs"]
        elif self.task_type == 'classification':
            return_sites = ["_RETURN"]
        else:
            assert False, f"Unsupported task type:{self.task_type}: Use 'regression' or 'classification'."

        predictive = Predictive(self.model, posterior_samples=posterior_samples, guide=guide, num_samples=num_samples, return_sites=return_sites, parallel=parallel)

        if self.task_type == "regression":
            # Predictive mean and std over real-valued output
            samples = predictive(x)["obs"].float()  # [num_samples, batch_size]
            pred_mean = samples.mean(0)             # [batch_size]
            pred_std = samples.std(0)               # [batch_size]
            pred_value = pred_mean
            pred_uncertainty = pred_std
        elif self.task_type == "classification":
            # Predictive mean and std over logits(prob) over multiple classes
            logits = predictive(x)["_RETURN"]       # [num_samples, batch_size, num_classes]
            probs = torch.softmax(logits, dim=-1)   # [num_samples, batch_size, num_classes]
            probs_mean = probs.mean(0)              # [batch_size, num_classes]
            probs_std = probs.std(0)                # [batch_size, num_classes]
            pred_value = probs_mean.argmax(dim=-1)  # [batch_size]

            # for classification, consider these as uncertainty measure:
            #   1.Predictive Entropy (Total uncertainty)
            entropy = -(probs_mean * probs_mean.clamp(min=1e-8).log()).sum(dim=-1)  # [batch_size]
            # #   2.Mutual_Info(BALD) (Model uncertainty)
            # expected_entropy = -(probs * probs.clamp(min=1e-8).log()).sum(dim=-1).mean(0)  # [batch_size]
            # mutual_info = entropy - expected_entropy  # [batch_size]

            pred_uncertainty = entropy/math.log(probs_mean.size(-1))  # entropy in [0,log_e(num_classes)]

        return pred_value, pred_uncertainty


if __name__ == '__main__':
    SEED = 123456
    set_seed(SEED)

In [None]:
# BNN-MCMC-NUTS with no PCA on non-Linear-10D dataset
# even though it faithfully model the true posterior
# the computational cost is just unbearable
# after warm, if we are lucky and MCMC track stays in high density region of latent space, model is quickly fit

task_type = "classification"  # "regression" or "classification"
hid_dims = [30, 20]
prior_scale = 5.0
num_train = 20  # samples/steps
num_test = 100  # samples

print('Getting Dataset...', end=' ')
# train_X, train_y, test_X, test_y = get_MNIST_dataset()
# train_X, train_y, test_X, test_y = get_digits_dataset()
train_X, train_y, test_X, test_y = get_NonLinear_dataset(
    n_samples=50000,
    n_features=10,
    n_classes=10,
    n_clusters_per_class=3,
    n_informative=5,
    class_sep=2.0,
    nonlinear_strength=0.6,
    )
print(f'{train_X.shape[1]}D')
# train_X, test_X = apply_PCA(train_X, test_X)
bnn = BNN(in_dim=train_X.shape[1], out_dim=len(torch.unique(train_y)), hid_dims=hid_dims, hid_layers=len(hid_dims), prior_scale=prior_scale, task_type=task_type)
bnn.fit_MCMC(train_X, train_y, num_samples=num_train)
# bnn.fit_VI(train_X, train_y, num_steps=num_train, lr=1e-3)
print('Training Done.')
pred_y, pred_uncertainty = bnn.predict(test_X, num_samples=min(num_train, num_test))
print('Prediction Done.')
get_stats(task_type, pred_y, pred_uncertainty, test_y)
plot_3d_classification(test_X, test_y, pred_y, pred_uncertainty)

Getting Dataset... 10D
Running MCMC... 

Sample: 100%|██████████| 26/26 [01:30,  3.50s/it, step size=4.23e-03, acc. prob=0.638]


Done.
Training Done.
Prediction Done.

===== Test Statistics =====
Task Type: classification
Accuracy: 0.6322
Precision (weighted): 0.6313
Recall (weighted):    0.6322
F1 Score (weighted):  0.6293
Confusion Matrix (rows: true, cols: pred):
[[822   5  29  22   7  56  13  11   1  34]
 [  4 674   8 101  49  16  28  51  45  25]
 [158  15 576  47  16  59  72  21  23  12]
 [ 41  92  41 659  23  26  14  33  35  38]
 [ 15  57  17  24 640  78  51  59  49  12]
 [ 74  54  40  30 103 578  32  23  24  42]
 [ 14  18  30  23  47  60 554  13 102 138]
 [ 21  89  45  20 121  28  13 539  75  50]
 [ 13  34  10  51  19  12  40  37 765  14]
 [ 75  19  12  25  17  22 152 134  30 515]]
Mean Uncertainty (normalized entropy): 0.7091
Min/Max Uncertainty: 0.1090 / 0.9793
Uncertainty StdDev:  0.1260
Class 0: Accuracy=0.8220, Mean Uncertainty=0.6445, Support=1000
Class 1: Accuracy=0.6733, Mean Uncertainty=0.7172, Support=1001
Class 2: Accuracy=0.5766, Mean Uncertainty=0.7232, Support=999
Class 3: Accuracy=0.6577, M

In [None]:
# BNN-VI-Mean-Field with no PCA on non-Linear-10D dataset

task_type = "classification"  # "regression" or "classification"
hid_dims = [30, 20]
prior_scale = 5.0
num_train = 500  # samples/steps
num_test = 100  # samples

print('Getting Dataset...', end=' ')
# train_X, train_y, test_X, test_y = get_MNIST_dataset()
# train_X, train_y, test_X, test_y = get_digits_dataset()
train_X, train_y, test_X, test_y = get_NonLinear_dataset(
    n_samples=50000,
    n_features=10,
    n_classes=10,
    n_clusters_per_class=3,
    n_informative=5,
    class_sep=2.0,
    nonlinear_strength=0.6,
    )
print(f'{train_X.shape[1]}D')
# train_X, test_X = apply_PCA(train_X, test_X)
bnn = BNN(in_dim=train_X.shape[1], out_dim=len(torch.unique(train_y)), hid_dims=hid_dims, hid_layers=len(hid_dims), prior_scale=prior_scale, task_type=task_type)
# bnn.fit_MCMC(train_X, train_y, num_samples=num_train)
bnn.fit_VI(train_X, train_y, num_steps=num_train, lr=1e-3)
print('Training Done.')
pred_y, pred_uncertainty = bnn.predict(test_X, num_samples=min(num_train, num_test))
print('Prediction Done.')
get_stats(task_type, pred_y, pred_uncertainty, test_y)
plot_3d_classification(test_X, test_y, pred_y, pred_uncertainty)

Getting Dataset... 10D
[0000] ELBO loss: 154609.8545
[0100] ELBO loss: 96359.7597
[0200] ELBO loss: 90556.3290
[0300] ELBO loss: 87349.4766
[0400] ELBO loss: 78035.4368
Training Done.
Prediction Done.

===== Test Statistics =====
Task Type: classification
Accuracy: 0.4836
Precision (weighted): 0.4923
Recall (weighted):    0.4836
F1 Score (weighted):  0.4648
Confusion Matrix (rows: true, cols: pred):
[[873   8   1  11   7  26  24   2  11  37]
 [ 12 688  25  46  61  11  40  16  84  18]
 [223  74 225  12  63  53 108  11 211  19]
 [111 269   5 382  27  12  34  45  59  58]
 [  7 142   2  13 411 101 121 130  54  21]
 [228  48 106  60 154 210  74  12  42  66]
 [ 23  29   5   9  21  35 637  14 123 103]
 [ 18 143   3  30 119  52  20 474  66  76]
 [ 21 104   4  15  44   1  49  69 559 129]
 [138  34   2   4  13  42 146 173  72 377]]
Mean Uncertainty (normalized entropy): 0.8137
Min/Max Uncertainty: 0.3280 / 0.9862
Uncertainty StdDev:  0.1077
Class 0: Accuracy=0.8730, Mean Uncertainty=0.6332, Supp

In [None]:
# BNN-VI-Mean-Field with PCA on non-Linear-10D dataset

task_type = "classification"  # "regression" or "classification"
hid_dims = [30, 20]
prior_scale = 5.0
num_train = 500  # samples/steps
num_test = 100  # samples

print('Getting Dataset...', end=' ')
# train_X, train_y, test_X, test_y = get_MNIST_dataset()
# train_X, train_y, test_X, test_y = get_digits_dataset()
train_X, train_y, test_X, test_y = get_NonLinear_dataset(
    n_samples=50000,
    n_features=10,
    n_classes=10,
    n_clusters_per_class=3,
    n_informative=5,
    class_sep=2.0,
    nonlinear_strength=0.6,
    )
print(f'{train_X.shape[1]}D')
train_X, test_X = apply_PCA(train_X, test_X, explained_var_threshold=0.9)
bnn = BNN(in_dim=train_X.shape[1], out_dim=len(torch.unique(train_y)), hid_dims=hid_dims, hid_layers=len(hid_dims), prior_scale=prior_scale, task_type=task_type)
# bnn.fit_MCMC(train_X, train_y, num_samples=num_train)
bnn.fit_VI(train_X, train_y, num_steps=num_train, lr=1e-3)
print('Training Done.')
pred_y, pred_uncertainty = bnn.predict(test_X, num_samples=min(num_train, num_test))
print('Prediction Done.')
get_stats(task_type, pred_y, pred_uncertainty, test_y)
plot_3d_classification(test_X, test_y, pred_y, pred_uncertainty)

Getting Dataset... 10D
[PCA] Selected 7 components to preserve 90.0% variance.
[0000] ELBO loss: 185837.6703
[0100] ELBO loss: 104508.9062
[0200] ELBO loss: 99957.4846
[0300] ELBO loss: 94363.6365
[0400] ELBO loss: 84365.8256
Training Done.
Prediction Done.

===== Test Statistics =====
Task Type: classification
Accuracy: 0.4604
Precision (weighted): 0.4683
Recall (weighted):    0.4604
F1 Score (weighted):  0.4425
Confusion Matrix (rows: true, cols: pred):
[[854   1  35  10  14   2  15   1  55  13]
 [  5 504  76 122 118   7  51  18  66  34]
 [190   7 479 103  38  37  61  16  54  14]
 [ 49  75  67 607  23   7  20  13 110  31]
 [ 33 249  46   1 425   9 113  63  50  13]
 [118  45 222  71 161 179  87  48  28  41]
 [ 15  23  12  16  23  34 498  15 118 245]
 [ 51 217 104  28 193  12  39 187  17 153]
 [ 63 113 117  22  50   2  76  11 537   4]
 [146  86  90  24  11   7 188  48  67 334]]
Mean Uncertainty (normalized entropy): 0.8592
Min/Max Uncertainty: 0.3393 / 0.9906
Uncertainty StdDev:  0.092