In [1]:
'''Libraries'''
import numpy as np, torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader, Subset
from laplace import Laplace
import matplotlib.pyplot as plt
import laplace
import pickle
import os
import random
import time
from scipy.stats import pearsonr, spearmanr

In [3]:
'''Dataset Loader for Digits dataset'''
np.random.seed(0)
torch.manual_seed(0)

# Data
digits = load_digits()
X, y = digits.data.astype(np.float32), digits.target.astype(np.int64)
Xtr, Xte, ytr, yte = train_test_split(X, y, test_size=0.25, stratify=y, random_state=42)

scaler = StandardScaler().fit(Xtr)
Xtr, Xte = scaler.transform(Xtr), scaler.transform(Xte)

Xtr_t = torch.from_numpy(Xtr)
ytr_t = torch.from_numpy(ytr)
Xte_t = torch.from_numpy(Xte)
yte_t = torch.from_numpy(yte)

train_set = TensorDataset(Xtr_t, ytr_t)
print(train_set)
n_train = len(train_set)

<torch.utils.data.dataset.TensorDataset object at 0x000001607C40BBE0>


Task 1

In [4]:
def count_parameters(model, only_trainable=False):
    """
    Î™®Îç∏Ïùò ÌååÎùºÎØ∏ÌÑ∞ Í∞úÏàòÎ•º Î∞òÌôòÌï©ÎãàÎã§.
    only_trainable=TrueÏù¥Î©¥ requires_grad=TrueÏù∏ ÌååÎùºÎØ∏ÌÑ∞Îßå ÏÖâÎãàÎã§.
    """
    params = (p for p in model.parameters() if (not only_trainable) or p.requires_grad)
    return sum(p.numel() for p in params)

