# Server.py 
## Federated Learning Server
Accepts connection from a set number of clients. Has access to a repository of unlabeled public data. Once all clients send predictions on public data, server aggregates results. Once resulrs are compiled the server sends out the public dataset to supliment client training with the newly assigned labels.

In [16]:
from send_receive import *
import socket
import threading
import torchvision.datasets as datasets
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [52]:
def load_features():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    X_train = mnist_trainset.data[50000:,:] #load the last 10,000 images leaving the rest for the clients private data
    X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
    return X_train

def load_labels():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    Y_train = mnist_trainset.targets[50000:] #load the last 10,000 images leaving the rest for the clients private data
    return Y_train

NUM_CLIENTS = 10
NUM_ROUNDS = 50
NUM_CLASSES = 10

logits_dict = {}
num_responses = 0
agreggation_done = 0

X_pub = load_features()
Y_pub = load_labels()

In [44]:
def f_median(logits):

    mu = logits.median(dim=0).values

    return mu

def f_cronus(
    logits,
    eps=1e-3,
    lambda_thresh=9.0,
    max_iters=5
):
    """
    Robust Cronus aggregation.

    logits: Tensor [K, N, C]  (models × samples × classes)
    returns: Tensor [N, C]
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N):

        # Y: [K, C] logits for sample n
        Y = logits[:, n, :]

        # Initial mean
        mu = Y.mean(dim=0)

        for _ in range(max_iters):

            # Centered data
            X = Y - mu

            # If no disagreement, stop
            if X.norm() < 1e-6:
                break

            # Empirical covariance (rank-deficient-safe)
            Sigma = (X.T @ X) / max(len(Y) - 1, 1)

            # Diagonal regularization
            Sigma = Sigma + eps * torch.eye(C, device=device)

            # Eigendecomposition with safety
            try:
                eigvals, eigvecs = torch.linalg.eigh(Sigma)
            except RuntimeError:
                # Covariance too ill-conditioned → skip trimming
                break

            lambda_star = eigvals[-1]

            # If largest eigenvalue small enough, stop trimming
            if lambda_star <= lambda_thresh:
                break

            # Principal direction
            v_star = eigvecs[:, -1]

            # Project samples onto principal direction
            projections = torch.abs((Y - mu) @ v_star)

            max_proj = projections.max()
            if max_proj < 1e-6:
                break

            # Randomized trimming threshold (Cronus)
            T = torch.sqrt(torch.rand(1, device=device)) * max_proj

            mask = projections < T

            # If too few samples left, stop
            if mask.sum() <= 1:
                break

            # Trim and recompute mean
            Y = Y[mask]
            mu = Y.mean(dim=0)

        agg[n] = mu

    return agg

def f_cronus_bila(
    logits,
    eps=1e-3,
    lambda_thresh=9.0,
    max_iters=5,
    beta=5.0
):
    """
    Cronus + BiLA hybrid aggregation.

    logits: Tensor [K, N, C]  (models × samples × classes)
    returns: Tensor [N, C]    (aggregated logits)

    beta: sharpness of reliability weighting
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N):

        # ---- Step 1: Cronus trimming (logit space) ----
        Y = logits[:, n, :]          # [K, C]
        mu = Y.mean(dim=0)

        for _ in range(max_iters):

            X = Y - mu
            if X.norm() < 1e-6:
                break

            Sigma = (X.T @ X) / max(len(Y) - 1, 1)
            Sigma = Sigma + eps * torch.eye(C, device=device)

            try:
                eigvals, eigvecs = torch.linalg.eigh(Sigma)
            except RuntimeError:
                break

            if eigvals[-1] <= lambda_thresh:
                break

            v = eigvecs[:, -1]
            proj = torch.abs((Y - mu) @ v)

            T = torch.sqrt(torch.rand(1, device=device)) * proj.max()
            mask = proj < T

            if mask.sum() <= 1:
                break

            Y = Y[mask]
            mu = Y.mean(dim=0)

        # ---- Step 2: BiLA-style reliability weighting ----
        # Reliability = agreement with trimmed mean
        distances = torch.norm(Y - mu, dim=1)          # [K']
        weights = torch.exp(-beta * distances)
        weights = weights / weights.sum()

        # ---- Step 3: weighted aggregation ----
        agg[n] = (weights[:, None] * Y).sum(dim=0)

    return agg


