# Server.py 
## Federated Learning Server
Accepts connection from a set number of clients. Has access to a repository of unlabeled public data. Once all clients send predictions on public data, server aggregates results. Once resulrs are compiled the server sends out the public dataset to supliment client training with the newly assigned labels.

In [16]:
from send_receive import *
import socket
import threading
import torchvision.datasets as datasets
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [52]:
def load_features():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    X_train = mnist_trainset.data[50000:,:] #load the last 10,000 images leaving the rest for the clients private data
    X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
    return X_train

def load_labels():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    Y_train = mnist_trainset.targets[50000:] #load the last 10,000 images leaving the rest for the clients private data
    return Y_train

NUM_CLIENTS = 10
NUM_ROUNDS = 50
NUM_CLASSES = 10

logits_dict = {}
num_responses = 0
agreggation_done = 0

X_pub = load_features()
Y_pub = load_labels()

In [44]:
def f_median(logits):

    mu = logits.median(dim=0).values

    return mu

def f_cronus(
    logits,
    eps=1e-3,
    lambda_thresh=9.0,
    max_iters=5
):
    """
    Robust Cronus aggregation.

    logits: Tensor [K, N, C]  (models × samples × classes)
    returns: Tensor [N, C]
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N):

        # Y: [K, C] logits for sample n
        Y = logits[:, n, :]

        # Initial mean
        mu = Y.mean(dim=0)

        for _ in range(max_iters):

            # Centered data
            X = Y - mu

            # If no disagreement, stop
            if X.norm() < 1e-6:
                break

            # Empirical covariance (rank-deficient-safe)
            Sigma = (X.T @ X) / max(len(Y) - 1, 1)

            # Diagonal regularization
            Sigma = Sigma + eps * torch.eye(C, device=device)

            # Eigendecomposition with safety
            try:
                eigvals, eigvecs = torch.linalg.eigh(Sigma)
            except RuntimeError:
                # Covariance too ill-conditioned → skip trimming
                break

            lambda_star = eigvals[-1]

            # If largest eigenvalue small enough, stop trimming
            if lambda_star <= lambda_thresh:
                break

            # Principal direction
            v_star = eigvecs[:, -1]

            # Project samples onto principal direction
            projections = torch.abs((Y - mu) @ v_star)

            max_proj = projections.max()
            if max_proj < 1e-6:
                break

            # Randomized trimming threshold (Cronus)
            T = torch.sqrt(torch.rand(1, device=device)) * max_proj

            mask = projections < T

            # If too few samples left, stop
            if mask.sum() <= 1:
                break

            # Trim and recompute mean
            Y = Y[mask]
            mu = Y.mean(dim=0)

        agg[n] = mu

    return agg

def f_cronus_bila(
    logits,
    eps=1e-3,
    lambda_thresh=9.0,
    max_iters=5,
    beta=5.0
):
    """
    Cronus + BiLA hybrid aggregation.

    logits: Tensor [K, N, C]  (models × samples × classes)
    returns: Tensor [N, C]    (aggregated logits)

    beta: sharpness of reliability weighting
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N):

        # ---- Step 1: Cronus trimming (logit space) ----
        Y = logits[:, n, :]          # [K, C]
        mu = Y.mean(dim=0)

        for _ in range(max_iters):

            X = Y - mu
            if X.norm() < 1e-6:
                break

            Sigma = (X.T @ X) / max(len(Y) - 1, 1)
            Sigma = Sigma + eps * torch.eye(C, device=device)

            try:
                eigvals, eigvecs = torch.linalg.eigh(Sigma)
            except RuntimeError:
                break

            if eigvals[-1] <= lambda_thresh:
                break

            v = eigvecs[:, -1]
            proj = torch.abs((Y - mu) @ v)

            T = torch.sqrt(torch.rand(1, device=device)) * proj.max()
            mask = proj < T

            if mask.sum() <= 1:
                break

            Y = Y[mask]
            mu = Y.mean(dim=0)

        # ---- Step 2: BiLA-style reliability weighting ----
        # Reliability = agreement with trimmed mean
        distances = torch.norm(Y - mu, dim=1)          # [K']
        weights = torch.exp(-beta * distances)
        weights = weights / weights.sum()

        # ---- Step 3: weighted aggregation ----
        agg[n] = (weights[:, None] * Y).sum(dim=0)

    return agg


