# Research notebook on classification with federated learning - Simulation mode

## Preparation

Before we begin with the actual code, let's make sure that we have everything we need.

### Installing dependencies

Next, we install and import the necessary packages:

## Importing needed librairies

In [80]:
import torchvision.transforms as transforms
import importlib, os
import flwr as fl
from torch.utils.data import DataLoader, random_split
import torch
DEVICE = torch.device("cpu")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Training on {DEVICE}")

import client.src.pipeline as pl
import client.src.client as flclient
import server.src.server as clserver
from server.main import Args
from models import ViT, ResNet18
import utils

importlib.reload(pl)
importlib.reload(utils)
importlib.reload(flclient)
importlib.reload(clserver)
#importlib.reload(flclient)

NUM_CLIENTS = 8
MODEL = ViT
NUM_CLASSES=3

Training on cpu


In [75]:
def load_datasets(dataset_dir: str, num_clients: int):
    train = os.path.join(dataset_dir, "train")
    test = os.path.join(dataset_dir, "valid")
    d_train = pl.ClassifyDataset(train)
    d_test = pl.ClassifyDataset(test)
    # Split training set into `num_clients` partitions to simulate different local datasets
    partition_size = int(len(d_train) // num_clients)
    lengths = [partition_size] * num_clients
    print("Len d_train {}".format(len(d_train)))
    print("Partition size {} Lengths {}".format(partition_size, lengths))
    datasets = random_split(d_train, lengths, torch.Generator().manual_seed(42))

    # Split each partition into train/val and create DataLoader
    trainloaders = []
    valloaders = []
    for ds in datasets:
        len_val = len(ds) // 10  # 10 % validation set
        len_train = len(ds) - len_val
        lengths = [len_train, len_val]
        ds_train, ds_val = random_split(ds, lengths, torch.Generator().manual_seed(42))
        trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
        valloaders.append(DataLoader(ds_val, batch_size=32))
        
    testloader = DataLoader(d_test, batch_size=32)
    return trainloaders, valloaders, testloader

In [76]:
dataset = "labeling/roboflow/Gear_Classify.v3-gear-fl-raw"
trainloaders, valloaders, testloader = load_datasets(dataset, NUM_CLIENTS)

[DATASET] Loading labels ..
[LABEL] Getting label 0 -> gear_red
[LABEL] Getting label 1 -> pic
[LABEL] Getting label 2 -> gear_black
[DATASET] Done
[DATASET] Loading labels ..
[LABEL] Getting label 0 -> gear_red
[LABEL] Getting label 1 -> pic
[LABEL] Getting label 2 -> gear_black
[DATASET] Done
Len d_train 344
Partition size 43 Lengths [43, 43, 43, 43, 43, 43, 43, 43]


In [77]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, cid, net, trainloader, valloader):
        self.cid = cid
        self.net = net
        self.trainloader = trainloader
        self.valloader = valloader

    def get_parameters(self):
        print(f"[Client {self.cid}] get_parameters")
        return utils.get_parameters(self.net)

    def fit(self, parameters, config):
        print(f"[Client {self.cid}] fit, config: {config}")
        utils.set_parameters(self.net, parameters)
        utils.train(self.net, self.trainloader, epochs=1)
        return utils.get_parameters(self.net), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        print(f"[Client {self.cid}] evaluate, config: {config}")
        utils.set_parameters(self.net, parameters)
        loss, accuracy = utils.test(self.net, self.valloader)
        return float(loss), len(self.valloader), {"accuracy": float(accuracy)}

def client_fn(cid) -> FlowerClient:
    global MODEL
    global DEVICE
    global trainloaders
    global testloader
    net = MODEL(n_classes=NUM_CLIENTS).to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return FlowerClient(cid, net, trainloader, valloader)

In [78]:
# create a classifier client
def client_fn_(cid) -> flclient.GearClassifyClient:
    global MODEL
    global DEVICE
    global trainloaders
    global testloader
    net = MODEL(n_classes=NUM_CLASSES).to(DEVICE)
    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]
    return flclient.GearClassifyClient(cid, net, trainloader, valloader)

