In [None]:
import os
import argparse
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split

from abc import ABC, abstractmethod
from typing import Optional, Tuple

import torchvision.transforms as transforms
from torchvision.datasets import MNIST

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


class ConvNN(nn.Module):
    def __init__(
        self,
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size, stride=1)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size, stride=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        pooled_h = (img_rows - 2 * kernel_size + 2) // 2
        pooled_w = (img_cols - 2 * kernel_size + 2) // 2
        self.fc1 = nn.Linear(num_filters * pooled_h * pooled_w, dense_layer)
        self.fc2 = nn.Linear(dense_layer, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.fc2(x)


class ConvNNBackbone(nn.Module):
    def __init__(self,
                 num_filters: int = 32,
                kernel_size: int = 4,
                img_rows: int = 28,
                img_cols: int = 28,
                dense_layer: int = 128,
                use_fc2_features: bool = False) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, num_filters, kernel_size, stride=1)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size, stride=1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        pooled_h = (img_rows - 2 * kernel_size + 2) // 2
        pooled_w = (img_cols - 2 * kernel_size + 2) // 2
        self.fc1 = nn.Linear(num_filters * pooled_h * pooled_w, dense_layer)
        self.fc2 = nn.Linear(dense_layer, 10)
        self.use_fc2_features = use_fc2_features
        self.feature_dim = 10 if use_fc2_features else dense_layer


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        "Run a forward pass through the model"
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        if self.use_fc2_features:
            return self.fc2(x)
        else:
            return x

    def get_fc1_features(self, x: torch.Tensor) -> torch.Tensor:
        """Explicitly get fc1 features (128-dim) regardless of use_fc2_features setting."""
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        return x

    def get_logits(self, x: torch.Tensor) -> torch.Tensor:
        """Explicitly get full network output (10-dim logits)."""
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        return self.fc2(x)

    def freeze(self) -> None:
        for p in self.parameters():
            p.requires_grad = False

    def unfreeze(self) -> None:
        for p in self.parameters():
            p.requires_grad = True

    def freeze_except_fc2(self) -> None:
        """Freeze everything except fc2 layer"""
        for name, p in self.named_parameters():
            if 'fc2' not in name:
                p.requires_grad = False
            else:
                p.requires_grad = True