In [67]:
class BiLACMAggregator(nn.Module):
    """
    Stateful BiLA-CM aggregator
    """

    def __init__(self, num_models, num_classes, hidden_dim=64, lr=1e-3, device="cpu"):
        super().__init__()

        self.K = num_models
        self.C = num_classes
        self.device = device

        # α network
        self.W1 = nn.Linear(self.C, hidden_dim)
        self.W2 = nn.Linear(hidden_dim, self.C)

        # β confusion matrices
        self.pi = nn.Parameter(
            torch.ones(self.K, self.C, self.C) / self.C
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)

    def update(self, logits, epochs=5):
        """
        logits: [K, N, C] (already trimmed!)
        returns: aggregated labels [N, C]
        """

        K, N, C = logits.shape
        assert K == self.K and C == self.C

        logits = logits.to(self.device)
        model_probs = F.softmax(logits, dim=-1)

        self.train()

        for _ in range(epochs):
            self.optimizer.zero_grad()

            # α prior
            mean_probs = model_probs.mean(dim=0)        # [N, C]
            h = torch.tanh(self.W1(mean_probs))
            q_alpha = F.softmax(self.W2(h), dim=-1)

            # β likelihood
            log_likelihood = torch.zeros(N, C, device=self.device)

            for k in range(K):
                pi_k = F.softmax(self.pi[k], dim=-1)
                log_likelihood += torch.log(
                    model_probs[k] @ pi_k.T + 1e-8
                )

            log_post = torch.log(q_alpha + 1e-8) + log_likelihood
            Y_hat = F.softmax(log_post, dim=-1)

            loss = -(Y_hat * log_post).sum(dim=1).mean()
            loss.backward()
            self.optimizer.step()

        return self.infer(logits)

    @torch.no_grad()
    def infer(self, logits):
        """
        Inference only (no learning)
        """
        logits = logits.to(self.device)
        model_probs = F.softmax(logits, dim=-1)

        mean_probs = model_probs.mean(dim=0)
        h = torch.tanh(self.W1(mean_probs))
        q_alpha = F.softmax(self.W2(h), dim=-1)

        log_likelihood = torch.zeros(mean_probs.shape[0], self.C, device=self.device)
        for k in range(self.K):
            pi_k = F.softmax(self.pi[k], dim=-1)
            log_likelihood += torch.log(
                model_probs[k] @ pi_k.T + 1e-8
            )

        log_post = torch.log(q_alpha + 1e-8) + log_likelihood
        return F.softmax(log_post, dim=-1)

class CronusBiLAAggregator:
    def __init__(self, bila):
        self.bila = bila

    def update(self, logits):
        """
        logits: [K, N, C]
        """
        trimmed_logits = f_cronus(logits)      # [N, C]
        trimmed_logits = trimmed_logits.unsqueeze(0).repeat(
            self.bila.K, 1, 1
        )
        return self.bila.update(trimmed_logits)

In [None]:
def handle_client(conn, addr, aggregator):
    X_pub = load_features()
    print(f"[+] Connected: {addr}")

    try:

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        bila = BiLACMAggregator(
            num_models=NUM_CLIENTS,
            num_classes=10,
            lr=1e-3,
            device=device
        ).to(device)
        
        aggregatorr = CronusBiLAAggregator(bila)

        # SEND PUBLIC DATA
        send_tensor(conn, X_pub)
        
        for r in range(NUM_ROUNDS+1):

            print(f"Round {r} Aggregation")

            logits = recv_tensor(conn)

            #aggregate_logits = aggregatorr.update(logits)
            
            aggregate_logits = aggregator(logits)

            print(
                aggregate_logits.min().item(),
                aggregate_logits.max().item(),
                aggregate_logits.std().item()
            )


            send_tensor(conn, aggregate_logits)

            print(aggregate_logits.shape)
            print(aggregate_logits[-1])

    except ConnectionResetError:
        print(f"[-] Connection reset by {addr}")
    finally:
        conn.close()
        print(f"[-] Disconnected: {addr}")

