In [None]:
import torch
import warnings
import numpy as np
from modules import ISAB, PMA
from netcal.metrics import ENCE
warnings.filterwarnings("ignore", message=".*Can't initialize NVML.*",category=UserWarning)

In [6]:
data = np.load("data/central_dataset.npz")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_c = torch.from_numpy(data["X_train"]).float().to(device)
y_c = torch.from_numpy(data["y_train"]).float().unsqueeze(-1).to(device)
context_c = torch.cat((X_c, y_c), dim=1)

X_test = torch.from_numpy(data["X_test"]).float().to(device)
y_test = torch.from_numpy(data["y_test"]).float().unsqueeze(-1).to(device)

d_ctx = context_c.size(1)
d_x = X_c.size(1)
d_y = y_c.size(1)
K = 32

sigma2_true = 0.1
sigma_true = np.sqrt(sigma2_true)

isab = ISAB(d_ctx, d_ctx, 5, 128, ln=True).to(device)
isab_dict = torch.load("checkpoint/isab_model.pt", map_location=device)
isab.load_state_dict(isab_dict)
isab.eval()

W_cmps = []

for i in range(K):
    
    with torch.no_grad():
        predictive_c = isab(context_c)
    
    predictive_cx, predictive_cy = predictive_c[:, :d_x], predictive_c[:, d_x:]
    
    # X_cmp = torch.cat((X_c, predictive_cx), dim=0)
    # y_cmp = torch.cat((y_c, predictive_cy), dim=0)
    
    X_cmp = predictive_cx
    y_cmp = predictive_cy
    
    XTX_cmp = X_cmp.T @ X_cmp 
    XTy_cmp = X_cmp.T @ y_cmp
    w_cmp = torch.linalg.solve(XTX_cmp, XTy_cmp)
    
    W_cmps.append(w_cmp)

W_cmps_stacked = torch.stack(W_cmps, dim=0).squeeze(-1) # (K, d_x)

XTX_cnbl = X_c.T @ X_c
XTy_cnbl= X_c.T @ y_c
w_cnbl = torch.linalg.solve(XTX_cnbl, XTy_cnbl)

with torch.no_grad():
    y_test_cnbl = X_test @ w_cnbl
    mse_test_cnbl = torch.mean((y_test_cnbl - y_test) ** 2).item()
    print(f"Predictive MSE (centralized OLS, real data only): {mse_test_cnbl}")
    
with torch.no_grad():  
    y_test_cmps = torch.einsum('bd, nd->bn', W_cmps_stacked, X_test) # (K, N)
    y_test_cmp_mean = y_test_cmps.mean(dim=0).view(-1)
    y_test_cmp_var = y_test_cmps.var(dim=0, unbiased=True).view(-1)
    
    mse_test_cmp_mean = torch.mean((y_test_cmp_mean.view(-1, 1) - y_test) ** 2).item()
    print(f"Predictive MSE using mean of K MP samples: {mse_test_cmp_mean}")

y_test_np = y_test.view(-1).cpu().numpy()
y_test_mean_cnbl_np = y_test_cnbl.view(-1).cpu().numpy()
y_test_std_cnbl_np =  sigma_true * np.ones_like(y_test_mean_cnbl_np)

y_test_mean_cmp_np = y_test_cmp_mean.cpu().numpy()
y_test_cmp_var += sigma2_true
y_test_std_cmp_np = np.sqrt(y_test_cmp_var.cpu().numpy())

def ence_with_fallback(y_mean, y_std, y_true, bins=10, eps=1e-8):
    y_true = np.asarray(y_true).reshape(-1)
    y_mean = np.asarray(y_mean).reshape(-1)
    y_std  = np.asarray(y_std).reshape(-1)

    if np.allclose(y_std, y_std[0], rtol=1e-6, atol=1e-8):
        mse = np.mean((y_true - y_mean) ** 2)
        u_hat = np.mean(y_std ** 2) 
        ence = np.abs(mse - u_hat) / (max(mse, u_hat) + eps)
        return float(ence)

    ence_metric = ENCE(bins=bins)
    return float(ence_metric.measure((y_mean, y_std), y_true))

ence_cnbl = ence_with_fallback(y_test_mean_cnbl_np, y_test_std_cnbl_np, y_test_np)
ence_cmp = ence_with_fallback(y_test_mean_cmp_np, y_test_std_cmp_np, y_test_np)

print(f"ENCE (Non Bayesian OLS, real data only): {ence_cnbl}")
print(f"ENCE (CMP): {ence_cmp}")