In [67]:
class BiLACMAggregator(nn.Module):
    """
    Stateful BiLA-CM aggregator
    """

    def __init__(self, num_models, num_classes, hidden_dim=64, lr=1e-3, device="cpu"):
        super().__init__()

        self.K = num_models
        self.C = num_classes
        self.device = device

        # α network
        self.W1 = nn.Linear(self.C, hidden_dim)
        self.W2 = nn.Linear(hidden_dim, self.C)

        # β confusion matrices
        self.pi = nn.Parameter(
            torch.ones(self.K, self.C, self.C) / self.C
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

    def update(self, logits, epochs=5):
        """
        logits: [K, N, C] (already trimmed!)
        returns: aggregated labels [N, C]
        """

        K, N, C = logits.shape
        assert K == self.K and C == self.C

        logits = logits.to(self.device)
        model_probs = F.softmax(logits, dim=-1)

        self.train()

        for _ in range(epochs):
            self.optimizer.zero_grad()

            # α prior
            mean_probs = model_probs.mean(dim=0)        # [N, C]
            h = torch.tanh(self.W1(mean_probs))
            q_alpha = F.softmax(self.W2(h), dim=-1)

            # β likelihood
            log_likelihood = torch.zeros(N, C, device=self.device)

            for k in range(K):
                pi_k = F.softmax(self.pi[k], dim=-1)
                log_likelihood += torch.log(
                    model_probs[k] @ pi_k.T + 1e-8
                )

            log_post = torch.log(q_alpha + 1e-8) + log_likelihood
            Y_hat = F.softmax(log_post, dim=-1)

            loss = -(Y_hat * log_post).sum(dim=1).mean()
            loss.backward()
            self.optimizer.step()

        return self.infer(logits)

    @torch.no_grad()
    def infer(self, logits):
        """
        Inference only (no learning)
        """
        logits = logits.to(self.device)
        model_probs = F.softmax(logits, dim=-1)

        mean_probs = model_probs.mean(dim=0)
        h = torch.tanh(self.W1(mean_probs))
        q_alpha = F.softmax(self.W2(h), dim=-1)

        log_likelihood = torch.zeros(mean_probs.shape[0], self.C, device=self.device)
        for k in range(self.K):
            pi_k = F.softmax(self.pi[k], dim=-1)
            log_likelihood += torch.log(
                model_probs[k] @ pi_k.T + 1e-8
            )

        log_post = torch.log(q_alpha + 1e-8) + log_likelihood
        return F.softmax(log_post, dim=-1)

class CronusBiLAAggregator:
    def __init__(self, bila):
        self.bila = bila

    def update(self, logits):
        """
        logits: [K, N, C]
        """
        trimmed_logits = f_cronus(logits)      # [N, C]
        trimmed_logits = trimmed_logits.unsqueeze(0).repeat(
            self.bila.K, 1, 1
        )
        return self.bila.update(trimmed_logits)

In [76]:
def handle_client(conn, addr, aggregator):
    X_pub = load_features()
    print(f"[+] Connected: {addr}")

    try:

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        bila = BiLACMAggregator(
            num_models=NUM_CLIENTS,
            num_classes=10,
            lr=1e-3,
            device=device
        ).to(device)
        
        aggregatorr = CronusBiLAAggregator(bila)

        # SEND PUBLIC DATA
        send_tensor(conn, X_pub)
        
        for r in range(NUM_ROUNDS+1):

            print(f"Round {r} Aggregation")

            logits = recv_tensor(conn)

            #aggregate_logits = aggregatorr.update(logits)
            
            aggregate_logits = aggregator(logits)

            print(
                aggregate_logits.min().item(),
                aggregate_logits.max().item(),
                aggregate_logits.std().item()
            )


            send_tensor(conn, aggregate_logits)

            print(aggregate_logits.shape)
            print(aggregate_logits[-1])

    except ConnectionResetError:
        print(f"[-] Connection reset by {addr}")
    finally:
        conn.close()
        print(f"[-] Disconnected: {addr}")

def start_server(HOST, PORT):
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((HOST, PORT))
    server.listen()

    print(f"[SERVER] Listening on {HOST}:{PORT}")
    
    try:
        while True:
            conn, addr = server.accept()
            thread = threading.Thread(
                target=handle_client,
                args=(conn, addr, f_cronus),
                daemon=True
            )
            thread.start()

    except KeyboardInterrupt:
        print("\n[SERVER] Shutdown requested (Ctrl+C)")

    finally:
        server.close()
        print("[SERVER] Socket closed")

In [None]:
HOST = "localhost"
PORT = 65435
start_server(HOST, PORT)

[SERVER] Listening on localhost:65435
[+] Connected: ('127.0.0.1', 53054)
Round 0 Aggregation
-18.526174545288086 29.712066650390625 5.030490398406982
torch.Size([10000, 10])
tensor([ 1.5993, -2.6973,  0.2837, -1.9181, -0.4026,  1.7058, -1.8876,  0.4350,
         7.2955,  2.8986])
Round 1 Aggregation
-19.46892547607422 33.65013885498047 5.060451984405518
torch.Size([10000, 10])
tensor([ 0.3316, -3.7999, -0.9993, -0.6673, -1.1489,  2.3024, -1.5566,  0.2128,
         6.5410,  2.6132])
Round 2 Aggregation
-18.50763511657715 32.01768493652344 5.062962532043457
torch.Size([10000, 10])
tensor([ 1.8491, -1.8711,  1.5974, -2.1387, -1.4464,  1.9788, -1.7829,  1.3851,
         7.4017,  3.4867])
Round 3 Aggregation
-22.22128677368164 30.03365707397461 5.071204662322998
torch.Size([10000, 10])
tensor([-1.5020, -4.2646, -1.5978, -0.1891, -2.4775,  2.5669, -1.6420,  0.4172,
         6.6001,  2.5744])
Round 4 Aggregation
-18.984630584716797 31.97455596923828 5.079565048217773
torch.Size([10000, 10])