def start_server(HOST, PORT):
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((HOST, PORT))
    server.listen()

    print(f"[SERVER] Listening on {HOST}:{PORT}")
    
    try:
        while True:
            conn, addr = server.accept()
            thread = threading.Thread(
                target=handle_client,
                args=(conn, addr, f_cronus),
                daemon=True
            )
            thread.start()

    except KeyboardInterrupt:
        print("\n[SERVER] Shutdown requested (Ctrl+C)")

    finally:
        server.close()
        print("[SERVER] Socket closed")

In [75]:
HOST = "localhost"
PORT = 65435
start_server(HOST, PORT)

[SERVER] Listening on localhost:65435
[+] Connected: ('127.0.0.1', 39318)
Round 0 Aggregation
0.05408193916082382 0.18214373290538788 0.024950189515948296
torch.Size([10000, 10])
tensor([0.0797, 0.1520, 0.0641, 0.1011, 0.0961, 0.0986, 0.0879, 0.1182, 0.1162,
        0.0860], device='cuda:0')
Round 1 Aggregation
0.046316053718328476 0.20929008722305298 0.03069021925330162
torch.Size([10000, 10])
tensor([0.0678, 0.1609, 0.0583, 0.0934, 0.1024, 0.0933, 0.1017, 0.1265, 0.1211,
        0.0746], device='cuda:0')
Round 2 Aggregation
0.03956795483827591 0.23690105974674225 0.03629045933485031
torch.Size([10000, 10])
tensor([0.0571, 0.1677, 0.0556, 0.0823, 0.1142, 0.0833, 0.1203, 0.1353, 0.1198,
        0.0645], device='cuda:0')
Round 3 Aggregation
0.033815596252679825 0.2674647271633148 0.0432104766368866
torch.Size([10000, 10])
tensor([0.0502, 0.1797, 0.0487, 0.0734, 0.1219, 0.0731, 0.1262, 0.1429, 0.1262,
        0.0578], device='cuda:0')
Round 4 Aggregation
0.029620544984936714 0.2951952517

