# Federated Learning with Multiple Devices using Socket Communication

This is a minimal implementation of **Horizontal Federated Learning (FedAvg)** using multiple processes or physical machines. Communication between the server and clients is achieved via `socket`-based TCP connections.

## File Structure
```
federated_server.py      ← Server (aggregates model weights)  
federated_client.py      ← Client (performs local training)  
model.py                 ← Shared PyTorch model definition  
```

## 1. model.py — Shared PyTorch Model

In [None]:
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(2, 10),
            nn.ReLU(),
            nn.Linear(10, 2)
        )

    def forward(self, x):
        return self.fc(x)

## 2. federated_client.py — Client Code

In [None]:
# federated_client.py

import socket, pickle, copy
import torch
import torch.optim as optim
import torch.nn.functional as F
from model import SimpleNN
import argparse  # ← 追加

HOST = 'localhost'
PORT = 8000

def local_train(model, X, y, epochs=5, lr=0.01):
    model = copy.deepcopy(model)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    for _ in range(epochs):
        optimizer.zero_grad()
        output = model(X)
        loss = F.cross_entropy(output, y)
        loss.backward()
        optimizer.step()
    return model.state_dict()

def run_client(client_id):
    model = SimpleNN()
    X = torch.randn(100, 2)
    y = (X[:, 0] + X[:, 1] > 0).long()

    test_X = torch.randn(50, 2)
    test_y = (test_X[:, 0] + test_X[:, 1] > 0).long()

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.connect((HOST, PORT))
        print(f"[Client {client_id}] Connected to server.")

        for round in range(10):
            data = s.recv(10 ** 6)
            message = pickle.loads(data)
            if message == "FIN":
                break  # gracefully exit
            model.load_state_dict(message)

            updated_weights = local_train(model, X, y)
            s.sendall(pickle.dumps(updated_weights))
            print(f"[Client {client_id}] Round {round + 1} completed.")

            # テスト精度の出力
            test_X = torch.randn(50, 2)
            test_y = (test_X[:, 0] + test_X[:, 1] > 0).long()
            acc = local_test(model, test_X, test_y)
            print(f"[Client {client_id}] Local test accuracy: {acc:.2%}")
        print(f"[Client {client_id}] Final local test accuracy: {acc:.2%}")

def local_test(model, X, y):
    with torch.no_grad():
        outputs = model(X)
        predictions = torch.argmax(outputs, dim=1)
        accuracy = (predictions == y).float().mean().item()
    return accuracy

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--id", type=int, default=0, help="Client ID")
    args = parser.parse_args()
    run_client(args.id)

## 3. federated_server.py — Server Code

In [None]:
# federated_server.py
# run: python federated_server.py --clients 7 --rounds 10
import socket
import pickle
import threading
import torch
import argparse
from model import SimpleNN

HOST = 'localhost'
PORT = 8000
clients = []

def handle_client(conn, addr):
    print(f"[Connected] {addr}")
    clients.append(conn)

def accept_clients(server_socket, num_clients):
    while len(clients) < num_clients:
        conn, addr = server_socket.accept()
        threading.Thread(target=handle_client, args=(conn, addr)).start()
    print(f"\n[Info] {num_clients} clients connected. Starting training...\n")

def aggregate_weights(weights_list):
    avg_weights = weights_list[0]
    for key in avg_weights.keys():
        for i in range(1, len(weights_list)):
            avg_weights[key] += weights_list[i][key]
        avg_weights[key] /= len(weights_list)
    return avg_weights

def main(num_clients, num_rounds):
    model = SimpleNN()
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind((HOST, PORT))
        s.listen()
        print(f"[Listening] Waiting for {num_clients} clients to connect...")

        accept_clients(s, num_clients)

        for rnd in range(num_rounds):
            print(f"[Round {rnd + 1}] Sending model to all clients...")
            data = pickle.dumps(model.state_dict())

            for conn in clients:
                try:
                    conn.sendall(data)
                except Exception as e:
                    print(f"[Error] Failed to send model to client: {e}")

            weights_list = []
            for idx, conn in enumerate(clients):
                try:
                    client_data = conn.recv(10**6)
                    if not client_data:
                        print(f"[Warning] Client {idx+1} sent no data.")
                        continue
                    weights = pickle.loads(client_data)
                    weights_list.append(weights)
                    print(f"[Server] Received update from Client {idx + 1}")
                except Exception as e:
                    print(f"[Error] Failed to receive from Client {idx + 1}: {e}")

            if weights_list:
                model.load_state_dict(aggregate_weights(weights_list))
                print(f"[Round {rnd + 1}] Aggregation complete with {len(weights_list)} clients.\n")
            else:
                print(f"[Round {rnd + 1}] No updates received.\n")

        for conn in clients:
            conn.close()
        print("[Finished] Training completed.")

        # After training rounds
        for conn in clients:
            try:
                conn.sendall(pickle.dumps("FIN"))  # Send finish signal
                conn.close()
            except:
                pass
        print("[Finished] Training completed.")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--clients', type=int, default=5, help='Number of clients to wait for')
    parser.add_argument('--rounds', type=int, default=10, help='Number of training rounds')
    args = parser.parse_args()

    main(args.clients, args.rounds)

## 4. launch_clients.py — Launch clients

In [None]:
# launch_clients.py
# run: python launch_clients.py --clients 7
import subprocess
import time
import argparse

def launch_clients(num_clients):
    processes = []
    for i in range(num_clients):
        print(f"Launching client {i + 1}")
        p = subprocess.Popen(["python", "federated_client.py", "--id", str(i + 1)])
        processes.append(p)
        time.sleep(0.5)

    for p in processes:
        p.wait()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--clients", type=int, default=3, help="Number of clients to launch")
    args = parser.parse_args()

    launch_clients(args.clients)

## How to Run

1. Start the server:
```bash
python federated_server.py --clients 7 --rounds 10
```

2.1 Then launch each client (in separate terminals or devices):
```bash
python federated_client.py
```

or

2.2 Launch multiple clients (in 1 terminal):
```bash
python launch_clients.py --clients 7
```

## Running on Multiple Devices

To run this system across multiple physical machines (e.g., in the same LAN or Wi-Fi network):

1. **On the server machine**, find its local IP address using:
```bash
ipconfig      # on Windows
ifconfig      # on Linux/macOS
```
Look for an IP like `192.168.x.x`.

2. **On each client**, update the `HOST` line:
```python
HOST = '192.168.x.x'  # Replace with the server's actual IP
```

3. Make sure:
- Both server and clients are on the same network.
- Port 8000 is not blocked by a firewall.
- Python and required libraries are installed on all machines.

## Further Notes

This is a minimal educational prototype. For production-grade use, consider:

- gRPC / HTTP-based communication  
- TLS encryption  
- Secure Aggregation  
- Differential Privacy  
- Real frameworks: [Flower](https://flower.dev/), [FedML](https://fedml.ai/), [PySyft](https://github.com/OpenMined/PySyft)