# Server.py 
## Federated Learning Server
Accepts connection from a set number of clients. Has access to a repository of unlabeled public data. Once all clients send predictions on public data, server aggregates results. Once resulrs are compiled the server sends out the public dataset to supliment client training with the newly assigned labels.

In [1]:
from send_receive import *
import socket
import threading
import torchvision.datasets as datasets
import numpy as np

In [2]:
def load_features():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    X_train = mnist_trainset.data[50000:,:] #load the last 10,000 images leaving the rest for the clients private data
    X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
    return X_train

def load_labels():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    Y_train = mnist_trainset.targets[50000:] #load the last 10,000 images leaving the rest for the clients private data
    return Y_train

NUM_CLIENTS = 5
NUM_ROUNDS = 50
NUM_CLASSES = 10

logits_dict = {}
num_responses = 0
agreggation_done = 0

X_pub = load_features()
Y_pub = load_labels()

In [None]:
def f_cronus(logits, eps=1e-6, lambda_thresh=9.0, max_iters=20):
    """
    f_cronus aggregator from paper, soft label aggregation
    logits: K x N x C (Models x Samples x Classes)
    returns: N x C (Samples x Classes) aggregation
    """

    K, N, C = logits.shape
    device = logits.device

    agg = torch.zeros(N, C, device=device)

    for n in range(N): # Loop through every logit sample

        Y = logits[:, n, :] # For every sample, get the model logits for it (K x C)

        mu = Y.mean(dim=0) # Take the mean across models (K x C)
        Sigma = torch.cov(Y.T) # Calculate the covariance matrice of Y
        Sigma = Sigma + (eps * torch.eye(C, device=device)) # Add positive epsilon to diagonals

        for _ in range(max_iters): # Go through a 
            X = Y - mu # (K x C difference)
            Sigma = (X.T @ X) / max(len(Y) - 1, 1)
            Sigma = Sigma + eps * torch.eye(C, device=device)

            eigvals, eigvecs = torch.linalg.eigh(Sigma)
            lambda_star = eigvals[-1]
            v_star = eigvecs[:, -1]

            if lambda_star <= lambda_thresh:
                break

            projections = torch.abs((Y - mu) @ v_star)
            T = torch.rand(1, device=device).sqrt() * projections.max()

            mask = projections < T
            if mask.sum() <= 1:
                break

            Y = Y[mask]
            mu = Y.mean(dim=0)

        agg[n] = mu


    return agg

def f_bila(logits):



In [4]:
def handle_client(conn, addr):
    X_pub = load_features()
    print(f"[+] Connected: {addr}")

    try:

        # SEND PUBLIC DATA
        send_tensor(conn, X_pub)
        
        for r in range(NUM_ROUNDS+1):

            print(f"Round {r} Aggregation")

            logits = recv_tensor(conn)

            aggregate_logits = f_cronus(logits)

            print(
                aggregate_logits.min().item(),
                aggregate_logits.max().item(),
                aggregate_logits.std().item()
            )


            send_tensor(conn, aggregate_logits)

            print(aggregate_logits.shape)
            print(aggregate_logits[0])

    except ConnectionResetError:
        print(f"[-] Connection reset by {addr}")
    finally:
        conn.close()
        print(f"[-] Disconnected: {addr}")

def start_server(HOST, PORT):
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((HOST, PORT))
    server.listen()

    print(f"[SERVER] Listening on {HOST}:{PORT}")
    
    try:
        while True:
            conn, addr = server.accept()
            thread = threading.Thread(
                target=handle_client,
                args=(conn, addr),
                daemon=True
            )
            thread.start()

    except KeyboardInterrupt:
        print("\n[SERVER] Shutdown requested (Ctrl+C)")

    finally:
        server.close()
        print("[SERVER] Socket closed")


HOST = "localhost"
PORT = 65435
start_server(HOST, PORT)

[SERVER] Listening on localhost:65435
[+] Connected: ('127.0.0.1', 46138)
Round 0 Aggregation


  tensor = torch.from_numpy(array).to(device)


-20.667842864990234 30.283143997192383 5.040340900421143
torch.Size([10000, 10])
tensor([-1.5914,  0.1358,  2.8258,  7.4749, -3.4463,  0.9359, -2.5753, -1.3921,
         1.9621, -1.5695])
Round 1 Aggregation
-19.27682876586914 31.80224609375 5.325206279754639
torch.Size([10000, 10])
tensor([-2.3871, -0.4122,  0.9024,  6.6995, -1.8927, -0.0992, -3.8802, -1.1381,
         2.7006, -2.0063])
Round 2 Aggregation
-19.576099395751953 32.090240478515625 5.372494220733643
torch.Size([10000, 10])
tensor([-2.4395, -0.4692,  0.8933,  6.7248, -1.8513, -0.1144, -3.9192, -1.1104,
         2.7310, -2.0114])
Round 3 Aggregation
-19.817298889160156 32.3421630859375 5.417448997497559
torch.Size([10000, 10])
tensor([-2.4757, -0.5069,  0.9309,  6.7641, -1.8484, -0.1188, -3.9752, -1.0873,
         2.7442, -2.0296])
Round 4 Aggregation
-20.019306182861328 32.554203033447266 5.458367824554443
torch.Size([10000, 10])
tensor([-2.5091, -0.5328,  0.9769,  6.8009, -1.8640, -0.1200, -4.0409, -1.0738,
         2.746