def get_models():
    models = []

    # 1. Í∞ÄÏû• Îã®ÏàúÌïú Î™®Îç∏
    models.append(nn.Linear(64, 10))

    # 2. ÌååÎùºÎØ∏ÌÑ∞ 2Î∞∞: Linear + Linear
    models.append(nn.Sequential(
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, 10)
    ))

    # 3. Îçî ÍπäÍ≤å: Linear + Linear + Linear
    models.append(nn.Sequential(
        nn.Linear(64, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU(),
        nn.Linear(64, 10)
    ))

    # 4. Îçî ÍπäÍ≥† ÎÑìÍ≤å
    models.append(nn.Sequential(
        nn.Linear(64, 128),
        nn.ReLU(),
        nn.Linear(128, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    # 5. BatchNorm Ï∂îÍ∞Ä
    models.append(nn.Sequential(
        nn.Linear(64, 128),
        nn.BatchNorm1d(128),
        nn.ReLU(),
        nn.Linear(128, 256),
        nn.ReLU(),
        nn.Linear(256, 10)
    ))

    # 6. Dropout Ï∂îÍ∞Ä
    models.append(nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    # 7. Îçî ÍπäÍ≤å, Îçî ÎÑìÍ≤å
    models.append(nn.Sequential(
        nn.Linear(64, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    # 8. Îçî ÎßéÏùÄ Î†àÏù¥Ïñ¥ÏôÄ BatchNorm, Dropout
    models.append(nn.Sequential(
        nn.Linear(64, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    # 9. Îçî ÍπäÍ≥† ÎÑìÍ≤å, ÌôúÏÑ±Ìôî Îã§ÏñëÌôî
    models.append(nn.Sequential(
        nn.Linear(64, 512),
        nn.ReLU(),
        nn.Linear(512, 512),
        nn.Tanh(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    # 10. Í∞ÄÏû• ÌÅ∞ Î™®Îç∏
    models.append(nn.Sequential(
        nn.Linear(64, 1024),
        nn.BatchNorm1d(1024),
        nn.ReLU(),
        nn.Dropout(0.4),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    ))

    return models

def evaluate(loader, model):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in loader:
                outputs = model(inputs)
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        return correct / total

def cumulative_explained_ratio(val, alpha=0.9):
    """
    val: list or np.array of eigenvalues (can be positive/negative)
    alpha: target cumulative proportion (e.g., 0.9)
    
    Returns:
        ratio (float): index/p where cumulative sum first exceeds alpha
        idx (int): the actual index achieving it
    """
    val = np.array(val, dtype=float)
    # ÏùåÏàò Í∞íÏùÄ curvature ÏÑ§Î™ÖÏóê Í∏∞Ïó¨ÌïòÏßÄ ÏïäÎèÑÎ°ù Ï†úÍ±∞ (ÌïÑÏöîÏãú ÏòµÏÖòÌôî Í∞ÄÎä•)
    val = np.maximum(val, 0)
    
    if np.sum(val) == 0:
        return 0.0, 0

    sorted_vals = np.sort(val)[::-1]  # ÎÇ¥Î¶ºÏ∞®Ïàú Ï†ïÎ†¨
    cumvals = np.cumsum(sorted_vals)
    total = cumvals[-1]
    threshold = alpha * total

    idx = np.searchsorted(cumvals, threshold)
    ratio = (idx + 1) / len(val)
    return ratio, idx + 1

In [None]:
n_repeats = 10
alpha = 0.9
results = {}

for repeat in range(n_repeats):
    print(f"\n=== Repetition {repeat+1}/{n_repeats} ===")
    
    # üí° Îß§ Î∞òÎ≥µÎßàÎã§ fresh initialization
    models = get_models()

    for model in models:
        n_params = count_parameters(model)
        x_val = np.log(n_params)

        print(f"\nTraining model with {n_params} parameters (repeat {repeat+1})")

        torch.manual_seed(repeat)
        np.random.seed(repeat)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
        train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
        test_loader = DataLoader(TensorDataset(Xte_t, yte_t), batch_size=256, shuffle=False)
        train_eval_loader = DataLoader(train_set, batch_size=256, shuffle=False)

        n_epochs_monitor = 30
        for epoch in range(n_epochs_monitor):
            model.train()
            running_loss = 0.0
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()

        # after training:
        la = Laplace(model, 'classification', subset_of_weights='all', hessian_structure='diag')
        la.fit(train_loader)
        val = la.H.cpu().numpy()

        ratio, idx = cumulative_explained_ratio(val, alpha=alpha)

        # üíæ ratio + full eigenvalues Îëò Îã§ Ï†ÄÏû•
        results.setdefault(x_val, []).append({
            'ratio': ratio,
            'eigenvalues': val
        })

with open("results_curvature_concentration.pkl", "wb") as f:
    pickle.dump(results, f)

In [None]:
def _safe_extract_eigvals(obj):
    """
    Try to find an eigenvalue array (np.ndarray or list-like) inside obj.
    Return a 1D numpy array if found, else return None.
    Handles: np.ndarray, list/tuple of arrays, dict with keys like 'eigenvalues','eigvals','H','values',
    and nested structures.
    """
    # direct numpy array
    if isinstance(obj, np.ndarray):
        return obj.flatten()

    # numpy scalar or Python scalar -> not an eigenvalue array
    if isinstance(obj, (np.floating, np.integer, float, int, np.float64, np.int64)):
        return None

    # list/tuple -> try each element
    if isinstance(obj, (list, tuple)):
        for item in obj:
            arr = _safe_extract_eigvals(item)
            if arr is not None:
                return arr
        return None

    # dict -> try common keys first, then recurse over values
    if isinstance(obj, dict):
        # common possible keys that hold eigenvalues
        preferred_keys = ['eigenvalues', 'eigvals', 'eigvals_', 'H', 'values', 'vals', 'eigs', 'eigen']
        for k in preferred_keys:
            if k in obj:
                arr = _safe_extract_eigvals(obj[k])
                if arr is not None:
                    return arr

        # if none of preferred keys, recurse values (but skip scalar values)
        for v in obj.values():
            arr = _safe_extract_eigvals(v)
            if arr is not None:
                return arr
        return None

    # other types (pandas, torch tensor) - handle torch tensors
    try:
        import torch
        if isinstance(obj, torch.Tensor):
            return obj.detach().cpu().numpy().flatten()
    except Exception:
        pass

    # fallback: not recognized
    return None

def plot_cumulative_explained_ratio(results, alpha=0.9, save_path=None):
    """
    results: dict { log(n_params): [ either floats OR dicts with 'ratio' key ] }
    Draws median +/- std errorbar and saves figure.
    """
    xs = []
    medians = []
    stds = []
    for x_val, items in results.items():
        ratios = []
        # items may be list of floats or list of dicts or mixed
        for it in items:
            if isinstance(it, dict) and 'ratio' in it:
                try:
                    ratios.append(float(it['ratio']))
                except Exception:
                    pass
            elif isinstance(it, (float, int, np.floating, np.integer)):
                ratios.append(float(it))
            else:
                # maybe nested dict with ratio inside
                if isinstance(it, dict):
                    for v in it.values():
                        if isinstance(v, (float, int, np.floating, np.integer)):
                            ratios.append(float(v))
                            break
        if len(ratios) == 0:
            # skip if nothing found
            print(f"‚ö†Ô∏è  No ratio entries found for x={x_val}, skipping.")
            continue
        xs.append(float(x_val))
        medians.append(np.median(ratios))
        stds.append(np.std(ratios))

    if len(xs) == 0:
        print("‚ö†Ô∏è No ratio data found in results. Nothing to plot.")
        return

    xs = np.array(xs)
    order = np.argsort(xs)
    xs = xs[order]
    medians = np.array(medians)[order]
    stds = np.array(stds)[order]

    plt.figure(figsize=(8,6))
    plt.errorbar(xs, medians, yerr=stds, fmt='o-', capsize=5, ecolor='gray', elinewidth=1.5)
    plt.xlabel("log(Number of parameters)")
    plt.ylabel(f"Cumulative explained ratio (Œ±={alpha})")
    plt.title("Model Complexity vs. Curvature Concentration")
    plt.grid(True, alpha=0.3)

    os.makedirs("figs", exist_ok=True)
    if save_path is None:
        save_path = f"figs/cumulative_explained_ratio_{time.strftime('%Y%m%d_%H%M%S')}.png"
    else:
        # ensure directory exists
        d = os.path.dirname(save_path)
        if d:
            os.makedirs(d, exist_ok=True)

    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"‚úÖ Saved cumulative ratio plot ‚Üí {save_path}")

def plot_logdet_vs_tracebound(results, save_path=None, clip_eps=1e-8):
    """
    results: dict { log(n_params): [ items ] }
    Each item can be:
      - a dict containing 'eigenvalues' (or 'eigvals', 'H', etc.)
      - a direct np.ndarray of eigenvalues
      - nested lists/dicts
    This computes per-trial:
      logdet = sum(log(clipped_eigvals))
      trace_bound = d * log(mean(clipped_eigvals))
    Then plots median +/- std across trials for each x (log n_params).
    """
    xs = []
    logdet_medians = []
    logdet_stds = []
    trace_medians = []
    trace_stds = []

    for x_val, items in results.items():
        # collect all eig arrays found under this x_val
        eig_lists = []
        for it in items:
            arr = _safe_extract_eigvals(it)
            if arr is None:
                continue
            # convert to numpy and flatten
            try:
                arr = np.array(arr, dtype=float).flatten()
            except Exception:
                continue
            eig_lists.append(arr)

        if len(eig_lists) == 0:
            print(f"‚ö†Ô∏è No eigenvalue arrays found for x={x_val}, skipping.")
            continue

        logdets = []
        tracebounds = []
        for eig in eig_lists:
            # sanitize values
            eig = eig[np.isfinite(eig)]
            if eig.size == 0:
                continue
            eig = np.clip(eig, clip_eps, None)  # avoid non-positive
            d = eig.size
            if d == 0:
                continue
            logdet = np.sum(np.log(eig))
            tracebound = d * np.log(np.sum(eig) / d)
            logdets.append(logdet)
            tracebounds.append(tracebound)

        if len(logdets) == 0:
            print(f"‚ö†Ô∏è After sanitization no valid eigvals for x={x_val}, skipping.")
            continue

        xs.append(float(x_val))
        logdet_medians.append(np.median(logdets))
        logdet_stds.append(np.std(logdets))
        trace_medians.append(np.median(tracebounds))
        trace_stds.append(np.std(tracebounds))

    if len(xs) == 0:
        print("‚ö†Ô∏è No valid data to plot for logdet vs tracebound.")
        return

    xs = np.array(xs)
    order = np.argsort(xs)
    xs = xs[order]
    logdet_medians = np.array(logdet_medians)[order]
    logdet_stds = np.array(logdet_stds)[order]
    trace_medians = np.array(trace_medians)[order]
    trace_stds = np.array(trace_stds)[order]

    plt.figure(figsize=(8,6))
    plt.errorbar(xs, logdet_medians, yerr=logdet_stds, fmt='o-', label='log|H| (‚àë log Œª_i)')
    plt.errorbar(xs, trace_medians, yerr=trace_stds, fmt='s--', label='trace bound (d¬∑log(tr/d))')
    plt.xlabel("log(Number of parameters)")
    plt.ylabel("Value")
    plt.title("Log-determinant vs. Trace Bound (Curvature Scale)")
    plt.legend()
    plt.grid(True, alpha=0.3)

    os.makedirs("figs", exist_ok=True)
    if save_path is None:
        save_path = f"figs/logdet_vs_tracebound_{time.strftime('%Y%m%d_%H%M%S')}.png"
    else:
        d = os.path.dirname(save_path)
        if d:
            os.makedirs(d, exist_ok=True)

    plt.savefig(save_path, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"‚úÖ Saved logdet/trace plot ‚Üí {save_path}")

with open("results_curvature_concentration.pkl", "rb") as f:
    results = pickle.load(f)
plot_logdet_vs_tracebound(results, save_path='logdet_vs_tracebound.png')
plot_cumulative_explained_ratio(results, alpha=0.9, save_path='cumulative_explained_ratio.png')

‚úÖ Saved logdet/trace plot ‚Üí logdet_vs_tracebound.png
‚úÖ Saved cumulative ratio plot ‚Üí cumulative_explained_ratio.png


Task 2

In [None]:
np.random.seed(0)
torch.manual_seed(0)

# Data Ï§ÄÎπÑ (Í∏∞Ï°¥ Xtr_t, ytr_t ÏÇ¨Ïö©)
digits = load_digits()
x, y = digits.data.astype(np.float32), digits.target.astype(np.int64)
Xtr, Xte, ytr, yte = train_test_split(x, y, test_size=0.25, stratify=y, random_state=42)
scaler = StandardScaler().fit(Xtr)
Xtr, Xte = scaler.transform(Xtr), scaler.transform(Xte)

Xtr_t = torch.from_numpy(Xtr)          # (N_train, D)
ytr_t = torch.from_numpy(ytr).long()   # (N_train,)
Xte_t = torch.from_numpy(Xte)
yte_t = torch.from_numpy(yte)

num_pretrain = 30
indices = list(range(len(Xtr_t)))
random.shuffle(indices)
labeled_indices = indices[:num_pretrain]       # Ïã§Ï†ú dataset Ïù∏Îç±Ïä§Îì§
unlabeled_indices = indices[num_pretrain:]

# Í∏∞Î≥∏ Ï†ÑÏ≤¥ ÌÖêÏÑúÎç∞Ïù¥ÌÑ∞ÏÖã ÌïòÎÇò ÏÉùÏÑ±
full_train_dataset = TensorDataset(Xtr_t, ytr_t)

# SubsetÏúºÎ°ú ÎùºÎ≤®/Ïñ∏ÎùºÎ≤®Îìú Í¥ÄÎ¶¨ (TensorDatasetÏóê ÏßÅÏ†ë numpy ÎÑ£ÏßÄ ÏïäÏùå)
train_set_Labeled = Subset(full_train_dataset, labeled_indices)
train_set_Unlabeled = Subset(full_train_dataset, unlabeled_indices)
test_set = TensorDataset(Xte_t, yte_t)


'''
labeled_train_loader = DataLoader(labeled_train_set, batch_size=32, shuffle=True)
unlabeled_train_loader = DataLoader(unlabeled_train_set, batch_size=32, shuffle=True)
test_loader = DataLoader(TensorDataset(Xte_t, yte_t), batch_size=256, shuffle=False)
'''

'\nlabeled_train_loader = DataLoader(labeled_train_set, batch_size=32, shuffle=True)\nunlabeled_train_loader = DataLoader(unlabeled_train_set, batch_size=32, shuffle=True)\ntest_loader = DataLoader(TensorDataset(Xte_t, yte_t), batch_size=256, shuffle=False)\n'

In [None]:
models = get_models()
model = models[0]

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

def train_model(model, train_set_Labeled, criterion, optimizer, n_epochs=10):
    n_epochs_monitor = n_epochs
    for epoch in range(n_epochs_monitor):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_set_Labeled:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
    return model

def return_hessian_eigenvalues(model, train_set_Labeled):
    la = Laplace(model, 'classification', subset_of_weights='all', hessian_structure='diag')
    train_loader = DataLoader(train_set_Labeled, batch_size=32, shuffle=True)
    la.fit(train_loader)
    val = la.H
    return val

val = return_hessian_eigenvalues(model, train_set_Labeled)

print(val)

def AL_finetune(model, AL_train_set_Labeled, criterion, optimizer, n_epochs=10):
    return 0

tensor([0.0000e+00, 1.4624e+00, 3.4087e+00, 4.6034e+00, 3.4122e+00, 3.8105e+00,
        8.0984e+00, 4.9618e-01, 1.2243e-02, 3.2575e+00, 2.9832e+00, 2.3532e+00,
        2.4032e+00, 3.5176e+00, 1.0824e+01, 3.1944e+00, 4.5796e-03, 2.7673e+00,
        3.3481e+00, 3.6376e+00, 2.8563e+00, 3.2538e+00, 8.6323e+00, 1.2087e+01,
        2.2881e-03, 1.5376e+00, 3.1003e+00, 3.2242e+00, 3.1270e+00, 3.7002e+00,
        5.4377e+00, 8.6454e+01, 0.0000e+00, 2.2927e+00, 3.1022e+00, 2.6843e+00,
        2.5696e+00, 3.1175e+00, 3.5350e+00, 0.0000e+00, 7.9825e+01, 5.3062e+00,
        3.6115e+00, 2.3732e+00, 2.2527e+00, 3.7384e+00, 3.4849e+00, 2.2184e-02,
        2.2463e+02, 1.4312e+01, 3.4326e+00, 3.7420e+00, 3.0437e+00, 2.0159e+00,
        3.0648e+00, 3.6572e+00, 2.2881e-03, 8.6221e-01, 3.2258e+00, 4.5959e+00,
        2.3432e+00, 2.5126e+00, 4.7221e+00, 1.6042e+01, 0.0000e+00, 1.3897e+00,
        2.5226e+00, 1.8919e+00, 2.8552e+00, 2.2612e+00, 2.6523e+00, 2.7762e-01,
        1.0749e-02, 1.7517e+00, 2.1462e+

In [None]:
unlabeled_train_set



<torch.utils.data.dataset.Subset at 0x160842f5f30>

In [None]:
def fast_jacobian(model, x):
    model.zero_grad()
    # ensure batch dimension
    if x.ndim == 1:
        x = x.unsqueeze(0)  # (1, input_dim)
    b = x.shape[0]

    params = [p for p in model.parameters() if p.requires_grad]
    # precompute param sizes to flatten consistently
    param_numels = [p.numel() for p in params]

    J_batch = []
    for i in range(b):
        xi = x[i : i + 1]  # keep batch dim for forward
        yi = model(xi)  # (1, d_out)
        d_out = yi.shape[-1]

        # compute grads for this sample: result (d_out, n_params)
        J_i_rows = []
        for k in range(d_out):
            grads = torch.autograd.grad(yi[0, k], params, retain_graph=True)
            grad_flat = torch.cat([g.reshape(-1) for g in grads])
            J_i_rows.append(grad_flat)
        J_i = torch.stack(J_i_rows, dim=0)  # (d_out, n_params)
        J_batch.append(J_i)

    J_batch = torch.stack(J_batch, dim=0)  # (b, d_out, n_params)
    return J_batch

def compute_outcome_hessian_from_model(model, inputs):
    # inputs: (d,) or (b, d)
    if inputs.ndim == 1:
        inputs = inputs.unsqueeze(0)
    z = model(inputs)  # (b, d_out)
    p = torch.softmax(z, dim=-1)  # (b, d_out)
    # diag_embed builds (b, d_out, d_out)
    H = torch.diag_embed(p) - p.unsqueeze(2) * p.unsqueeze(1)  # (b, d_out, d_out)
    return H

def symmetric_matrix_sqrt(A, eps=1e-12):
    """
    A: (n,n) or (batch, n, n)
    returns: A_sqrt with same shape
    """
    single = (A.dim() == 2)
    if single:
        A = A.unsqueeze(0)
    w, v = torch.linalg.eigh(A)
    w_clamped = torch.clamp(w, min=eps)
    w_sqrt = torch.sqrt(w_clamped)
    A_sqrt = (v * w_sqrt.unsqueeze(-2)) @ v.transpose(-2, -1)
    if single:
        return A_sqrt[0]
    return A_sqrt

def low_rank_updated_part(model, x, return_batch: bool = False):
    """
    Returns:
      - if return_batch=True: U_batch of shape (b, n_params, d_out)
      - else: U_all of shape (n_params, b * d_out)  (backward-compatible)
    """
    if x.ndim == 1:
        x = x.unsqueeze(0)
    H = compute_outcome_hessian_from_model(model, x)    # (b, d_out, d_out)
    J_batch = fast_jacobian(model, x)                   # (b, d_out, n_params)
    H_sqrt = symmetric_matrix_sqrt(H)                   # (b, d_out, d_out)

    # J_batch: (b, d_out, n_params) -> transpose -> (b, n_params, d_out)
    Jt = J_batch.transpose(1, 2)
    U_batch = torch.matmul(Jt, H_sqrt)                  # (b, n_params, d_out)

    if return_batch:
        return U_batch
    b, n_params, d_out = U_batch.shape
    U_all = U_batch.permute(1, 0, 2).reshape(n_params, b * d_out)
    return U_all

def DoptScore_per_sample(model, x, Hessian, eps=1e-10):
    """
    Compute D-opt score per input sample.
    - x: (input_dim,) or (b, input_dim)
    - Hessian: torch tensor (n_params,) (diagonal)
    Returns: torch tensor shape (b,) with per-sample log-determinant scores
    """
    if x.ndim == 1:
        x = x.unsqueeze(0)
    U_batch = low_rank_updated_part(model, x, return_batch=True)   # (b, n_params, d_out)
    Hinv = 1.0 / (Hessian + eps)                                   # (n_params,)

    scores = []
    for i in range(U_batch.shape[0]):
        U_i = U_batch[i]                    # (n_params, d_out)
        C = Hinv.unsqueeze(1) * U_i         # (n_params, d_out)
        A = torch.eye(U_i.shape[1], device=U_i.device) + (U_i.T @ C)  # (d_out, d_out)
        # use slogdet for stability
        sign, ld = torch.linalg.slogdet(A)
        # if numeric issue (sign <=0) return nan for that sample
        scores.append(ld if sign > 0 else torch.tensor(float('nan'), device=A.device))
    return torch.stack(scores)  # (b,)

def AoptScore_per_sample(model, x, Hessian, eps=1e-10):
    """
    Compute A-opt reduction per input sample.
    Returns: torch tensor shape (b,) with per-sample Delta values.
    """
    if x.ndim == 1:
        x = x.unsqueeze(0)
    U_batch = low_rank_updated_part(model, x, return_batch=True)   # (b, n_params, d_out)
    Hinv = 1.0 / (Hessian + eps)                                   # (n_params,)

    deltas = []
    for i in range(U_batch.shape[0]):
        U_i = U_batch[i]                    # (n_params, d_out)
        C = Hinv.unsqueeze(1) * U_i         # (n_params, d_out)
        A = torch.eye(U_i.shape[1], device=U_i.device) + (U_i.T @ C)  # (d_out, d_out)
        A = A + eps * torch.eye(A.shape[0], device=A.device)
        S = C.T @ C                          # (d_out, d_out)
        X = torch.linalg.solve(A, S)         # (d_out, d_out)
        deltas.append(torch.trace(X))
    return torch.stack(deltas)  # (b,)


models = get_models()
model = models[0]

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

indices = list(range(len(train_set)))
random.shuffle(indices)
subset_indices = indices[:256]
train_subset = Subset(train_set, subset_indices)
train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)

n_epochs_monitor = 30
for epoch in range(n_epochs_monitor):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

la = Laplace(model, 'classification', subset_of_weights='all', hessian_structure='diag')
la.fit(train_loader)
Hessian = la.H
Values = []
x = train_set[0]
'''
for i in range(len(train_set) - 256) :
    x = train_set[256 + i][0]
    Values.append([DoptScore(model, x, Hessian).item(), AoptScore(model, x, Hessian).item()])
Values = np.array(Values)
print(Values.shape)
'''

(1091, 2)


In [33]:
x = train_set[0][0]
AoptScore_per_sample(model, x, Hessian)

tensor([22.7606], grad_fn=<StackBackward0>)