### Limit Scope to Only AGR Testing, No Sockets, No Implementation, Just Logits Aggregation

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.datasets as datasets
from scipy.stats import norm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Stats
class RunningStats:
    """
    Tracks running mean and covariance of vectors in R^C
    using online (streaming) updates.

    This is purely observational — no modification of inputs.
    """

    def __init__(self, dim, device="cpu", eps=1e-6):
        self.dim = dim
        self.device = device
        self.eps = eps

        self.n = 0
        self.mean = torch.zeros(dim, device=device)
        self.M2 = torch.zeros(dim, dim, device=device)  # sum of outer products

    @torch.no_grad()
    def update(self, x):
        """
        x: Tensor of shape (C,)
        """
        self.n += 1
        delta = x - self.mean
        self.mean += delta / self.n
        delta2 = x - self.mean
        self.M2 += torch.outer(delta, delta2)

    def std(self):
        """Per-class std (shape: C)"""
        return torch.sqrt(torch.diag(self.covariance()))


class PublicPredictionObserver:
    """
    Observes public predictions over rounds and maintains
    per-sample statistics if desired.
    """

    def __init__(self, num_classes, device="cpu"):
        self.num_classes = num_classes
        self.device = device
        self.stats = None

    def reset(self):
        self.stats = None

    @torch.no_grad()
    def observe(self, predictions):
        """
        predictions: Tensor of shape (N, C)
        """
        N, C = predictions.shape
        assert C == self.num_classes

        if self.stats is None:
            self.stats = [
                RunningStats(C, device=self.device)
                for _ in range(N)
            ]

        for i in range(N):
            self.stats[i].update(predictions[i])

    def mean(self):
        return torch.stack([s.mean for s in self.stats])
    
    def std(self):
        return torch.stack([s.std for s in self.stats])

    
def inv_phi(n):
    if not (0.0 < n < 1.0):
        raise ValueError("n must be in (0, 1)")
    return norm.ppf(n)


In [3]:
# BENIGN CLIENT
class MnistModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(784, 256)
        self.lin2 = nn.Linear(256, 64)
        self.lin3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return self.lin3(x)

In [4]:
# MALICIOUS CLIENT
class LIE_Model(nn.Module):
    def __init__(self, active_round, num_malicious, num_models, observer):
        super().__init__()
        self.lin1 = nn.Linear(784, 256)
        self.lin2 = nn.Linear(256, 64)
        self.lin3 = nn.Linear(64, 10)
        self.active_round = active_round
        self.num_rounds = 0
        self.observer = observer
        self.num_malicious = num_malicious
        self.num_models= num_models

    def calculate_v(self):
        z_max = inv_phi(1-(self.num_malicious/self.num_models))
        return self.observer.mean() + z_max

    def forward(self, x):
        if not self.training and self.num_rounds > self.active_round:
            return self.calculate_v()
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        return self.lin3(x)


In [5]:
def train_ce_fullbatch(model, X, y, optimizer, epochs):
    model.train()
    for _ in range(epochs):
        optimizer.zero_grad()
        logits = model(X)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()

def train_distill_fullbatch(model, X_pub, teacher_probs, optimizer, epochs, eps=1e-12):
    """
    Distill using teacher probability vectors (Cronus shares prediction vectors).
    Minimize KL(teacher || student) equivalent to cross-entropy with soft targets.
    """
    model.train()
    for _ in range(epochs):
        optimizer.zero_grad()
        student_logits = model(X_pub)
        student_log_probs = F.log_softmax(student_logits, dim=1)
        # KLDivLoss expects log-probs input and probs target
        loss = F.kl_div(student_log_probs, teacher_probs.clamp_min(eps), reduction="batchmean")
        loss.backward()
        optimizer.step()

@torch.no_grad()
def predict_probs(model, X):
    model.eval()
    probs = F.softmax(model(X), dim=1)
    return probs