In [79]:
# Create an instance of the model and get the parameters
server_model = ResNet18(n_classes=NUM_CLASSES)
# Create an instance of the model and get the parameters
params = utils.get_parameters(server_model)

# ass parameters to the Strategy for server-side parameter initialization
strategy = fl.server.strategy.FedAvg(
    fraction_fit=0.3,
    fraction_eval=0.3,
    min_fit_clients=3,
    min_eval_clients=3,
    min_available_clients=NUM_CLIENTS,
    initial_parameters=fl.common.weights_to_parameters(params),
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=NUM_CLIENTS,
    num_rounds=3,  # Just three rounds
    strategy=strategy,
)

INFO flower 2022-06-28 11:02:43,341 | app.py:158 | Ray initialized with resources: {'node:127.0.0.1': 1.0, 'memory': 12341373338.0, 'CPU': 10.0, 'object_store_memory': 2147483648.0}
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3

[2m[36m(launch_and_fit pid=35827)[0m [Client 2] fit, config: {}


--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 716, in create_connection
    sock.connect

[2m[36m(launch_and_fit pid=35828)[0m [Client 4] fit, config: {}
[2m[36m(launch_and_fit pid=35826)[0m [Client 6] fit, config: {}
[2m[36m(launch_and_evaluate pid=35828)[0m [Client 5] evaluate, config: {}


DEBUG flower 2022-06-28 11:02:49,484 | server.py:174 | evaluate_round received 0 results and 3 failures
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hug

[2m[36m(launch_and_evaluate pid=35827)[0m [Client 2] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=35826)[0m [Client 3] evaluate, config: {}
[2m[36m(launch_and_fit pid=35826)[0m [Client 0] fit, config: {}


DEBUG flower 2022-06-28 11:02:51,273 | server.py:220 | fit_round received 0 results and 3 failures
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt

[2m[36m(launch_and_fit pid=35827)[0m [Client 3] fit, config: {}
[2m[36m(launch_and_fit pid=35828)[0m [Client 5] fit, config: {}


--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 716, in create_connection
    sock.connect

[2m[36m(launch_and_evaluate pid=35827)[0m [Client 1] evaluate, config: {}


DEBUG flower 2022-06-28 11:02:53,801 | server.py:174 | evaluate_round received 0 results and 3 failures
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hug

[2m[36m(launch_and_evaluate pid=35828)[0m [Client 2] evaluate, config: {}


--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 716, in create_connection
    sock.connect

[2m[36m(launch_and_evaluate pid=35827)[0m [Client 3] evaluate, config: {}
[2m[36m(launch_and_fit pid=35827)[0m [Client 1] fit, config: {}
[2m[36m(launch_and_fit pid=35828)[0m [Client 2] fit, config: {}


DEBUG flower 2022-06-28 11:02:56,025 | server.py:220 | fit_round received 0 results and 3 failures
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt

[2m[36m(launch_and_fit pid=35826)[0m [Client 3] fit, config: {}
[2m[36m(launch_and_evaluate pid=35826)[0m [Client 2] evaluate, config: {}


DEBUG flower 2022-06-28 11:02:57,546 | server.py:174 | evaluate_round received 0 results and 3 failures
--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hug

[2m[36m(launch_and_evaluate pid=35828)[0m [Client 5] evaluate, config: {}
[2m[36m(launch_and_evaluate pid=35827)[0m [Client 3] evaluate, config: {}


--- Logging error ---
Traceback (most recent call last):
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/logging/handlers.py", line 1196, in emit
    h.endheaders()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1276, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 1036, in _send_output
    self.send(msg)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 976, in send
    self.connect()
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/http/client.py", line 948, in connect
    (self.host,self.port), self.timeout, self.source_address)
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 728, in create_connection
    raise err
  File "/Users/hugo/opt/anaconda3/envs/flower3.7/lib/python3.7/socket.py", line 716, in create_connection
    sock.connect

