# Config

In [None]:
NUM_CLIENTS = 4
TOTAL_ROUNDS = 50
FRAMES_PER_ROUND = 100 
EXPERIMENT_REPEATS = 1
SEED = 0

In [None]:
import os, shutil

import flwr as fl
import numpy as np
import torch
from omegaconf import OmegaConf, DictConfig
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm

from florl.common.util import aggregate_weighted_average, stateful_client
from florl.client.kitten.qt_opt import *

from strategy import RlFedAvg
from visualisation import *
from experiment_utils import *


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

episode_length = FRAMES_PER_ROUND

config = DictConfig({
    "rl": {
        "env": {
            "name": "Pendulum-v1",
            "max_episode_steps": episode_length
        },
        "algorithm": {
            "gamma": 0.99,
            "tau": 0.005,
            "lr": 0.001,
            "update_frequency": 1,
            "clip_grad_norm": 1,
            "critic": {
                "features": 64
            }
        },
        "memory": {
            "type": "experience_replay",
            "capacity": max(128, TOTAL_ROUNDS * FRAMES_PER_ROUND)
        },
        "train": {
            "initial_collection_size": episode_length,
            "minibatch_size": 32
        }
    },
    "fl": {
        "train_config": {
            "frames": FRAMES_PER_ROUND,
        },
        "evaluate_config": {
            "evaluation_repeats": 1
        }
    }
})

train_config = OmegaConf.to_container(config["fl"]["train_config"])
evaluate_config = OmegaConf.to_container(config["fl"]["evaluate_config"])

def _on_fit_config_fn(server_round: int):
        return train_config | {"server_round": server_round}
def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round}

# Client

In [None]:
import copy
from flwr.common import Config


class MemoryClientFactory(QTOptClientFactory):
        def create_client(self, cid: str, config: Config, **kwargs) -> MemoryClient:
            client =  super().create_client(cid, config, **kwargs)
            return MemoryClient(client)

client_factory = MemoryClientFactory(config)

In [None]:
CONTEXT_WS = "florl_ws"

evaluation_client = client_factory.create_client(0, config["rl"])._client
strategy = RlFedAvg(
    knowledge=copy.deepcopy(client_factory.create_default_knowledge(config=config["rl"])),
    on_fit_config_fn = _on_fit_config_fn,
    on_evaluate_config_fn= _on_evaluate_config_fn,
    fit_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_fn=get_evaluation_fn(evaluation_client),
    accept_failures=False,
    inplace=False
)

federated_results = []
rng = np.random.default_rng(seed=SEED)

for _ in tqdm(range(EXPERIMENT_REPEATS)):
    seed = rng.integers(0, 65535)
    if os.path.exists(CONTEXT_WS):
        shutil.rmtree(CONTEXT_WS)

    @stateful_client
    def build_client(cid: str) -> fl.client.Client:
        cid = int(cid) + seed
        return client_factory.create_client(
            cid=cid,
            config=config["rl"],
            enable_evaluation = False
        )

    hist = fl.simulation.start_simulation(
        client_fn=build_client,
        client_resources={'num_cpus': 1},
        config=fl.server.ServerConfig(num_rounds=TOTAL_ROUNDS),
        num_clients = NUM_CLIENTS,
        strategy = strategy
    )

    federated_results.append(hist)

# Analysis

In [None]:
def pendulum_dimensionality_reduction(states: torch.Tensor):
    # X: sin(theta)
    x = torch.atan2(states[:,0], states[:, 1]).sin()
    # Y: Min-max
    y = states[:, 2] / 8.0
    return x, y

sns.set_theme()
fig, axs = plt.subplots(2,5, sharex='col', sharey='row')
fig.set_size_inches(20,8)

fig.suptitle("Replay Buffer Heterogeneity")
colors = sns.color_palette("husl", NUM_CLIENTS)
markers = ["o", "s", "P", "^", "X"]

# Text annotations for each row
fig.text(0.05, 0.7, "All Data in Replay Buffer", ha='center', va='center', rotation='vertical', fontsize=12)
fig.text(0.05, 0.25, "Most Recent Collected Episode", ha='center', va='center', rotation='vertical', fontsize=12)

ids_ = federated_results[0].metrics_distributed_fit["id"]
id_to_index = {id_: i  for i, id_ in enumerate([x[1] for x in ids_[0][1]["all"]])}
for i in range(5):
    round_number =  int(TOTAL_ROUNDS * (i/5)) + 5
    transition_states_clients = [pickle.loads(x) for x in federated_results[0].metrics_distributed_fit['rb'][round_number][1]]
    new_transition_states_clients = [pickle.loads(x) for x in federated_results[0].metrics_distributed_fit['rb_new'][round_number][1]]
    
    # Total
    ax = axs[0][i]
    ax.set_title(f"Round {round_number}")

    order = [(j, id_to_index[ids_[round_number][1]["all"][j][1]]) for j in range(len(transition_states_clients))]
    order.sort(key=lambda x: x[1])

    for j, _ in order:
        data = transition_states_clients[j]
        id_ = ids_[round_number][1]["all"][j][1]
        index = id_to_index[id_]
        n = int(federated_results[0].metrics_distributed_fit["rb_size"][round_number][1]['avg'])
        x, y = pendulum_dimensionality_reduction(data[:n])
        ax.scatter(x, y, color=colors[index], s=5, marker=markers[index], label=f"Client {index}")
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

    # New
    ax = axs[1][i]
    for j, data in enumerate(new_transition_states_clients):
        id_ = ids_[round_number][1]["all"][j][1]
        index = id_to_index[id_]
        x, y = pendulum_dimensionality_reduction(data)
        ax.scatter(x, y, color=colors[index], s=5, marker=markers[index], label=f"Client {index}")
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)

for ax in axs.flat:
    ax.label_outer()
for i in range(2):
    axs[i][0].set_ylabel("CCW Torque (Normalised)")
for j in range(5):
    axs[1][j].set_xlabel('Sin(θ)')

handles, labels = axs[-1,-1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.23, 1.03), ncol=3, title="Clients", fontsize='small', title_fontsize='medium', markerscale=0.75)