In [None]:
from typing import List, OrderedDict
import os

import flwr
import numpy as np
import torch
from ultralytics import YOLO

In [None]:
base_model = YOLO(model="runs/detect/train/weights/best.pt")
print(base_model.state_dict())

In [None]:
def set_parameters(model, parameters: List[np.ndarray]):
    params_dict = zip(model.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    model.load_state_dict(state_dict, strict=True)

In [None]:
def get_parameters(model) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in model.state_dict().items()]

In [None]:
class YOLOClient(flwr.client.NumPyClient):
    def __init__(self, model, dataset):
        self.model = model
        self.config_path = f"datasets/{dataset}/config.yaml"
        self.train_size = len(os.listdir(f"datasets/{dataset}/images/train"))

    def get_parameters(self, config):
        return get_parameters(self.model)

    def fit(self, parameters, config):
        set_parameters(self.model, parameters)
        self.model.train(data=self.config_path, epochs=1)
        return get_parameters(self.model), self.train_size, {}

    def evaluate(self, parameters, config):
        set_parameters(self.model, parameters)
        metrics = self.model.val()
        accuracy = metrics.box.map
        loss = 1 - accuracy
        return loss, self.train_size, {"accuracy": accuracy}

In [None]:
datasets = ["citypersons", "roadsigns"]

def client_fn(cid: str) -> YOLOClient:
    model = base_model
    dataset = datasets[int(cid)]
    return YOLOClient(model, dataset).to_client()

In [None]:
strategy = flwr.server.strategy.FedAvg()
client_resources = {"num_cpus": 1, "num_gpus": 1.0}

flwr.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=len(datasets),
    config=flwr.server.ServerConfig(num_rounds=1),
    strategy=strategy,
    client_resources=client_resources,
)

In [None]:
base_model.state_dict()