class HierarchicalBayesModel(nn.Module):
    """Base class for hierarchical parametrized basis function regression models."""
    def __init__(
        self,
        feature_extractor: Optional[ConvNNBackbone],
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
        num_outputs: int = 10,
        use_fc2: bool = False,
    ) -> None:
        super().__init__()
        if feature_extractor is not None:
            self.feature_extractor = feature_extractor
        else:
            self.feature_extractor = ConvNNBackbone(
                num_filters=num_filters,
                kernel_size=kernel_size,
                dense_layer=dense_layer,
                img_rows=img_rows,
                img_cols=img_cols,
                use_fc2_features=use_fc2
            )
        self.num_outputs = num_outputs
        self.use_fc2 = use_fc2

    def get_basis_func(self, x: torch.Tensor) -> torch.Tensor:
        """Get the basis function outputs (features) for input x."""
        if self.use_fc2:
            return self.feature_extractor(x)
        else:
            return self.feature_extractor.get_fc1_features(x)

    def freeze_basis_functions(self) -> None:
        self.feature_extractor.freeze()

    def unfreeze_basis_functions(self) -> None:
        self.feature_extractor.unfreeze()

    @abstractmethod
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass returning predictions. Must be implemented by subclasses."""
        pass

    @abstractmethod
    def predict_with_uncertainty(
        self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
        """Return predictions and predictive variances."""
        pass

    @abstractmethod
    def compute_predictive_variance(self, x: torch.Tensor) -> torch.Tensor:
        """Computes predictive variance of given data and returns it."""
        pass



class Analytical_HB(HierarchicalBayesModel):
    """Analytic Inference for Hierarchical Bayes Model"""
    def __init__(
        self,
        feature_extractor: Optional[ConvNNBackbone],
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
        num_outputs: int = 10,
        likelihood_variance: float = 1.0,
        prior_variance: float = 1.0,
        jitter: float = 1e-6,
    ) -> None:
        super().__init__(
            feature_extractor=feature_extractor,
            num_filters=num_filters,
            kernel_size=kernel_size,
            dense_layer=dense_layer,
            img_rows=img_rows,
            img_cols=img_cols,
            num_outputs=num_outputs,
        )
        self.likelihood_variance = likelihood_variance
        self.prior_variance = prior_variance
        self.jitter = jitter
        self.posterior_cov: Optional[torch.Tensor] = None
        self.posterior_mean: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Predictive mean using current posterior (falls back to raw basis if posterior absent)."""
        phi = self.get_basis_func(x)
        if self.posterior_mean is None:
            return phi
        return phi @ self.posterior_mean

    def predict_with_uncertainty(
        self, x: torch.Tensor, include_likelihood_noise: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
        """Predictive mean and variance using closed-form posterior."""
        phi = self.get_basis_func(x)
        if self.posterior_mean is None or self.posterior_cov is None:
            base_var = torch.zeros(phi.size(0), device=phi.device, dtype=phi.dtype)
            if include_likelihood_noise:
                base_var = base_var + self.likelihood_variance
            pred_var = base_var.unsqueeze(1).expand(-1, self.num_outputs)
            return phi, pred_var
        pred_mean = phi @ self.posterior_mean
        phi_S = phi @ self.posterior_cov
        epistemic = torch.sum(phi_S * phi, dim=1)
        if include_likelihood_noise:
            total_var = epistemic + self.likelihood_variance
        else:
            total_var = epistemic
        pred_var = total_var.unsqueeze(1).expand(-1, self.num_outputs)
        return pred_mean, pred_var


    @torch.no_grad()
    def compute_posterior(self, x: torch.Tensor, y: torch.Tensor) -> None:
        """Compute closed-form posterior given design matrix phi (n,k) and targets y (n,d)."""
        sigma_inv = 1.0 / self.likelihood_variance
        prior_inv = 1.0 / self.prior_variance
        phi = self.get_basis_func(x)
        phiT_phi = phi.T @ phi
        eye = torch.eye(phi.shape[1], device=phi.device, dtype=phi.dtype)
        s_inv = sigma_inv * phiT_phi + prior_inv * eye
        s_inv = s_inv + self.jitter * eye
        s_cov = torch.linalg.inv(s_inv)
        phiT_y = phi.T @ y
        mu = s_cov @ (sigma_inv * phiT_y)
        self.posterior_cov = s_cov
        self.posterior_mean = mu


    @torch.no_grad()
    def compute_posterior_from_loader(
        self, loader: torch.utils.data.DataLoader, device: torch.device
) -> None:
        """Accumulate sufficient statistics over a dataloader and compute posterior."""
        phiT_phi_accum = None
        phiT_y_accum = None
        n_total = 0
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            phi = self.get_basis_func(batch_x)
            if phiT_phi_accum is None:
                k = phi.shape[1]
                d = batch_y.shape[1]
                phiT_phi_accum = torch.zeros((k, k), device=device, dtype=phi.dtype)
                phiT_y_accum = torch.zeros((k, d), device=device, dtype=phi.dtype)
            phiT_phi_accum = phiT_phi_accum + phi.T @ phi
            phiT_y_accum = phiT_y_accum + phi.T @ batch_y
            n_total += batch_x.size(0)
        assert phiT_y_accum is not None
        sigma_inv = 1.0 / self.likelihood_variance
        prior_inv = 1.0 / self.prior_variance
        eye = torch.eye(k, device=device, dtype=phiT_phi_accum.dtype)
        s_inv = sigma_inv * phiT_phi_accum + prior_inv * eye
        s_inv = s_inv + self.jitter * eye
        s_cov = torch.linalg.inv(s_inv)
        mu = s_cov @ (sigma_inv * phiT_y_accum)
        self.posterior_cov = s_cov
        self.posterior_mean = mu

    @torch.no_grad()
    def compute_predictive_variance(self, X: np.ndarray) -> np.ndarray:
        """Compute predictive variance for each sample in X."""
        dev = next(self.feature_extractor.parameters()).device
        x_tensor = torch.from_numpy(X).float().to(dev)
        phi = self.get_basis_func(x_tensor)
        if self.posterior_mean is None or self.posterior_cov is None:
            variances = np.zeros(phi.size(0))
            return variances
        var_w = self.posterior_cov
        phi_S = phi @ var_w
        epistemic_var = torch.sum(phi_S * phi, dim=1)
        total_var = epistemic_var + self.likelihood_variance
        return total_var.detach().cpu().numpy()


class MFVI_HB(HierarchicalBayesModel):
    """Mean-field variational inference (MFVI) for Hierarchical Bayes Model"""

    def __init__(
        self,
        feature_extractor: Optional[ConvNNBackbone],
        num_filters: int = 32,
        kernel_size: int = 4,
        dense_layer: int = 128,
        img_rows: int = 28,
        img_cols: int = 28,
        num_outputs: int = 10,
        likelihood_variance: float = 1.0,
        prior_variance: float = 1.0,
        jitter: float = 1e-6,
        elbo_lr: float = 1e-3,
        vi_method: str ='closed',
        **kwargs
    ) -> None:
        super().__init__(
            feature_extractor=feature_extractor,
            num_filters=num_filters,
            kernel_size=kernel_size,
            dense_layer=dense_layer,
            img_rows=img_rows,
            img_cols=img_cols,
            num_outputs=num_outputs,
        )
        self.likelihood_variance = likelihood_variance
        self.prior_variance = prior_variance
        self.jitter = jitter
        self.elbo_lr = elbo_lr
        self.vi_method = vi_method
        self.posterior_cov: Optional[torch.Tensor] = None
        self.posterior_mean: Optional[torch.Tensor] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Predictive mean using current posterior"""
        phi = self.get_basis_func(x)
        if self.posterior_mean is None:
            return phi
        return phi @ self.posterior_mean

    def _compute_A_k(self, phi: torch.Tensor) -> torch.Tensor:
        """Computes A_k = \\sum_n phi_nk^2 for each feature k."""
        return torch.sum(phi ** 2, dim=0)

    def _update_M_closed_form(self, phi: torch.Tensor, y: torch.Tensor, sigma2: float) -> torch.Tensor:
        """Solve closed-form ridge regression for posterior mean M."""
        k = phi.shape[1]
        device = phi.device
        dtype = phi.dtype
        phiT_phi = phi.T @ phi
        phiT_y = phi.T @ y
        eye = torch.eye(k, device=device, dtype=dtype)
        lambda_reg = (sigma2 / self.prior_variance) * eye
        M = torch.linalg.solve(phiT_phi + lambda_reg, phiT_y)
        return M

    def _update_variances_closed_form(self, A_k: torch.Tensor, sigma2: float) -> torch.Tensor:
        """Compute posterior variances sigma_kd² using closed form.
        sigma_kd^2 = 1 / (A_k / \\sigma^2 + 1 / s^2)
        Note: sigma_kd^2 depends on k but not on d (shared variance across outputs for each feature).
        """
        inv_term = (A_k / sigma2) + (1.0 / self.prior_variance)
        var_diag = 1.0 / inv_term
        S = torch.diag(var_diag)
        return S

    def _update_sigma2(self, y: torch.Tensor, phi: torch.Tensor, M: torch.Tensor, S: torch.Tensor) -> float:
        """Compute ML estimate of observation noise variance sigma^2.
        sigma^2* = (1/ND) * (Res + V)
        where Res = ||Y - phi M||_F^2, V = \\sum_n phi_n^T diag(SS^T) phi_n
        Returns:
            sigma2: updated likelihood variance
        """
        n, d = y.shape
        residuals = y - phi @ M
        Res = torch.sum(residuals ** 2)
        S_diag = torch.diag(S)
        A_k = self._compute_A_k(phi)
        V = torch.sum(A_k * S_diag)

        sigma2_new = float((Res + V) / (n * d))
        return max(sigma2_new, 1e-8)

    def _compute_posterior_coord_ascent(self, phi: torch.Tensor, y: torch.Tensor, max_iters: int=20) -> None:
        """Coordinate ascent optimization for MFVI posterior."""
        max_iterations = max_iters
        tolerance = 1e-4
        sigma2 = self.likelihood_variance
        for iter in range(max_iterations):
            M = self._update_M_closed_form(phi, y, sigma2)
            A_k = self._compute_A_k(phi)
            S = self._update_variances_closed_form(A_k, sigma2)
            sigma2_new = self._update_sigma2(y, phi, M, S)
            sigma2_change = abs(sigma2_new - sigma2)
            if sigma2_change < tolerance:
                break
            sigma2 = sigma2_new

        self.posterior_mean = M.detach()
        self.posterior_cov = S.detach()

    def predict_with_uncertainty(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        phi = self.get_basis_func(x)
        if self.posterior_mean is None or self.posterior_cov is None:
            base_var = torch.zeros(phi.size(0), device=phi.device, dtype=phi.dtype)
            base_var = base_var + self.likelihood_variance
            pred_var = base_var.unsqueeze(1).expand(-1, self.num_outputs)
            return phi, pred_var
        predictive_mean = phi @ self.posterior_mean
        var_w = self.posterior_cov
        phi_S = phi @ var_w
        epistemic_var = torch.sum(phi_S * phi, dim=1)
        per_dim_var = self.likelihood_variance + epistemic_var
        predictive_var = per_dim_var.unsqueeze(1).expand(-1, self.num_outputs)
        return predictive_mean, predictive_var

    def compute_posterior(self, x: torch.Tensor, y: torch.Tensor) -> None:
        """Compute posterior for MFVI head with configurable VI method.

        Routes based on `self.vi_method`:
        - "closed": closed-form M,S with fixed σ² = likelihood_variance
        - "coord": coordinate-ascent updating M,S,σ² until convergence
        - fallback: original ELBO gradient optimization
        """
        phi = self.get_basis_func(x)

        if self.vi_method == "closed":
            sigma2 = self.likelihood_variance
            M = self._update_M_closed_form(phi, y, sigma2)
            A_k = self._compute_A_k(phi)
            S = self._update_variances_closed_form(A_k, sigma2)
            self.posterior_mean = M.detach()
            self.posterior_cov = S.detach()
            return

        if self.vi_method == "coord":
            self._compute_posterior_coord_ascent(phi, y)
            return

        print("ELBO calculation failed - falling back to gradient descent")
        k = phi.shape[1]
        d = y.shape[1]
        device = phi.device
        dtype = phi.dtype

        mu = torch.zeros((k, d), device=device, dtype=dtype, requires_grad=True)
        log_var_diag = torch.zeros((k,), device=device, dtype=dtype, requires_grad=True)

        optimizer = torch.optim.Adam([mu, log_var_diag], lr=self.elbo_lr)
        sigma_inv = 1.0 / self.likelihood_variance
        prior_inv = 1.0 / self.prior_variance

        for _ in range(50):
            var_diag = torch.exp(log_var_diag)
            eps = torch.randn((k, d), device=device, dtype=dtype)
            w_sample = mu + torch.sqrt(var_diag).unsqueeze(1) * eps

            y_pred = phi @ w_sample
            residuals = y - y_pred

            norm_const = -0.5 * d * phi.size(0) * np.log(2 * np.pi * self.likelihood_variance)
            mse_term = -0.5 * sigma_inv * torch.sum(residuals ** 2)
            var_contrib = torch.sum((phi @ torch.diag(var_diag)) * phi)
            trace_term = -0.5 * sigma_inv * var_contrib
            exp_log_lik = norm_const + mse_term + trace_term

            kl_var = 0.5 * d * torch.sum(-torch.log(var_diag + 1e-8) - 1.0 + prior_inv * var_diag)
            kl_mu = 0.5 * prior_inv * torch.sum(mu ** 2)
            kl_div = kl_var + kl_mu

            elbo = exp_log_lik - kl_div
            loss = -elbo / phi.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        self.posterior_mean = mu.detach()
        self.posterior_cov = torch.diag(torch.exp(log_var_diag.detach()))

    @torch.no_grad()
    def compute_predictive_variance(self, X: np.ndarray) -> np.ndarray:
        """Compute predictive variance for each sample in X."""
        dev = next(self.feature_extractor.parameters()).device
        x_tensor = torch.from_numpy(X).float().to(dev)
        phi = self.get_basis_func(x_tensor)
        if self.posterior_mean is None or self.posterior_cov is None:
            variances = np.zeros(phi.size(0))
            return variances
        var_w = self.posterior_cov
        phi_S = phi @ var_w
        epistemic_var = torch.sum(phi_S * phi, dim=1)
        total_var = epistemic_var + self.likelihood_variance

        return total_var.detach().cpu().numpy()


def get_balanced_initial_set(
    X: np.ndarray,
    y: np.ndarray,
    samples_per_class: int = 2,
    num_classes: int = 10,
):
    """Extract balanced initial set: samples_per_class per class; return updated pool."""
    init_idx = []
    for cls in range(num_classes):
        cls_indices = np.where(y == cls)[0]
        if len(cls_indices) < samples_per_class:
            raise ValueError(
                f"Not enough samples for class {cls}: {len(cls_indices)} available, need {samples_per_class}"
            )
        sampled = np.random.choice(cls_indices, size=samples_per_class, replace=False)
        init_idx.extend(sampled)
    init_idx = np.array(init_idx)
    pool_mask = np.ones(len(X), dtype=bool)
    pool_mask[init_idx] = False
    pool_idx = np.where(pool_mask)[0]

    return X[init_idx], y[init_idx], X[pool_idx], y[pool_idx]

class LoadData:
    """Download, split, and prepare MNIST for active learning."""

    def __init__(
        self,
        seed: int = 271,
        pretrain_size: int = 1000,
        val_size: int = 100,
        train_size: int = 20,
        root: str = "data",
        initial_per_class: int = 2,
    ) -> None:
        self.seed = seed
        self.pretrain_size = pretrain_size
        self.train_size = train_size
        self.val_size = val_size
        self.root = root
        self.initial_per_class = initial_per_class
        self.mnist_train, self.mnist_test = self.download_dataset()
        self.pool_size = len(self.mnist_train) - self.train_size - self.val_size - self.pretrain_size
        (
            self.X_pretrain_All,
            self.y_pretrain_All,
            self.X_train_All,
            self.y_train_All,
            self.X_val,
            self.y_val,
            self.X_pool,
            self.y_pool,
            self.X_test,
            self.y_test,
        ) = self.split_and_load_dataset()
        self.X_init, self.y_init = self.preprocess_training_data()

    def tensor_to_np(self, tensor_data: torch.Tensor) -> np.ndarray:
        return tensor_data.detach().cpu().numpy()

    def check_mnist_folder(self) -> bool:
        return not os.path.exists(os.path.join(self.root, "MNIST"))

    def download_dataset(self):
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )
        download = self.check_mnist_folder()
        mnist_train = MNIST(
            self.root, train=True, download=download, transform=transform
        )
        mnist_test = MNIST(self.root, train=False, download=download, transform=transform)
        return mnist_train, mnist_test

    def split_and_load_dataset(self):
        generator = torch.Generator().manual_seed(self.seed)
        pretrain_set, train_set, val_set, pool_set = random_split(
            self.mnist_train,
            [self.pretrain_size, self.train_size, self.val_size, self.pool_size],
            generator=generator,
        )
        pretrain_loader = DataLoader(
            dataset = pretrain_set, batch_size=self.pretrain_size,shuffle=True
        )
        train_loader = DataLoader(
            dataset=train_set, batch_size=self.train_size, shuffle=True
        )
        val_loader = DataLoader(dataset=val_set, batch_size=self.val_size, shuffle=True)
        pool_loader = DataLoader(
            dataset=pool_set, batch_size=self.pool_size, shuffle=True
        )
        test_loader = DataLoader(dataset=self.mnist_test, batch_size=10000, shuffle=True)
        X_pretrain_All, y_pretrain_All = next(iter(pretrain_loader))
        X_train_All, y_train_All = next(iter(train_loader))
        X_val, y_val = next(iter(val_loader))
        X_pool, y_pool = next(iter(pool_loader))
        X_test, y_test = next(iter(test_loader))
        return X_pretrain_All, y_pretrain_All, X_train_All, y_train_All, X_val, y_val, X_pool, y_pool, X_test, y_test


    def preprocess_training_data(self):
        """Build a balanced initial set with equal samples per class using the helper."""
        per_class = self.initial_per_class
        X_concat = torch.cat([self.X_train_All, self.X_pool], dim=0).detach().cpu().numpy()
        y_concat = torch.cat([self.y_train_All, self.y_pool], dim=0).detach().cpu().numpy()
        X_init_np, y_init_np, X_pool_np, y_pool_np = get_balanced_initial_set(
            X_concat, y_concat, samples_per_class=per_class, num_classes=10,
        )
        X_init = torch.from_numpy(X_init_np).float()
        y_init = torch.from_numpy(y_init_np).long()
        self.X_pool = torch.from_numpy(X_pool_np).float()
        self.y_pool = torch.from_numpy(y_pool_np).long()

        print(f"Initial training data points: {X_init.shape[0]}")
        binc = np.bincount(y_init.detach().cpu().numpy(), minlength=10)
        print(f"Data distribution for each class: {binc}")
        return X_init, y_init


    def load_all(self):
        return (
            self.tensor_to_np(self.X_pretrain_All),
            self.tensor_to_np(self.y_pretrain_All),
            self.tensor_to_np(self.X_init),
            self.tensor_to_np(self.y_init),
            self.tensor_to_np(self.X_val),
            self.tensor_to_np(self.y_val),
            self.tensor_to_np(self.X_pool),
            self.tensor_to_np(self.y_pool),
            self.tensor_to_np(self.X_test),
            self.tensor_to_np(self.y_test),
        )


def uniform_acq(model: nn.Module, X_pool: np.ndarray, n_query: int = 10, **_):
    n_query = min(n_query, len(X_pool))
    query_idx = np.random.choice(range(len(X_pool)), size=n_query, replace=False)
    return query_idx, X_pool[query_idx]

def predictive_variance_acq(model: nn.Module, X_pool: np.ndarray, n_query: int= 10, **_):
    n_query = min(n_query, len(X_pool))
    variances = model.compute_predictive_variance(X_pool)
    query_idx = np.argsort(-variances)[:n_query]
    return query_idx, X_pool[query_idx]

def _make_loader(X: np.ndarray, y: np.ndarray, batch_size: int, shuffle: bool = True):
    dataset = TensorDataset(torch.from_numpy(X).float(), torch.from_numpy(y).long())
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

def _accuracy(model: nn.Module, X: np.ndarray, y: np.ndarray, device: torch.device) -> float:
    model.eval()
    with torch.no_grad():
        xb = torch.from_numpy(X).float().to(device)
        yb = torch.from_numpy(y).long().to(device)
        preds = torch.argmax(model(xb), dim=1)
        return float((preds == yb).float().mean().cpu().item())

def one_hot_labels(y_labels: np.ndarray, num_classes: int = 10) -> np.ndarray:
    """Convert integer class labels to one-hot encoded array."""
    y_labels = y_labels.astype(int)
    y_one_hot = np.zeros((len(y_labels), num_classes), dtype=np.float32)
    y_one_hot[np.arange(len(y_labels)), y_labels] = 1.0
    return y_one_hot

def _mse_rmse(model: nn.Module, X: np.ndarray, y: np.ndarray, device: torch.device) -> tuple[float, float]:
    """Compute MSE and RMSE on raw logits vs one-hot targets."""
    model.eval()
    with torch.no_grad():
        xb = torch.from_numpy(X).float().to(device)
        y_one_hot = torch.from_numpy(one_hot_labels(y.astype(int), num_classes=10)).float().to(device)
        logits = model(xb)
        mse = torch.mean((logits - y_one_hot) ** 2).item()
        rmse = float(np.sqrt(mse))
        return mse, rmse

def pretrainer(
        backbone_model: ConvNNBackbone, X_pretrain: np.ndarray, Y_pretrain: np.ndarray,
               epochs: int, lr: float, weight_decay: float, device: torch.device,
                       batch_size: int = 128, task: str = 'classification',
                       X_val: np.ndarray | None = None, y_val: np.ndarray | None = None,
                       val_batch_size: int = 256) -> Tuple[ConvNNBackbone, float, float, float]:
    """Train CNN backbone and report train acc plus validation acc/MSE."""
    opt = torch.optim.Adam(list(backbone_model.parameters()), lr=lr, weight_decay=weight_decay)
    if task == 'classification':
        crit = nn.CrossEntropyLoss()
    else:
        crit = nn.MSELoss()
    backbone_model.to(device).train()

    if task == 'classification':
        y_labels = Y_pretrain.astype(int)
        loader = _make_loader(X_pretrain, y_labels, batch_size, shuffle=True)
    else:
        y_labels = Y_pretrain.astype(int)
        y_one_hot = np.zeros((len(y_labels), 10), dtype=np.float32)
        y_one_hot[np.arange(len(y_labels)), y_labels] = 1.0
        dataset = TensorDataset(
            torch.from_numpy(X_pretrain).float(),
            torch.from_numpy(y_one_hot).float()
        )
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for epoch in range(epochs):
        for xb, yb in loader:
            xb = xb.to(device)
            yb = yb.to(device)
            opt.zero_grad()
            preds = backbone_model(xb)
            loss = crit(preds, yb)
            loss.backward()
            opt.step()

    backbone_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device)
            preds = backbone_model(xb)
            if task == 'classification':
                yb_labels = yb.to(device)
                pred_labels = torch.argmax(preds, dim=1)
                correct += (pred_labels == yb_labels).sum().item()
            else:
                yb = yb.to(device)
                pred_labels = torch.argmax(preds, dim=1)
                true_labels = torch.argmax(yb, dim=1)
                correct += (pred_labels == true_labels).sum().item()
            total += xb.size(0)
    pretrain_acc = correct / total
    val_acc = float('nan')
    val_mse = float('nan')
    if X_val is not None and y_val is not None:
        if task == 'classification':
            y_val_labels = y_val.astype(int)
            val_loader = _make_loader(X_val, y_val_labels, val_batch_size, shuffle=False)
            y_val_one_hot = np.zeros((len(y_val_labels), 10), dtype=np.float32)
            y_val_one_hot[np.arange(len(y_val_labels)), y_val_labels] = 1.0
            y_val_one_hot_t = torch.from_numpy(y_val_one_hot).float().to(device)
        else:
            y_val_labels = y_val.astype(int)
            y_val_one_hot = np.zeros((len(y_val_labels), 10), dtype=np.float32)
            y_val_one_hot[np.arange(len(y_val_labels)), y_val_labels] = 1.0
            dataset = TensorDataset(torch.from_numpy(X_val).float(), torch.from_numpy(y_val_one_hot).float())
            val_loader = DataLoader(dataset, batch_size=val_batch_size, shuffle=False)
            y_val_one_hot_t = None

        correct_val = 0
        total_val = 0
        mse_accum = 0.0
        with torch.no_grad():
            for xb, yb in val_loader:
                xb = xb.to(device)
                if task == 'classification':
                    yb_labels = yb.to(device)
                    logits = backbone_model(xb)
                    pred_labels = torch.argmax(logits, dim=1)
                    correct_val += (pred_labels == yb_labels).sum().item()
                    if y_val_one_hot_t is not None:
                        start = total_val
                        end = total_val + xb.size(0)
                        target_slice = y_val_one_hot_t[start:end]
                        probs = torch.softmax(logits, dim=1)
                        mse_accum += torch.sum((probs - target_slice) ** 2).item()
                else:
                    logits = backbone_model(xb)
                    yb = yb.to(device)
                    pred_labels = torch.argmax(logits, dim=1)
                    true_labels = torch.argmax(yb, dim=1)
                    correct_val += (pred_labels == true_labels).sum().item()
                    probs = torch.softmax(logits, dim=1)
                    mse_accum += torch.sum((probs - yb) ** 2).item()
                total_val += xb.size(0)
        val_acc = correct_val / total_val if total_val > 0 else float('nan')
        val_mse = mse_accum / total_val if total_val > 0 else float('nan')
    return backbone_model, pretrain_acc, val_acc, val_mse

def init_head_posterior(
    model: torch.nn.Module,
    X_init: np.ndarray,
    y_init: np.ndarray,
    device: torch.device,
    num_classes: int = 10,
) -> tuple[np.ndarray, np.ndarray]:
    """Initialize Bayesian head posterior on the initial labeled set."""
    X_labeled = X_init.copy()
    y_labeled_oh = one_hot_labels(y_init, num_classes=num_classes)
    x_t = torch.from_numpy(X_labeled).float().to(device)
    y_t = torch.from_numpy(y_labeled_oh).float().to(device)
    model.compute_posterior(x=x_t, y=y_t)
    return X_labeled, y_labeled_oh

def select_queries(
    model: torch.nn.Module,
    X_pool_run: np.ndarray,
    n_query: int,
    acq_fn: str = "predictive",
) -> np.ndarray:
    """Select indices from pool according to the acquisition function."""
    n_query = min(n_query, len(X_pool_run))
    if n_query <= 0:
        return np.array([], dtype=int)
    if acq_fn == "predictive":
        variances = model.compute_predictive_variance(X_pool_run)
        return np.argsort(-variances)[:n_query]
    return np.random.choice(range(len(X_pool_run)), size=n_query, replace=False)

def update_sets_after_query(
    X_labeled: np.ndarray,
    y_labeled_oh: np.ndarray,
    X_pool_run: np.ndarray,
    y_pool_oh_run: np.ndarray,
    query_idx: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Append queried samples to labeled set and remove them from pool."""
    if query_idx.size == 0:
        return X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run
    X_labeled = np.concatenate([X_labeled, X_pool_run[query_idx]], axis=0)
    y_labeled_oh = np.concatenate([y_labeled_oh, y_pool_oh_run[query_idx]], axis=0)
    X_pool_run = np.delete(X_pool_run, query_idx, axis=0)
    y_pool_oh_run = np.delete(y_pool_oh_run, query_idx, axis=0)
    return X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run

def recompute_posterior(
    model: torch.nn.Module,
    X_labeled: np.ndarray,
    y_labeled_oh: np.ndarray,
    device: torch.device,
) -> None:
    """Recompute posterior on the current labeled set."""
    x_t = torch.from_numpy(X_labeled).float().to(device)
    y_t = torch.from_numpy(y_labeled_oh).float().to(device)
    model.compute_posterior(x=x_t, y=y_t)

def acquisition_round(
    model: torch.nn.Module,
    X_labeled: np.ndarray,
    y_labeled_oh: np.ndarray,
    X_pool_run: np.ndarray,
    y_pool_oh_run: np.ndarray,
    n_query: int,
    acq_fn: str,
    device: torch.device,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Perform one acquisition round: select, update sets, and recompute posterior."""
    query_idx = select_queries(model, X_pool_run, n_query=n_query, acq_fn=acq_fn)
    X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run = update_sets_after_query(
        X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run, query_idx
    )
    recompute_posterior(model, X_labeled, y_labeled_oh, device)
    return X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run, query_idx




args = argparse.Namespace(
    seed=271,
    batch_size=64,
    pretrain_epochs=100,
    acq_rounds=95,
    vi_method='closed',                # ['closed','coord']
    pretrain_lr=1e-3,
    lr=1e-3,
    pretrain_weight_decay=1e-2,
    weight_decay=1e-2,
    query=10,
    pretrain_size=200,
    val_size=1000,
    initial_labeled_per_class=5,       # Number of samples per class in initial labeled set
    pretrain_task="classification",    # ['regression', 'classification']
    inference_type="all",              # ['mfvi','analytic','all']
    acq_fn="all",                      # ['predictive','random','all']
    likelihood_variance=1,
    prior_variance=1,
    result_dir="min_extension",
)

set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
data_loader = LoadData(
    seed=args.seed,
    pretrain_size=args.pretrain_size,
    val_size=args.val_size,
    initial_per_class=args.initial_labeled_per_class
)
(X_pretrain, y_pretrain, X_init, y_init, X_val, y_val, X_pool, y_pool, X_test, y_test) = data_loader.load_all()

datasets = {
    "X_pretrain":X_pretrain,
    "y_pretrain":y_pretrain,
    "X_init": X_init,
    "y_init": y_init,
    "X_val": X_val,
    "y_val": y_val,
    "X_pool": X_pool,
    "y_pool": y_pool,
    "X_test": X_test,
    "y_test": y_test,
}
backbone = ConvNNBackbone(
                num_filters=32,
                kernel_size=4,
                img_rows=28,
                img_cols=28,
                dense_layer=128,
                use_fc2_features=True)
backbone, backbone_train_acc, backbone_val_acc, backbone_val_mse = pretrainer(
    backbone_model=backbone,
    X_pretrain=datasets['X_pretrain'],
    Y_pretrain=datasets['y_pretrain'],
    epochs=args.pretrain_epochs,
    lr=args.pretrain_lr,
    weight_decay=args.pretrain_weight_decay,
    device=device,
    task=args.pretrain_task,
    X_val=datasets['X_val'],
    y_val=datasets['y_val']
)
print(f"Backbone pretraining complete. Validation MSE: {backbone_val_mse:.6f}")
backbone.freeze()
models = []
acq_fns = []
if args.inference_type=='all':
    models = ['analytic','mfvi']
else:
    models.append(args.inference_type)
if args.acq_fn == 'all':
    acq_fns = ['predictive','random']
else:
    acq_fns.append(args.acq_fn)

results_rmse = {}
results_mse = {}
for model_type in models:
    for acq_fn in acq_fns:
        if model_type == 'analytic':
            model = Analytical_HB(feature_extractor=backbone,
                                likelihood_variance=args.likelihood_variance,
                                prior_variance=args.prior_variance)
        elif model_type == 'mfvi':
            model = MFVI_HB(feature_extractor=backbone,
                            vi_method=args.vi_method,
                            likelihood_variance=args.likelihood_variance,
                            prior_variance=args.prior_variance,
            )
        X_labeled, y_labeled_oh = init_head_posterior(
            model,
            datasets["X_init"],
            datasets["y_init"],
            device,
            num_classes=10,
        )

        X_pool_run = datasets["X_pool"].copy()
        y_pool_oh_run = one_hot_labels(datasets["y_pool"].astype(int), num_classes=10)

        rmse_curve = []
        mse_curve = []
        for rnd in range(args.acq_rounds):
            X_labeled, y_labeled_oh, X_pool_run, y_pool_oh_run, _ = acquisition_round(
                model=model,
                X_labeled=X_labeled,
                y_labeled_oh=y_labeled_oh,
                X_pool_run=X_pool_run,
                y_pool_oh_run=y_pool_oh_run,
                n_query=args.query,
                acq_fn=acq_fn,
                device=device,
            )
            val_mse, val_rmse = _mse_rmse(model, datasets["X_val"], datasets["y_val"], device)
            mse_curve.append(val_mse)
            rmse_curve.append(val_rmse)

            print(f"Round {rnd+1}/{args.acq_rounds} [{model_type}-{acq_fn}] | MSE: {val_mse:.6f} | RMSE: {val_rmse:.6f}")


        key = f"{acq_fn}-{model_type}"


        results_rmse[key] = rmse_curve
        results_mse[key] = mse_curve



In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

friendly_labels = {
    "predictive-analytic": "Predictive (Analytic)",
    "random-analytic": "Random (Analytic)",
    "predictive-mfvi": "Predictive (MFVI)",
    "random-mfvi": "Random (MFVI)",
}

colors = {
    "analytic": "tab:blue",
    "mfvi": "tab:red",
}

linestyles = {
    "predictive": "-",
    "random": "--",
}

plt.figure(figsize=(8, 5))
if 'results_rmse' in globals():
    for inf in models:
        for acq in acq_fns:
            key = f"{acq}-{inf}"
            if key in results_rmse:
                curve = results_rmse[key]
                x_vals = np.arange(0, len(curve))
                label = friendly_labels.get(key, key)
                color = colors.get(inf, "tab:gray")
                linestyle = linestyles.get(acq, "-")
                plt.plot(x_vals, curve, label=label, linestyle=linestyle, color=color)
plt.xlabel("Acquisition round")
plt.ylabel("RMSE")
plt.xlim(left=0,right=100)
plt.title("Validation RMSE: Random vs Predictive (Analytic vs MFVI)")
plt.grid()
plt.legend()
plt.show()