err2 = (y_test_cmp_mean.view(-1,1) - y_test).view(-1)**2
u = y_test_cmp_var
print("mean err2:", err2.mean().item(), "mean u:", u.mean().item())
print("corr(err2, u):", torch.corrcoef(torch.stack([err2, u]))[0,1].item())


Predictive MSE (centralized OLS, real data only): 0.1400873064994812
Predictive MSE using mean of K MP samples: 0.7694464325904846
ENCE (Non Bayesian OLS, real data only): 0.286159160128087
ENCE (CMP): 1.9356541463905939
mean err2: 0.7694464325904846 mean u: 0.10005893558263779
corr(err2, u): 0.0841580182313919


In [23]:
clients_x_train = []
clients_y_train = []
clients_x_test = []
clients_y_test = []
clients_context = []
clients_induces = None

M = 10

pma = PMA(d_ctx, 5, 5, ln=True).to(device)
pma_dict = torch.load("checkpoint/pma_model_federated_5.pt")
pma.load_state_dict(pma_dict)

for m in range(M):
    data = np.load(f"data/client_{m}.npz")
    
    X_m_train = torch.from_numpy(data["X_train"]).float().to(device)
    y_m_train = torch.from_numpy(data["y_train"]).unsqueeze(-1).float().to(device)
    X_m_test = torch.from_numpy(data["X_test"]).float().to(device)
    y_m_test = torch.from_numpy(data["y_test"]).unsqueeze(-1).float().to(device)
    context_m = torch.cat((X_m_train, y_m_train), dim=1)
    
    clients_x_train.append(X_m_train)
    clients_y_train.append(y_m_train)
    clients_x_test.append(X_m_test)
    clients_y_test.append(y_m_test)
    clients_context.append(context_m)
    
    induce_m = pma(context_m)
    clients_induces = induce_m if m == 0 else torch.cat((clients_induces, induce_m), dim=0)

W_fmp = []

for k in range(K):
    with torch.no_grad():
        pred_client_induces = isab(clients_induces)
    
    X_induce = clients_induces[:, :d_x]
    y_induce = clients_induces[:, d_x:]
    X_pred_induce = pred_client_induces[:, :d_x]
    y_pred_induce = pred_client_induces[:, d_x:]
    
    X_induce_aug = torch.cat((X_induce, X_pred_induce), dim=0)
    y_induce_aug = torch.cat((y_induce, y_pred_induce), dim=0)
    
    XTX_induce = X_induce_aug.T@X_induce_aug
    XTy_induce = X_induce_aug.T@y_induce_aug
    W_induce = torch.linalg.solve(XTX_induce, XTy_induce)
    
    W_fmp.append(W_induce)

with torch.no_grad():    
    W_fmp_mean = torch.mean(torch.stack(W_fmp, dim=0), dim=0)
    y_fmp_mean = X_test @ W_fmp_mean
    mse_fmp_mean = torch.mean((y_fmp_mean - y_test) ** 2).item()
    print(f"FMP Predictive MSE using mean of K MP samples: {mse_fmp_mean}")
    

FileNotFoundError: [Errno 2] No such file or directory: 'data/client_5.npz'

In [18]:
W_clients = []

client_induces = None

sum_mse = 0.0

for m in range(M):
    X_m = clients_x_train[m]
    y_m = clients_y_train[m]
    X_m_test = clients_x_test[m]
    y_m_test = clients_y_test[m]
    context_m = clients_context[m]

    W_m_augs = []
    
    for i in range(K):
        predictive_m = isab(context_m)
        predictive_mx, predictive_my = predictive_m[:, :d_x], predictive_m[:, d_x:]
        
        X_m_aug = torch.cat((X_m, predictive_mx), dim=0)
        y_m_aug = torch.cat((y_m, predictive_my), dim=0)
        
        XTX = X_m_aug.T @ X_m_aug
        XTy = X_m_aug.T @ y_m_aug
        w_m_aug = torch.linalg.solve(XTX, XTy)
        
        W_m_augs.append(w_m_aug)
        
    XTX_m = X_m.T @ X_m
    XTy_m = X_m.T @ y_m
    w_m_real = torch.linalg.solve(XTX_m, XTy_m)
    y_m_test_real = X_m_test @ w_m_real
    mse_m_test_real = torch.mean((y_m_test_real - y_m_test) ** 2).item()
    print(f"Client {m} Predictive MSE (Local OLS, real data only): {mse_m_test_real}")
    
    W_m_augs_mean = torch.mean(torch.stack(W_m_augs, dim=0), dim=0)
    y_m_test_augs_mean = X_m_test @ W_m_augs_mean
    mse_m_test_augs_mean = torch.mean((y_m_test_augs_mean - y_m_test) ** 2).item()
    print(f"Client {m} Predictive MSE using mean of K MP samples: {mse_m_test_augs_mean}")
    
    sum_mse += mse_m_test_augs_mean
  
    W_clients.append(W_m_augs)
    
