# Server.py 
## Federated Learning Server
Accepts connection from a set number of clients. Has access to a repository of unlabeled public data. Once all clients send predictions on public data, server aggregates results. Once resulrs are compiled the server sends out the public dataset to supliment client training with the newly assigned labels.

In [2]:
from send_receive import *
import socket
import threading
import torchvision.datasets as datasets
import numpy as np

In [None]:
def load_features():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    X_train = mnist_trainset.data[50000:,:] #load the last 10,000 images leaving the rest for the clients private data
    X_train = X_train.float().flatten(start_dim=1, end_dim=2) # Flatten training images
    return X_train

def load_labels():
    mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=None)
    Y_train = mnist_trainset.targets[50000:] #load the last 10,000 images leaving the rest for the clients private data
    return Y_train

NUM_CLIENTS = 10
NUM_ROUNDS = 20
NUM_CLASSES = 10

logits_dict = {}
num_responses = 0
agreggation_done = 0

X_pub = load_features()
Y_pub = load_labels()

In [5]:
def aggregation_rule(logits):
    # logits: (num_models, samples, classes)
    mean_logits = logits.mean(dim=0)
    labels = mean_logits.argmax(dim=1)
    return labels.long()

def handle_client(conn, addr, public_data):
    X_pub = load_features()
    print(f"[+] Connected: {addr}")

    try:

        # SEND PUBLIC DATA
        send_tensor(conn, X_pub)
        
        for r in range(NUM_ROUNDS):

            print(f"Round {r} Aggregation")

            logits = recv_tensor(conn)

            aggregate_logits = aggregation_rule(logits)

            send_tensor(conn, aggregate_logits)

    except ConnectionResetError:
        print(f"[-] Connection reset by {addr}")
    finally:
        conn.close()
        print(f"[-] Disconnected: {addr}")

def start_server(HOST, PORT):
    server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server.bind((HOST, PORT))
    server.listen()

    print(f"[SERVER] Listening on {HOST}:{PORT}")



    while True:
        conn, addr = server.accept()
        thread = threading.Thread(target=handle_client, args=(conn, addr, X_pub), daemon=True)
        thread.start()

HOST = "localhost"
PORT = 65437
start_server(HOST, PORT)

[SERVER] Listening on localhost:65437
[+] Connected: ('127.0.0.1', 38638)
Round 0 Aggregation
Round 1 Aggregation
Round 2 Aggregation
Round 3 Aggregation
Round 4 Aggregation
Round 5 Aggregation
[+] Connected: ('127.0.0.1', 52252)
Round 0 Aggregation
Round 1 Aggregation
Round 2 Aggregation
Round 3 Aggregation
Round 4 Aggregation
Round 5 Aggregation
Round 6 Aggregation
Round 7 Aggregation
Round 8 Aggregation
Round 9 Aggregation
[-] Disconnected: ('127.0.0.1', 52252)
[+] Connected: ('127.0.0.1', 55018)
Round 0 Aggregation
Round 1 Aggregation
Round 2 Aggregation
Round 3 Aggregation
Round 4 Aggregation
Round 5 Aggregation
Round 6 Aggregation
Round 7 Aggregation
[+] Connected: ('127.0.0.1', 40952)
Round 0 Aggregation
Round 1 Aggregation
Round 2 Aggregation
Round 3 Aggregation
Round 4 Aggregation
Round 5 Aggregation
Round 6 Aggregation
Round 7 Aggregation
Round 8 Aggregation
Round 9 Aggregation
[-] Disconnected: ('127.0.0.1', 40952)


KeyboardInterrupt: 