# Federated Learning on a Single Machine using Sockets

This notebook demonstrates a simple horizontal federated learning system using TCP socket communication, running entirely on a single machine.

**Structure**:
- `model.py`: Defines the shared neural network model.
- `federated_server.py`: Coordinates model aggregation.
- `federated_client.py`: Simulates each client performing local training.

## model.py

In [1]:
# model.py
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )

    def forward(self, x):
        return self.fc(x)

## federated_client.py

In [None]:
# federated_client.py
import torch.optim as optim
import torch.nn.functional as F

HOST = '127.0.0.1'
PORT = 8000

def local_train(model, X, y, epochs=5, lr=0.01):
    model = copy.deepcopy(model)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for _ in range(epochs):
        optimizer.zero_grad()
        output = model(X)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
    return model.state_dict()

def run_client():
    model = SimpleNN()
    X = torch.randn(100, 2)
    y = (X[:, 0] + X[:, 1] > 0).long()

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.connect((HOST, PORT))
        print("Connected to server.")

        for round in range(10):
            data = s.recv(10**6)
            global_weights = pickle.loads(data)
            model.load_state_dict(global_weights)

            updated_weights = local_train(model, X, y)
            s.sendall(pickle.dumps(updated_weights))
            print(f"Round {round+1} completed.")

if __name__ == '__main__':
    run_client()

## federated_server.py

In [None]:
# federated_server.py
import socket, pickle, copy
import threading
import torch
from FL.model import SimpleNN

HOST = '127.0.0.1'
PORT = 8000
NUM_CLIENTS = 2

clients = []

def average_weights(weight_list):
    avg_weights = copy.deepcopy(weight_list[0])
    for key in avg_weights.keys():
        for i in range(1, len(weight_list)):
            avg_weights[key] += weight_list[i][key]
        avg_weights[key] = avg_weights[key] / len(weight_list)
    return avg_weights

def client_thread(conn, addr, client_id, shared_weights, lock):
    for rnd in range(10):
        with lock:
            conn.sendall(pickle.dumps(shared_weights[0]))

        data = conn.recv(10**6)
        client_weights = pickle.loads(data)

        with lock:
            shared_weights[1].append(client_weights)

def run_server():
    global_model = SimpleNN()
    shared_weights = [global_model.state_dict(), []]
    lock = threading.Lock()

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((HOST, PORT))
        s.listen()
        print(f"Server listening on {HOST}:{PORT}")

        for i in range(NUM_CLIENTS):
            conn, addr = s.accept()
            print(f"Client {i} connected from {addr}")
            clients.append(threading.Thread(target=client_thread, args=(conn, addr, i, shared_weights, lock)))

        for t in clients:
            t.start()

        for rnd in range(10):
            while True:
                with lock:
                    if len(shared_weights[1]) == NUM_CLIENTS:
                        break

            with lock:
                shared_weights[0] = average_weights(shared_weights[1])
                shared_weights[1] = []

            print(f"Round {rnd+1} aggregation completed.")

        for t in clients:
            t.join()

if __name__ == '__main__':
    run_server()

## How to Run

1. Open a terminal and run the server:

```bash
python federated_server.py
```

2. Open one or more terminals for clients:

```bash
python federated_client.py
```

- All communication occurs via `127.0.0.1` (localhost).
- This setup simulates federated learning on a single machine.