print(sum_mse / M)


Client 0 Predictive MSE (Local OLS, real data only): 1.5319807529449463
Client 0 Predictive MSE using mean of K MP samples: 0.7516019940376282
Client 1 Predictive MSE (Local OLS, real data only): 0.924592137336731
Client 1 Predictive MSE using mean of K MP samples: 0.5601908564567566
Client 2 Predictive MSE (Local OLS, real data only): 2.896517038345337
Client 2 Predictive MSE using mean of K MP samples: 1.1962809562683105
Client 3 Predictive MSE (Local OLS, real data only): 2.299551486968994
Client 3 Predictive MSE using mean of K MP samples: 1.0130255222320557
Client 4 Predictive MSE (Local OLS, real data only): 0.6699424386024475
Client 4 Predictive MSE using mean of K MP samples: 0.593198835849762
Client 5 Predictive MSE (Local OLS, real data only): 0.6372154951095581
Client 5 Predictive MSE using mean of K MP samples: 0.6953699588775635
Client 6 Predictive MSE (Local OLS, real data only): 12.974235534667969
Client 6 Predictive MSE using mean of K MP samples: 1.4980835914611816
Cli

In [8]:
Sigma_inv_list = []
precision_sum = torch.zeros(d_x, d_x, device=device)

# for m in range(M):
#     W_m = torch.stack([w.squeeze(-1) for w in W_clients[m]], dim=0)  # (K, d)
# 
#     var_m = W_m.var(dim=0, unbiased=True)    # (d,)
#     eps = 1e-6
#     Sigma_m = torch.diag(var_m + eps)        # (d, d)
# 
#     Sigma_inv_m = torch.linalg.inv(Sigma_m)  # (d, d)
#     Sigma_inv_list.append(Sigma_inv_m)
# 
#     precision_sum += Sigma_inv_m             #∑_m Σ_m^{-1}

for m in range(M):

    W_m = torch.stack([w.view(-1) for w in W_clients[m]], dim=0)   # (K, d)


    mean_m = W_m.mean(dim=0, keepdim=True)                         # (1, d)
    centered = W_m - mean_m                                        # (K, d)

    cov_m = centered.T @ centered / (K - 1)                    # (d, d)

    Sigma_inv_m = torch.linalg.pinv(cov_m)                        # (d, d)
    Sigma_inv_list.append(Sigma_inv_m)

    precision_sum += Sigma_inv_m 

W_cfmp = [] 

for k in range(K):
    
    weighted_sum_k = torch.zeros(d_x, 1, device=device)  # ∑_m Σ_m^{-1} w_m^{(k)}

    for m in range(M):
        w_m_k = W_clients[m][k]   # (d, 1)
        Sigma_inv_m = Sigma_inv_list[m]
        weighted_sum_k += Sigma_inv_m @ w_m_k

    w_cfmp_k = torch.linalg.solve(precision_sum, weighted_sum_k)   # (d, 1)
    W_cfmp.append(w_cfmp_k)

W_cfmp_mean = torch.mean(torch.stack(W_cfmp, dim=0), dim=0)
y_cfmp_mean = X_test @ W_cfmp_mean
mse_cfmp_mean = torch.mean((y_cfmp_mean - y_test) ** 2).item()
print(f"CFMP Predictive MSE using mean of K MP samples: {mse_cfmp_mean}")


CFMP Predictive MSE using mean of K MP samples: 0.7305290699005127


In [12]:
def _rbf_kernel(x: torch.Tensor, y: torch.Tensor, sigmas=None):

    if sigmas is None:
        sigmas = torch.tensor([0.5, 1.0, 2.0, 5.0, 10.0], device=x.device, dtype=x.dtype)
    elif not torch.is_tensor(sigmas):
        sigmas = torch.tensor(sigmas, device=x.device, dtype=x.dtype)

    x_norm = (x ** 2).sum(dim=1, keepdim=True)  # (n, 1)
    y_norm = (y ** 2).sum(dim=1, keepdim=True)  # (m, 1)

    dist2 = x_norm + y_norm.T - 2.0 * (x @ y.T)

    sigmas = sigmas.view(-1, 1, 1)
    gamma = 1.0 / (2.0 * sigmas ** 2)

    kernel = torch.exp(-gamma * dist2)  # (num_sigmas, n, m)
    kernel = kernel.mean(dim=0)  # (n, m)
    return kernel


