# Novel Extension

In [None]:
import os
import argparse
import requests
import tarfile
import math
from PIL import Image
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.models as models
from torchvision import transforms
from typing import Callable, Tuple

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def set_seed(seed: int) -> None:
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

url = "https://s3.amazonaws.com/fast-ai-imagelocal/biwi_head_pose.tgz"
save_path = "biwi_head_pose.tgz"
data_dir = "biwi_head_pose"

if not os.path.exists(data_dir):
    print("Downloading BIWI dataset...")
    response = requests.get(url, stream=True)
    with open(save_path, "wb") as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)

    print("Extracting...")
    with tarfile.open(save_path, "r:gz") as tar:
        tar.extractall()
    print("Done!")
else:
    print("Dataset already exists.")

In [None]:
class BiwiDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = sorted(glob(f"{root_dir}/*/*_rgb.jpg"))  # RGB only

    def select_faces(self,face_ids):
        self.image_paths = []
        for face_id in face_ids:
            self.image_paths += sorted(glob(f"{self.root_dir}/{face_id:0>2}/*_rgb.jpg"))

    def __len__(self):
        return len(self.image_paths)

    def convert_matrix_to_euler(self, rotation_matrix):
        R = rotation_matrix
        sy = np.sqrt(R[0, 0] * R[0, 0] + R[1, 0] * R[1, 0])
        singular = sy < 1e-6

        if not singular:
            x = np.arctan2(R[2, 1], R[2, 2])
            y = np.arctan2(-R[2, 0], sy)
            z = np.arctan2(R[1, 0], R[0, 0])
        else:
            x = np.arctan2(-R[1, 2], R[1, 1])
            y = np.arctan2(-R[2, 0], sy)
            z = 0
        return np.array([x, y, z]) * (180 / np.pi)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        label_path = img_path.replace("_rgb.jpg", "_pose.txt")
        with open(label_path, "r") as f:
            lines = f.readlines()
            matrix = np.array(
                [
                    [float(v) for v in lines[0].strip().split()],
                    [float(v) for v in lines[1].strip().split()],
                    [float(v) for v in lines[2].strip().split()],
                ]
            )

        # Convert Matrix to Angles - (Pitch, Yaw, Roll)
        angles = self.convert_matrix_to_euler(matrix).astype(np.float32)

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(angles)


transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

pool_dataset = BiwiDataset(data_dir, transform=transform)
pool_dataset.select_faces(list(range(1,21)))
test_dataset = BiwiDataset(data_dir, transform=transform)
test_dataset.select_faces(list(range(1,25)))

def make_loader(dataset,indices, batch_size, shuffle):
    subset = Subset(dataset, indices.tolist())
    return DataLoader(subset, batch_size=batch_size, shuffle=shuffle)

def prep_data(seed, pool_ratio, pool_size, val_size, test_size):
    """
    Prepare data for the experiments.
    Returns:
        loaders : Tuple[loader] init, pool, deterministic, val, test loaders
        pool_dataset : full dataset of available points
        idxs: Tuple[torch.Tensor] labeled, pool indices - to keep consistent starting splits across runs
    """
    total = len(pool_dataset)
    if total < pool_size + val_size:
        raise ValueError("Not enough points available in dataset")
    init_size = int(pool_ratio * pool_size)
    # Make splits
    perm = torch.randperm(total, generator=torch.Generator().manual_seed(seed))
    idx_labeled = perm[:init_size]
    idx_pool = perm[init_size :pool_size]
    idx_val = perm[pool_size : pool_size + val_size]
    test_perm = torch.randperm(len(test_dataset), generator=torch.Generator().manual_seed(seed))
    idx_test = test_perm[:test_size]
    # Make loaders
    init_loader = make_loader(pool_dataset,idx_labeled, batch_size=32, shuffle=True)
    pool_loader = make_loader(pool_dataset,idx_pool, batch_size=64, shuffle=False)
    determ_train_loader = make_loader(pool_dataset, torch.cat((idx_labeled,idx_pool)), batch_size=128, shuffle=False)
    val_loader = make_loader(pool_dataset,idx_val, batch_size=64, shuffle=False)
    test_loader = make_loader(test_dataset,idx_test, batch_size=64, shuffle=True)
    loaders = (init_loader, pool_loader, determ_train_loader, val_loader, test_loader)
    idxs = (idx_labeled, idx_pool)
    return loaders, pool_dataset, idxs