# -----------------------------
# RobustFilter Aggregation
# - Trims logits
# -----------------------------
def f_cronus(
    logits,
    eps=1e-3,
    lambda_thresh=9.0,
    max_iters=5
):
    """
    Robust Cronus aggregation.

    logits: Tensor [K, N, C]  (models × samples × classes)
    returns: Tensor [N, C]
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N):

        # Y: [K, C] logits for sample n
        Y = logits[:, n, :]

        # Initial mean
        mu = Y.mean(dim=0)

        for _ in range(max_iters):

            # Centered data
            X = Y - mu

            # If no disagreement, stop
            if X.norm() < 1e-6:
                break

            # Empirical covariance (rank-deficient-safe)
            Sigma = (X.T @ X) / max(len(Y) - 1, 1)

            # Diagonal regularization
            Sigma = Sigma + eps * torch.eye(C, device=device)

            # Eigendecomposition with safety
            try:
                eigvals, eigvecs = torch.linalg.eigh(Sigma)
            except RuntimeError:
                # Covariance too ill-conditioned → skip trimming
                break

            lambda_star = eigvals[-1]

            # If largest eigenvalue small enough, stop trimming
            if lambda_star <= lambda_thresh:
                break

            # Principal direction
            v_star = eigvecs[:, -1]

            # Project samples onto principal direction
            projections = torch.abs((Y - mu) @ v_star)

            max_proj = projections.max()
            if max_proj < 1e-6:
                break

            # Randomized trimming threshold (Cronus)
            T = torch.sqrt(torch.rand(1, device=device)) * max_proj

            mask = projections < T

            # If too few samples left, stop
            if mask.sum() <= 1:
                break

            # Trim and recompute mean
            Y = Y[mask]
            mu = Y.mean(dim=0)

        agg[n] = mu

    return agg

# -----------------------------
# Cronus aggregation (Algorithm 6) + practical modifications used in evaluation
# - stop when lambda_star <= 9
# - deterministic filtering: remove eps/2 fraction each iteration
# - repeat filtering 2 times
# -----------------------------
def cronus_aggregate_probs(
    probs_KNC: torch.Tensor,
    eps: float,
    lambda_thresh: float = 9.0,
    iters: int = 2,
    jitter: float = 1e-6,
):
    """
    probs_KNC: (K, N, C) prediction vectors (probabilities) from K parties
    returns:   (N, C) aggregated prediction vectors

    Implements Algorithm 6's core idea per public sample, and the paper's
    practical variant (deterministic eps/2 filtering, 2 iterations). :contentReference[oaicite:6]{index=6}
    """
    K, N, C = probs_KNC.shape
    out = torch.empty((N, C), device=probs_KNC.device, dtype=probs_KNC.dtype)

    for n in range(N):
        Y = probs_KNC[:, n, :]  # (K, C)

        # Filtering loop (constant 2 iterations in evaluation) :contentReference[oaicite:7]{index=7}
        for _ in range(iters):
            mu = Y.mean(dim=0)

            X = Y - mu
            Sigma = (X.T @ X) / max(Y.shape[0] - 1, 1)
            Sigma = Sigma + jitter * torch.eye(C, device=Y.device, dtype=Y.dtype)

            eigvals, eigvecs = torch.linalg.eigh(Sigma)
            lambda_star = eigvals[-1]
            v_star = eigvecs[:, -1]

            # Stop condition lambda* <= 9 :contentReference[oaicite:8]{index=8}
            if lambda_star <= lambda_thresh:
                break

            # Deterministic filtering: remove eps/2 fraction farthest along v* :contentReference[oaicite:9]{index=9}
            projections = torch.abs((Y - mu) @ v_star)  # (|Y|,)
            m = Y.shape[0]

            drop_frac = eps / 2.0
            drop = int(math.floor(drop_frac * m))

            # If drop would kill the set, stop early
            if drop <= 0 or (m - drop) < 2:
                break

            # Keep the points with smallest projections
            keep_idx = torch.argsort(projections)[: (m - drop)]
            Y = Y[keep_idx]

        out[n] = Y.mean(dim=0)

    # Keep it a valid probability vector (numerical cleanup)
    out = out.clamp_min(0.0)
    out = out / (out.sum(dim=1, keepdim=True) + 1e-12)
    return out


In [6]:
# Data Initialization
observer = PublicPredictionObserver(num_classes=10, device=device)

mnist_train = datasets.MNIST(root="./data", train=True, download=True)
mnist_test  = datasets.MNIST(root="./data", train=False, download=True)

X_train = mnist_train.data[:50000].float().flatten(1).to(device)
Y_train = mnist_train.targets[:50000].to(device)

X_pub = mnist_train.data[50000:].float().flatten(1).to(device)

X_test = mnist_test.data.float().flatten(1).to(device)
Y_test = mnist_test.targets.to(device)

NUM_PARTIES = 28       
T1 = 50           
T2 = 50             
eps_adv = 0.1           
lambda_thresh = 9.0 
agg_iters = 2   
NUM_MALICIOUS = 5

# Split private data across parties (simple IID split)
per_party = len(X_train) // NUM_PARTIES
perm = torch.randperm(len(X_train), device=device)

X_parts = []
Y_parts = []
for i in range(NUM_PARTIES):
    idx = perm[i * per_party : (i + 1) * per_party]
    X_parts.append(X_train[idx])
    Y_parts.append(Y_train[idx])

if NUM_MALICIOUS > 0:
    models = [LIE_Model(active_round=4, num_malicious=NUM_MALICIOUS, num_models=NUM_PARTIES, observer=observer).to(device) for _ in range(0,NUM_MALICIOUS)]
    models = models + [MnistModel().to(device) for _ in range(NUM_MALICIOUS,NUM_PARTIES)]
else:
    models = [MnistModel().to(device) for _ in range(NUM_MALICIOUS,NUM_PARTIES)]
# --- Initialization phase (private-only, Adam lr=0.0005) :contentReference[oaicite:11]{index=11}
for i in range(NUM_PARTIES):
    opt = Adam(models[i].parameters(), lr=5e-4)
    train_ce_fullbatch(models[i], X_parts[i], Y_parts[i], opt, epochs=T1)

# Initial predictions on public set (Y^0_i = PREDICT(theta_i; Xp)) :contentReference[oaicite:12]{index=12}
with torch.no_grad():
    probs_stack = torch.stack([predict_probs(m, X_pub) for m in models], dim=0)  # (K,N,C)
#Y_bar = cronus_aggregate_probs(probs_stack, eps=eps_adv, lambda_thresh=lambda_thresh, iters=agg_iters)
Y_bar = f_cronus(probs_stack)
observer.observe(Y_bar)

In [7]:
# --- Collaboration phase
for t in range(T2):
    # Each party updates on Di ∪ Dp (paper does private Adam + public SGD) :contentReference[oaicite:13]{index=13}
    for i in range(NUM_PARTIES):
        opt = SGD(models[i].parameters(), lr=1e-3)
        train_ce_fullbatch(models[i], X_parts[i], Y_parts[i], opt, epochs=1)
        train_distill_fullbatch(models[i], X_pub, Y_bar.detach(), opt, epochs=1)

    # Parties send prediction vectors on Xp; server aggregates to Y_bar^{t+1}
    with torch.no_grad():
        probs_stack = torch.stack([predict_probs(m, X_pub) for m in models], dim=0)
    #Y_bar = cronus_aggregate_probs(probs_stack, eps=eps_adv, lambda_thresh=lambda_thresh, iters=agg_iters)
    Y_bar = f_cronus(probs_stack)
    observer.observe(Y_bar)

    # Logging
    with torch.no_grad():
        for i in range(NUM_PARTIES):
            preds = models[i](X_test).argmax(dim=1)
            err = (preds != Y_test).float().mean().item()
            print(f"Collab epoch {t:02d}, party {i}, error {err:.4f}")

    for i in models[0:NUM_MALICIOUS]:
        i.num_rounds+=1 

print("Finished")

Collab epoch 00, party 0, error 0.1124
Collab epoch 00, party 1, error 0.1210
Collab epoch 00, party 2, error 0.1173
Collab epoch 00, party 3, error 0.1246
Collab epoch 00, party 4, error 0.1220
Collab epoch 00, party 5, error 0.1094
Collab epoch 00, party 6, error 0.1158
Collab epoch 00, party 7, error 0.1265
Collab epoch 00, party 8, error 0.1268
Collab epoch 00, party 9, error 0.1122
Collab epoch 00, party 10, error 0.1114
Collab epoch 00, party 11, error 0.1318
Collab epoch 00, party 12, error 0.1190
Collab epoch 00, party 13, error 0.1191
Collab epoch 00, party 14, error 0.1146
Collab epoch 00, party 15, error 0.1077
Collab epoch 00, party 16, error 0.1179
Collab epoch 00, party 17, error 0.1141
Collab epoch 00, party 18, error 0.1209
Collab epoch 00, party 19, error 0.1144
Collab epoch 00, party 20, error 0.1223
Collab epoch 00, party 21, error 0.1096
Collab epoch 00, party 22, error 0.1084
Collab epoch 00, party 23, error 0.1078
Collab epoch 00, party 24, error 0.1076
Collab epo