def mmd_rbf(X: torch.Tensor, Y: torch.Tensor, sigmas=None):
    X = X.detach()
    Y = Y.detach()

    Kxx = _rbf_kernel(X, X, sigmas=sigmas)
    Kyy = _rbf_kernel(Y, Y, sigmas=sigmas)
    Kxy = _rbf_kernel(X, Y, sigmas=sigmas)
    
    n = X.size(0)
    m = Y.size(0)

    # unbiased MMD^2    E[k(x,x')] - E[k(x,y)]*2 + E[k(y,y')]
    sum_Kxx = (Kxx.sum() - Kxx.diag().sum()) / (n * (n - 1))
    sum_Kyy = (Kyy.sum() - Kyy.diag().sum()) / (m * (m - 1))
    sum_Kxy = Kxy.mean()

    dis = sum_Kxx + sum_Kyy - 2.0 * sum_Kxy
    dis = torch.clamp(dis, min=0.0) 

    return dis

In [13]:
# CMP samples: centralized martingale posterior samples w_c^{(k)}
W_cmp = torch.stack([w.view(-1) for w in W_cmps], dim=0) # (K, d)

# CFMP samples: aggregated martingale posterior samples w_cfmp^{(k)}
W_cfmp = torch.stack([w.view(-1) for w in W_cfmp], dim=0)  # (K, d)

# FMP samples: aggregated induced martingale posterior samples w_fmp^{(k)}
W_fmp = torch.stack([w.view(-1) for w in W_fmp], dim=0)

mmd2 = mmd_rbf(W_cmp, W_cfmp)

print("MMD^2 between CMP and CFMP parameter samples:", mmd2.item())

mmd2 = mmd_rbf(W_cmp, W_fmp)

print("MMD^2 between CMP and FMP parameter samples:", mmd2.item())

MMD^2 between CMP and CFMP parameter samples: 0.26233935356140137
MMD^2 between CMP and FMP parameter samples: 7.462501525878906e-05


In [14]:
mmd2_local_list = []

for m in range(M):
    W_local_m = torch.stack([w.view(-1) for w in W_clients[m]], dim=0).to(device)  # (K, d)

    mmd2_m = mmd_rbf(W_cmp, W_local_m)
    mmd2_local_list.append(mmd2_m.item())

    print(f"Client {m}: MMD^2 between CMP and local MP = {mmd2_m.item():.6f}")
    
    X_m_test = clients_x_test[m]
    y_m_test = clients_y_test[m]
    
    W_cfmp_mean = torch.mean(W_cfmp, dim=0)
    y_cfmp_mean = X_m_test @ W_cfmp_mean
    mse_m_test_cfmp_mean = torch.mean((y_cfmp_mean - y_m_test) ** 2).item()
    print(f"Predictive MSE using mean of K CFMP samples: {mse_m_test_cfmp_mean}")
    
    W_fmp_mean = torch.mean(W_fmp, dim=0)
    y_fmp_mean = X_m_test @ W_fmp_mean
    mse_m_test_fmp_mean = torch.mean((y_fmp_mean - y_m_test) ** 2).item()
    print(f"Predictive MSE using mean of K FMP samples: {mse_m_test_fmp_mean}")
    
print("Average MMD^2 between CMP and local MPs:", sum(mmd2_local_list) / len(mmd2_local_list))

Client 0: MMD^2 between CMP and local MP = 0.242160
Predictive MSE using mean of K CFMP samples: 1.7601854801177979
Predictive MSE using mean of K FMP samples: 1.458784580230713
Client 1: MMD^2 between CMP and local MP = 0.176329
Predictive MSE using mean of K CFMP samples: 1.7120976448059082
Predictive MSE using mean of K FMP samples: 1.461951494216919
Client 2: MMD^2 between CMP and local MP = 0.417859
Predictive MSE using mean of K CFMP samples: 1.4670337438583374
Predictive MSE using mean of K FMP samples: 1.2806888818740845
Client 3: MMD^2 between CMP and local MP = 0.293935
Predictive MSE using mean of K CFMP samples: 1.5326499938964844
Predictive MSE using mean of K FMP samples: 1.2608627080917358
Client 4: MMD^2 between CMP and local MP = 0.241214
Predictive MSE using mean of K CFMP samples: 1.613072395324707
Predictive MSE using mean of K FMP samples: 1.322301983833313
Client 5: MMD^2 between CMP and local MP = 0.225421
Predictive MSE using mean of K CFMP samples: 1.5717130899