def compute_label_stats(loader: DataLoader, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute mean and std of labels from a loader."""
    ys = []
    with torch.no_grad():
        for _, batch_y in loader:
            ys.append(batch_y)
    y_all = torch.cat(ys, dim=0)
    y_mean = y_all.mean(dim=0)
    y_std = y_all.std(dim=0) + 1e-6
    return y_mean.to(device), y_std.to(device)

In [None]:
class EfficientNetBasis(nn.Module):
    def __init__(self):
        super().__init__()
        # Load the full model
        weights = models.EfficientNet_B0_Weights.DEFAULT
        full_model = models.efficientnet_b0(weights=weights)
        # EfficientNet has .features (convolutions) and .classifier (linear)
        self.features = full_model.features
        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten()
        )
        # Freeze parameters
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        return x

In [None]:
from abc import ABC, abstractmethod
from typing import Optional, Tuple

class BayesianRegressionHead(ABC):
    """Base class for Bayesian linear regression heads."""
    def __init__(
        self,
        basis_dim: int,
        num_outputs: int,
        likelihood_variance: float = 1.0,
        prior_variance: float = 1.0,
        jitter: float = 1e-6,
    ):
        self.basis_dim = basis_dim
        self.num_outputs = num_outputs
        self.likelihood_variance = likelihood_variance
        self.prior_variance = prior_variance
        self.jitter = jitter
        self.posterior_mean: Optional[torch.Tensor] = None  # shape (basis_dim, num_outputs)
        self.posterior_cov: Optional[torch.Tensor] = None   # shape (basis_dim, basis_dim)

    @abstractmethod
    def forward(self, phi: torch.Tensor) -> torch.Tensor:
        """Predict given basis functions phi: (batch, basis_dim) -> (batch, num_outputs)"""
        pass

    @abstractmethod
    def predict_with_uncertainty(self, phi: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return (mean, variance) predictions."""
        pass

    @abstractmethod
    def compute_posterior(self, phi: torch.Tensor, y: torch.Tensor) -> None:
        """Fit posterior given basis functions and targets."""
        pass


class AnalyticalBayesianHead(BayesianRegressionHead):
    """Closed-form analytical Bayesian linear regression."""

    def forward(self, phi: torch.Tensor) -> torch.Tensor:
        if self.posterior_mean is None:
            return phi  # Return raw basis if no posterior
        return phi @ self.posterior_mean

    def predict_with_uncertainty(
        self, phi: torch.Tensor, include_likelihood_noise: bool = True
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.posterior_mean is None or self.posterior_cov is None:
            base_var = torch.zeros(phi.size(0), device=phi.device, dtype=phi.dtype)
            if include_likelihood_noise:
                base_var = base_var + self.likelihood_variance
            pred_var = base_var.unsqueeze(1).expand(-1, self.num_outputs)
            return phi, pred_var

        pred_mean = phi @ self.posterior_mean
        phi_S = phi @ self.posterior_cov
        epistemic = torch.sum(phi_S * phi, dim=1)
        total_var = epistemic + self.likelihood_variance if include_likelihood_noise else epistemic
        pred_var = total_var.unsqueeze(1).expand(-1, self.num_outputs)
        return pred_mean, pred_var

    @torch.no_grad()
    def compute_posterior(self, phi: torch.Tensor, y: torch.Tensor) -> None:
        sigma_inv = 1.0 / self.likelihood_variance
        prior_inv = 1.0 / self.prior_variance

        phiT_phi = phi.T @ phi
        eye = torch.eye(phi.shape[1], device=phi.device, dtype=phi.dtype)
        s_inv = sigma_inv * phiT_phi + prior_inv * eye
        s_inv = s_inv + self.jitter * eye
        s_cov = torch.linalg.inv(s_inv)

        phiT_y = phi.T @ y
        mu = s_cov @ (sigma_inv * phiT_y)

        self.posterior_cov = s_cov
        self.posterior_mean = mu


class MFVIBayesianHead(BayesianRegressionHead):
    """Mean-field variational inference Bayesian linear regression (closed-form updates)."""
    def __init__(
        self,
        basis_dim: int,
        num_outputs: int,
        likelihood_variance: float = 1.0,
        prior_variance: float = 1.0,
        jitter: float = 1e-6,
    ):
        super().__init__(basis_dim, num_outputs, likelihood_variance, prior_variance, jitter)

    def forward(self, phi: torch.Tensor) -> torch.Tensor:
        if self.posterior_mean is None:
            return phi
        return phi @ self.posterior_mean

    def predict_with_uncertainty(self, phi: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        if self.posterior_mean is None or self.posterior_cov is None:
            base_var = torch.zeros(phi.size(0), device=phi.device, dtype=phi.dtype)
            base_var = base_var + self.likelihood_variance
            pred_var = base_var.unsqueeze(1).expand(-1, self.num_outputs)
            return phi, pred_var
        predictive_mean = phi @ self.posterior_mean
        phi_S = phi @ self.posterior_cov
        epistemic = torch.sum(phi_S * phi, dim=1)
        total_var = epistemic + self.likelihood_variance
        pred_var = total_var.unsqueeze(1).expand(-1, self.num_outputs)

        return predictive_mean, pred_var

    @torch.no_grad()
    def compute_posterior(self, phi: torch.Tensor, y: torch.Tensor) -> None:
        sigma2 = self.likelihood_variance
        prior_var = self.prior_variance
        D = phi.shape[1]

        phiT_phi = phi.T @ phi
        eye = torch.eye(D, device=phi.device, dtype=phi.dtype)

        s_inv = (1.0 / sigma2) * phiT_phi + (1.0 / prior_var) * eye
        s_inv = s_inv + self.jitter * eye
        s_cov = torch.linalg.inv(s_inv)

        phiT_y = phi.T @ y
        mu = s_cov @ ((1.0 / sigma2) * phiT_y)

        diag_vals = torch.diag(s_cov)
        self.posterior_cov = torch.diag(diag_vals).detach()
        self.posterior_mean = mu.detach()

class LaplaceMFVIBayesianHead(BayesianRegressionHead):
    """Mean-field variational inference Bayesian linear regression with Laplace Prior"""

    def __init__(
        self,
        basis_dim: int,
        num_outputs: int,
        prior_variance: float = 1.0,
        jitter: float = 1e-6,
        likelihood_scale: float = 1.0,
        learning_rate: float = 1e-3,
        num_iters: int = 1000,
        batch_size: int = 64,
        **kwargs
    ):
        super().__init__(basis_dim, num_outputs, 2*(likelihood_scale**2), prior_variance, jitter)
        self.likelihood_scale = likelihood_scale
        self.lr = learning_rate
        self.steps = num_iters
        self.batch_size = batch_size


    def forward(self, phi: torch.Tensor) -> torch.Tensor:
        """Predictive mean using posterior mean weights."""
        if self.posterior_mean is None:
            return phi @ self.q_mu
        return phi @ self.posterior_mean

    def predict_with_uncertainty(self, phi: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Predictive mean and variance.
        """
        if self.posterior_mean is None or self.posterior_cov is None:
            base_var = torch.zeros(phi.size(0), device=phi.device, dtype=phi.dtype)
            base_var = base_var + self.likelihood_variance
            pred_var = base_var.unsqueeze(1).expand(-1, self.num_outputs)
            return phi, pred_var
        predictive_mean = phi @ self.posterior_mean
        var_w = self.posterior_cov
        phi_sq = phi ** 2
        epistemic_var = phi_sq @ var_w
        aleatoric_var = 2 * (self.likelihood_scale ** 2)
        predictive_var = epistemic_var + aleatoric_var

        return predictive_mean, predictive_var

    def _kl_divergence(self, mu: torch.Tensor, log_var: torch.Tensor) -> torch.Tensor:
        """
        Compute KL(q(w) || p(w)) where both are Gaussian.
        p(w) ~ N(0, prior_variance * I)
        q(w) ~ N(mu, exp(log_var) * I)
        """
        var = torch.exp(log_var)
        prior_var = self.prior_variance
        kl = 0.5 * torch.sum(
            (var / prior_var) +
            (mu**2 / prior_var) -
            1.0 -
            (log_var - math.log(prior_var))
        )
        return kl

    def compute_posterior(self, phi: torch.Tensor, y: torch.Tensor) -> None:
        """
        Fit posterior using Gradient Descent (Adam) instead of Coordinate Ascent.
        We optimize local tensors and then save the result to self.posterior_mean.
        """
        device = phi.device
        N, D = phi.shape
        num_outputs = y.shape[1]
        q_mu = torch.zeros(D, num_outputs, device=device, requires_grad=True)
        with torch.no_grad():
             q_mu.normal_(0, 0.01)

        q_log_var = torch.ones(D, num_outputs, device=device, requires_grad=True)
        with torch.no_grad():
             init_log_var = math.log(self.prior_variance)
             q_log_var.fill_(init_log_var)
        optimizer = torch.optim.Adam([q_mu, q_log_var], lr=self.lr)
        dataset = torch.utils.data.TensorDataset(phi, y)
        loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, shuffle=True)
        for step in range(self.steps):
            for batch_phi, batch_y in loader:
                optimizer.zero_grad()
                q_log_var_clamped = torch.clamp(q_log_var, min=-10, max=2)
                std = torch.exp(0.5 * q_log_var_clamped)
                epsilon = torch.randn_like(std)
                w_sample = q_mu + std * epsilon
                y_pred = batch_phi @ w_sample
                l1_loss = torch.mean(torch.abs(batch_y - y_pred))
                nll = l1_loss / self.likelihood_scale
                kl = self._kl_divergence(q_mu, q_log_var_clamped)
                loss = nll + kl/N

                loss.backward()
                torch.nn.utils.clip_grad_norm_([q_mu, q_log_var], max_norm=1.0)
                optimizer.step()
        self.posterior_mean = q_mu.detach()
        self.posterior_cov = torch.exp(q_log_var).detach()

    @torch.no_grad()
    def update_likelihood_scale(self, mae_value: float) -> None:
        """
        Updates the internal likelihood scale (b) based on an externally
        computed Mean Absolute Error.
        """
        new_scale = max(float(mae_value), 1e-6)
        self.likelihood_scale = new_scale
        self.likelihood_scale = min(self.likelihood_scale, 2.0)
        self.likelihood_variance = 2 * (new_scale ** 2)


class BayesianEfficientNetModel(nn.Module):
    """Combines the EfficientNet backbone + Bayesian linear head."""
    def __init__(
        self,
        backbone: nn.Module,
        num_outputs: int = 3,
        head_type: str = "analytical",
        likelihood_variance: float = 1.0,
        prior_variance: float = 1.0,
        likelihood_scale: float = 1.0,
    ):
        super().__init__()
        self.backbone = backbone
        self.num_outputs = num_outputs
        self.likelihood_variance = likelihood_variance
        self.head_type = head_type
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 224, 224)
            basis = backbone(dummy_input)
            basis_dim = basis.shape[1]
        if head_type == "analytical":
            self.head = AnalyticalBayesianHead(
                basis_dim=basis_dim,
                num_outputs=num_outputs,
                likelihood_variance=likelihood_variance,
                prior_variance=prior_variance,
            )
        elif head_type == "mfvi":
            self.head = MFVIBayesianHead(
                basis_dim=basis_dim,
                num_outputs=num_outputs,
                likelihood_variance=likelihood_variance,
                prior_variance=prior_variance,
            )
        else:
            raise ValueError(f"Unknown head_type: {head_type}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        phi = self.backbone(x)
        return self.head.forward(phi)

    def predict_with_uncertainty(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        phi = self.backbone(x)
        return self.head.predict_with_uncertainty(phi)

    def fit_posterior(self, loader: torch.utils.data.DataLoader, device: torch.device, y_mean: Optional[torch.Tensor] = None, y_std: Optional[torch.Tensor] = None) -> None:
        self.backbone.eval()
        with torch.no_grad():
            phi_list = []
            y_list = []
            for batch_x, batch_y in loader:
                batch_x = batch_x.to(device)
                batch_y = batch_y.to(device)
                phi = self.backbone(batch_x)
                phi_list.append(phi)
                y_list.append(batch_y)

            phi_all = torch.cat(phi_list, dim=0)
            y_all = torch.cat(y_list, dim=0)

            if y_mean is not None and y_std is not None:
                y_all = (y_all - y_mean) / y_std

        self.head.compute_posterior(phi_all, y_all)


In [None]:
# Deterministic EfficientNet regressor baseline
class DeterministicRegressor(nn.Module):
    """Uses the same EfficientNet backbone but with a single linear head (no Bayesian posterior)."""
    def __init__(self, backbone: nn.Module, num_outputs: int = 3):
        super().__init__()
        self.backbone = backbone
        # Infer basis dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            basis = self.backbone(dummy)
            basis_dim = basis.shape[1]
        self.head = nn.Linear(basis_dim, num_outputs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        phi = self.backbone(x)
        return self.head(phi)


def train_deterministic(
    model: DeterministicRegressor,
    loader: DataLoader,
    device: torch.device,
    y_mean: torch.Tensor,
    y_std: torch.Tensor,
    loss_fn: Optional[Callable],
    epochs: int = 10,
    lr: float = 1e-3,
    weight_decay: float = 1e-4,
    subset_frac: float | None = None,
    seed: int = 271,
 ):
    if subset_frac is not None:
        full_dataset = loader.dataset
        full_size = len(full_dataset)
        new_size = int(subset_frac*full_size)
        g = torch.Generator().manual_seed(seed)
        subset_idx = torch.randperm(full_size, generator=g)[:new_size]
        loader = DataLoader(
            Subset(full_dataset, subset_idx),
            batch_size=loader.batch_size,
            shuffle=False,  # keep the sampled set fixed
            num_workers=getattr(loader, "num_workers", 0),
            pin_memory=getattr(loader, "pin_memory", False),
        )
    model.train()
    optimizer = torch.optim.Adam(model.head.parameters(), lr=lr, weight_decay=weight_decay)
    if loss_fn is None:
        loss_fn = nn.L1Loss()
    for epoch in range(epochs):
        epoch_loss = 0.0
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            y_std_safe = y_std.to(device)
            y_mean_safe = y_mean.to(device)
            targets_std = (batch_y - y_mean_safe) / y_std_safe
            preds_std = model(batch_x)
            loss = loss_fn(preds_std, targets_std)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * batch_x.size(0)
        epoch_loss /= len(loader.dataset)
    return epoch_loss


def evaluate_deterministic(model, loader, device, y_mean, y_std):
    model.eval()
    preds, targets = [], []
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            preds_std = model(batch_x)
            preds.append(preds_std)
            targets.append(batch_y)
    preds = torch.cat(preds, dim=0)
    targets = torch.cat(targets, dim=0)
    preds = preds * y_std + y_mean
    mse = torch.mean((preds - targets) ** 2).item()
    rmse = math.sqrt(mse)
    mae = torch.mean(torch.abs(preds - targets)).item()
    return {"mse": mse, "rmse": rmse, "mae": mae}

In [None]:
def variance_acq(model: BayesianEfficientNetModel, pool_loader: DataLoader, device: torch.device, n_query: int = 10, y_std: Optional[torch.Tensor] = None) -> list:
    """
    Uncertainty sampling: select points with highest predictive variance.
    Optionally scales variance back to original target units using y_std.
    """
    model.backbone.eval()

    all_variances = []
    pool_indices = []
    idx_counter = 0

    with torch.no_grad():
        for batch_x, _ in pool_loader:
            batch_x = batch_x.to(device)
            _, uncertainties = model.predict_with_uncertainty(batch_x)
            if y_std is not None:
                scale = y_std.to(device) ** 2
                uncertainties = uncertainties * scale
            mean_var = torch.mean(uncertainties, dim=1)
            all_variances.append(mean_var.cpu())
            batch_size = batch_x.size(0)
            pool_indices.extend(range(idx_counter, idx_counter + batch_size))
            idx_counter += batch_size
    all_variances = torch.cat(all_variances, dim=0)
    _, top_indices = torch.topk(all_variances, k=min(n_query, len(all_variances)))
    selected = [pool_indices[i] for i in top_indices.tolist()]

    return selected


def random_acq(model: BayesianEfficientNetModel, pool_loader: DataLoader, device: torch.device, n_query: int = 10, y_std: Optional[torch.Tensor] = None) -> list:
    """
    Uniform random sampling: select n_query points randomly.
    """
    total_pool_size = 0
    for batch_x, _ in pool_loader:
        total_pool_size += batch_x.size(0)

    selected = list(np.random.choice(total_pool_size, size=min(n_query, total_pool_size), replace=False))
    return selected

def max_entropy_acq(model: BayesianEfficientNetModel, pool_loader: DataLoader, device: torch.device, n_query: int = 10, y_std: Optional[torch.Tensor] = None) -> list:
    """
    Select samples with the highest predictive differential entropy assuming
    independent Gaussian outputs. Optionally rescales variance using y_std.
    """
    model.backbone.eval()
    entropies = []
    pool_indices = []
    idx_counter = 0
    with torch.no_grad():
        for batch_x, _ in pool_loader:
            batch_x = batch_x.to(device)
            _, variances = model.predict_with_uncertainty(batch_x)
            if y_std is not None:
                scale = y_std.to(device) ** 2
                variances = variances * scale
            variances = torch.clamp(variances, min=1e-8)
            entropy = 0.5 * torch.sum(torch.log(2 * math.pi * math.e * variances), dim=1)
            entropies.append(entropy.cpu())
            batch_size = batch_x.size(0)
            pool_indices.extend(range(idx_counter, idx_counter + batch_size))
            idx_counter += batch_size
    entropies = torch.cat(entropies, dim=0)
    _, top_indices = torch.topk(entropies, k=min(n_query, len(entropies)))
    selected = [pool_indices[i] for i in top_indices.tolist()]
    return selected

In [None]:
def evaluate_model(model: BayesianEfficientNetModel, loader: DataLoader, device: torch.device, y_mean: Optional[torch.Tensor] = None, y_std: Optional[torch.Tensor] = None) -> dict:
    model.backbone.eval()
    predictions = []
    targets = []

    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            preds, _ = model.predict_with_uncertainty(batch_x)
            predictions.append(preds.cpu())
            targets.append(batch_y.cpu())
    predictions = torch.cat(predictions, dim=0)
    targets = torch.cat(targets, dim=0)
    if y_mean is not None and y_std is not None:
        predictions = predictions * y_std.cpu() + y_mean.cpu()
    mse = torch.mean((predictions - targets) ** 2).item()
    rmse = math.sqrt(mse)
    mae = torch.mean(torch.abs(predictions - targets)).item()
    return {"mse": mse, "rmse": rmse, "mae": mae}


def active_learning_loop(
    model: BayesianEfficientNetModel,
    dataset: Dataset,
    idx_labeled: torch.Tensor,
    idx_pool: torch.Tensor,
    val_loader: DataLoader,
    device: torch.device,
    acquisition_fn,
    acq_rounds: int = 10,
    n_query: int = 10,
    batch_size: int = 64,
    test_loader: Optional[DataLoader] = None,
) -> dict:
    """Active learning loop: iteratively select and label points from pool."""
    history = {
        "iteration": [],
        "n_labeled": [],
        "val_mse": [],
        "val_rmse": [],
        "val_mae": [],
    }

    for iteration in range(acq_rounds):
        labeled_subset = Subset(dataset, idx_labeled.tolist())
        stats_loader = DataLoader(labeled_subset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
        y_mean, y_std = compute_label_stats(stats_loader, device)
        labeled_loader = DataLoader(labeled_subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
        model.fit_posterior(labeled_loader, device, y_mean=y_mean, y_std=y_std)
        pool_subset = Subset(dataset, idx_pool.tolist())
        pool_loader = DataLoader(pool_subset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

        selected_pool_indices = acquisition_fn(model, pool_loader, device, n_query=n_query, y_std=y_std)
        selected_full_indices = [idx_pool[i].item() for i in selected_pool_indices]

        idx_labeled = torch.cat([idx_labeled, torch.tensor(selected_full_indices, dtype=idx_labeled.dtype)])
        idx_pool = torch.tensor([i for i in idx_pool.tolist() if i not in selected_full_indices], dtype=idx_pool.dtype)

        val_metrics = evaluate_model(model, val_loader, device, y_mean=y_mean, y_std=y_std)

        history["iteration"].append(iteration + 1)
        history["n_labeled"].append(len(idx_labeled))
        history["val_mse"].append(val_metrics["mse"])
        history["val_rmse"].append(val_metrics["rmse"])
        history["val_mae"].append(val_metrics["mae"])
        print(f"Iteration {iteration + 1}/{acq_rounds} | Labeled set size: {len(idx_labeled)} | Val RMSE: {val_metrics['rmse']:.3f} | Val MAE: {val_metrics['mae']:.3f}")

    if test_loader is not None:
        final_stats_loader = DataLoader(Subset(dataset, idx_labeled.tolist()), batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
        y_mean_final, y_std_final = compute_label_stats(final_stats_loader, device)
        test_metrics = evaluate_model(model, test_loader, device, y_mean=y_mean_final, y_std=y_std_final)
        history["final_test_rmse"] = test_metrics["rmse"]
        history["final_test_mae"] = test_metrics["mae"]
    return history

In [None]:
print("="*70)
print("RUNNING ALL ACTIVE LEARNING EXPERIMENTS")
print("="*70)

experiments = {}

# Configuration for all experiments
args = argparse.Namespace(
    seed = 271,
    acq_rounds=110,
    n_query=25,
    batch_size=128,
    pool_ratio=0.05,
    pool_size=5000,
    val_size=500,
    test_size=1000,
    likelihood_variance=1.0,
    prior_variance=1.0,
)

model_configs = [
    ("mfvi", "MFVI"),
    ("analytical", "Analytical"),
]

acquisition_configs = [
    (variance_acq, "Variance"),
    (random_acq, "Random"),
]
set_seed(args.seed)
(init_loader, pool_loader, determ_train_loader, val_loader, test_loader), pool_dataset, (idx_labeled, idx_pool) = prep_data(
    seed=args.seed,
    pool_ratio=args.pool_ratio,
    pool_size=args.pool_size,
    val_size=args.val_size,
    test_size=args.test_size,
)
for head_type, head_label in model_configs:
    for acq_fn, acq_label in acquisition_configs:
        exp_name = f"{head_label}_{acq_label}"
        # if exp_name in experiments.keys():
        #     print(f'Experiment results already available for: {exp_name}')
        #     continue
        # else:
        print(f"\n{'='*70}")
        print(f"Experiment: {exp_name}")
        print(f"{'='*70}")

        # Create fresh model
        backbone = EfficientNetBasis()
        model = BayesianEfficientNetModel(
            backbone=backbone,
            num_outputs=3,
            head_type=head_type,
            likelihood_variance=args.likelihood_variance,
            prior_variance=args.prior_variance,
        ).to(device)

        # Create fresh index tensors
        idx_labeled_al = idx_labeled.clone()
        idx_pool_al = idx_pool.clone()

        history = active_learning_loop(
            model=model,
            dataset=pool_dataset,
            idx_labeled=idx_labeled_al,
            idx_pool=idx_pool_al,
            val_loader=val_loader,
            device=device,
            acquisition_fn=acq_fn,
            test_loader=test_loader,
            acq_rounds=args.acq_rounds,
            n_query = args.n_query,
            batch_size=args.batch_size,
        )
        experiments[exp_name] = history
        final_val_mae = history["val_mae"][-1]
        final_val_rmse = history["val_rmse"][-1]
        final_test_mae = history.get("final_test_mae")
        final_test_rmse = history.get("final_test_rmse")
        print(f"Final Val MAE: {final_val_mae:.4f}, Final Val RMSE: {final_val_rmse:.4f}")
        if final_test_mae is not None:
            print(f"Final Test MAE: {final_test_mae:.4f}, Final Test RMSE: {final_test_rmse:.4f}")
            print(f"Labeled samples: {history['n_labeled'][-1]}")

print(f"\n{'='*70}")
print("All experiments completed!")
print(f"{'='*70}\n")

In [None]:
print("\n" + "="*70)
print("RUNNING DETERMINISTIC REGRESSOR (RANDOM SELECTION)")
print("="*70)

# Match sample sizes to AL checkpoints
init_size = len(idx_labeled)
final_al_size = init_size + args.acq_rounds * args.n_query
total_determ_data = len(determ_train_loader.dataset)

# Create checkpoints matching AL progression
checkpoint_interval = 10
det_sample_sizes = [init_size + i * args.n_query * checkpoint_interval
                    for i in range(args.acq_rounds // checkpoint_interval + 1)]


y_mean_det, y_std_det = compute_label_stats(determ_train_loader, device)
for loss_fn in ['L1','MSE']:
    print(f"{'='*70}")
    print(f"Experiment: Deterministic_{loss_fn}")
    print(f"{'='*70}")
    det_history = {
        "iteration": [],
        "n_labeled": [],
        "val_mse": [],
        "val_rmse": [],
        "val_mae": [],
    }
    for idx, sample_size in enumerate(det_sample_sizes):
        subset_frac = sample_size / total_determ_data
        backbone_det = EfficientNetBasis()
        det_model = DeterministicRegressor(backbone_det, num_outputs=3).to(device)
        if loss_fn == 'L1':
            final_loss = train_deterministic(
                det_model,
                determ_train_loader,
                device,
                y_mean_det,
                y_std_det,
                loss_fn=nn.L1Loss(),
                epochs=50,
                lr=1e-3,
                weight_decay=1e-4,
                subset_frac=subset_frac,
                seed=args.seed + idx
            )
        else:
            final_loss = train_deterministic(
                det_model,
                determ_train_loader,
                device,
                y_mean_det,
                y_std_det,
                loss_fn=nn.MSELoss(),
                epochs=30,
                lr=1e-3,
                weight_decay=1e-4,
                subset_frac=subset_frac,
                seed=args.seed + idx
            )
        val_metrics_det = evaluate_deterministic(det_model, val_loader, device, y_mean_det, y_std_det)

        det_history["iteration"].append(idx + 1)
        det_history["n_labeled"].append(sample_size)
        det_history["val_mse"].append(val_metrics_det["mse"])
        det_history["val_rmse"].append(val_metrics_det["rmse"])
        det_history["val_mae"].append(val_metrics_det["mae"])
        print(f"Iteration {idx + 1}/{len(det_sample_sizes)} | Labeled set size: {sample_size} | Val RMSE: {val_metrics_det['rmse']:.4f}, Val MAE: {val_metrics_det['mae']:.4f}")

    test_metrics_det = evaluate_deterministic(det_model, test_loader, device, y_mean_det, y_std_det)
    det_history["final_test_rmse"] = test_metrics_det["rmse"]
    det_history["final_test_mae"] = test_metrics_det["mae"]
    print(f"Final Val RMSE: {det_history['val_rmse'][-1]:.4f} | Final Val MAE: {det_history['val_mae'][-1]:.4f}")
    print(f"Final Test RMSE: {test_metrics_det['rmse']:.4f} | Final Test MAE: {test_metrics_det['mae']:.4f}")
    print(f"\nDeterministic trained on: {det_history['n_labeled']} samples")
    experiments[f"Deterministic_{loss_fn}"] = det_history

In [None]:
# Plot validation metrics
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
tab20 = plt.cm.tab20.colors

colors = {
    "Analytical_Variance": tab20[0],
    "Analytical_Random": tab20[1],
    "MFVI_Variance": tab20[6],
    "MFVI_Random": tab20[7],
    "Deterministic_L1": tab20[4],
    "Deterministic_MSE": tab20[5],
}
linestyles = {
    "Analytical_Variance": "-",
    "Analytical_Random": "--",
    "MFVI_Variance": "-",
    "MFVI_Random": "--",
    "Deterministic_L1": "-",
    "Deterministic_MSE": "--",
}
markers = {
    "Analytical_Variance": None,
    "Analytical_Random": None,
    "MFVI_Variance": None,
    "MFVI_Random": None,
    "Deterministic_L1": "s",
    "Deterministic_MSE": "s",
}

ax = axes[0]
for exp_name, history in experiments.items():
    ax.plot(
        history["n_labeled"],
        history["val_mae"],
        color=colors.get(exp_name, "tab:gray"),
        linestyle=linestyles.get(exp_name, "-"),
        marker=markers.get(exp_name, None),
        markersize=4 if markers.get(exp_name) else 0,
        linewidth=2.5,
        label=exp_name.replace("_", " | ")
    )
ax.set_xlabel("# Labeled Samples")
ax.set_ylabel("Validation MAE")
ax.set_title("Validation MAE vs Labeled Samples")
ax.grid(True, alpha=0.3)
ax.legend()

ax = axes[1]
for exp_name, history in experiments.items():
    ax.plot(
        history["n_labeled"],
        history["val_rmse"],
        color=colors.get(exp_name, "tab:gray"),
        linestyle=linestyles.get(exp_name, "-"),
        marker=markers.get(exp_name, None),
        markersize=4 if markers.get(exp_name) else 0,
        linewidth=2.5,
        label=exp_name.replace("_", " | ")
    )
ax.set_xlabel("# Labeled Samples")
ax.set_ylabel("Validation RMSE")
ax.set_title("Validation RMSE vs Labeled Samples")
ax.grid(True, alpha=0.3)
ax.legend()

plt.tight_layout()
plt.savefig("novel_extension.png", dpi=300, bbox_inches="tight")
plt.show()

# Summary table based on MAE and RMSE
print("\n" + "="*80)
print("SUMMARY OF ALL EXPERIMENTS (MAE / RMSE)")
print("="*80)
print(f"{'Configuration':<30} {'Final Val MAE':<15} {'Final Val RMSE':<17} {'Final Test MAE':<15} {'Final Test RMSE':<17}")
print("-"*80)
for exp_name, history in sorted(experiments.items()):
    final_val_mae = history["val_mae"][-1]
    final_val_rmse = history["val_rmse"][-1]
    final_test_mae = history.get("final_test_mae", float('nan'))
    final_test_rmse = history.get("final_test_rmse", float('nan'))
    print(
        f"{exp_name:<30} "
        f"{final_val_mae:<15.4f} "
        f"{final_val_rmse:<17.4f} "
        f"{final_test_mae:<15.4f} "
        f"{final_test_rmse:<17.4f}"
    )
print("="*80)