Exception in thread Thread-40 (handle_client):
Traceback (most recent call last):
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1081[0m, in [35m_bootstrap_inner[0m
    [31mself._context.run[0m[1;31m(self.run)[0m
    [31m~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^[0m
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1023[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/tmp/ipykernel_15226/1749648197.py"[0m, line [35m25[0m, in [35mhandle_client[0m
    logits = recv_tensor(conn)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m42[0m, in [35mrecv_tensor[0m
    header_len_bytes = recv_all(sock, 4)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m33[0m, in [35mrecv_all[0m
    raise ConnectionError("Socket closed while receiving data")
[1;35mConnectionError[0m: 

[-] Disconnected: ('127.0.0.1', 46062)
[+] Connected: ('127.0.0.1', 38210)
Round 0 Aggregation


Exception in thread Thread-41 (handle_client):
Traceback (most recent call last):
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1081[0m, in [35m_bootstrap_inner[0m
    [31mself._context.run[0m[1;31m(self.run)[0m
    [31m~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^[0m
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1023[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/tmp/ipykernel_15226/1749648197.py"[0m, line [35m25[0m, in [35mhandle_client[0m
    logits = recv_tensor(conn)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m42[0m, in [35mrecv_tensor[0m
    header_len_bytes = recv_all(sock, 4)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m33[0m, in [35mrecv_all[0m
    raise ConnectionError("Socket closed while receiving data")
[1;35mConnectionError[0m: 

0.05843349173665047 0.16657042503356934 0.022614723071455956
torch.Size([10000, 10])
tensor([0.0779, 0.1358, 0.1116, 0.0886, 0.0949, 0.1164, 0.0844, 0.1197, 0.0825,
        0.0883], device='cuda:0')
Round 1 Aggregation
[-] Disconnected: ('127.0.0.1', 38210)
[+] Connected: ('127.0.0.1', 33774)
Round 0 Aggregation


Exception in thread Thread-42 (handle_client):
Traceback (most recent call last):
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1081[0m, in [35m_bootstrap_inner[0m
    [31mself._context.run[0m[1;31m(self.run)[0m
    [31m~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^[0m
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1023[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/tmp/ipykernel_15226/1749648197.py"[0m, line [35m25[0m, in [35mhandle_client[0m
    logits = recv_tensor(conn)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m42[0m, in [35mrecv_tensor[0m
    header_len_bytes = recv_all(sock, 4)
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m33[0m, in [35mrecv_all[0m
    raise ConnectionError("Socket closed while receiving data")
[1;35mConnectionError[0m: 

0.06050444766879082 0.16569891571998596 0.026293687522411346
torch.Size([10000, 10])
tensor([0.1184, 0.0801, 0.0976, 0.0641, 0.1303, 0.1263, 0.0737, 0.0874, 0.1154,
        0.1067], device='cuda:0')
Round 1 Aggregation
[-] Disconnected: ('127.0.0.1', 33774)
[+] Connected: ('127.0.0.1', 51856)
Round 0 Aggregation
0.06070582568645477 0.16278550028800964 0.02185731567442417
torch.Size([10000, 10])
tensor([0.0911, 0.1313, 0.0902, 0.1471, 0.0811, 0.1089, 0.0744, 0.0635, 0.0907,
        0.1216], device='cuda:0')
Round 1 Aggregation
[+] Connected: ('127.0.0.1', 54982)
Round 0 Aggregation
0.060003332793712616 0.15729689598083496 0.021297140046954155
torch.Size([10000, 10])
tensor([0.1264, 0.0883, 0.0960, 0.1379, 0.0650, 0.0667, 0.1301, 0.0853, 0.0819,
        0.1223], device='cuda:0')
Round 1 Aggregation
0.050467077642679214 0.18450391292572021 0.029402820393443108
torch.Size([10000, 10])
tensor([0.1349, 0.0824, 0.0923, 0.1542, 0.0563, 0.0578, 0.1422, 0.0758, 0.0719,
        0.1322], device='c

Exception in thread Thread-44 (handle_client):
Traceback (most recent call last):
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1081[0m, in [35m_bootstrap_inner[0m
    [31mself._context.run[0m[1;31m(self.run)[0m
    [31m~~~~~~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^[0m
  File [35m"/usr/lib64/python3.14/threading.py"[0m, line [35m1023[0m, in [35mrun[0m
    [31mself._target[0m[1;31m(*self._args, **self._kwargs)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/tmp/ipykernel_15226/1749648197.py"[0m, line [35m38[0m, in [35mhandle_client[0m
    [31msend_tensor[0m[1;31m(conn, aggregate_logits)[0m
    [31m~~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/home/imadk/Desktop/Cronus_Experiment/send_receive.py"[0m, line [35m23[0m, in [35msend_tensor[0m
    [31msock.sendall[0m[1;31m(header)[0m
    [31m~~~~~~~~~~~~[0m[1;31m^^^^^^^^[0m
[1;35mBrokenPipeError[0m: [35m[Errno 32] Broken pipe[0m


0.0011796214384958148 0.7888602614402771 0.17711085081100464
[-] Disconnected: ('127.0.0.1', 54982)

[SERVER] Shutdown requested (Ctrl+C)
[SERVER] Socket closed
