In [None]:
import os, shutil
import logging

import flwr as fl
from flwr.common.logger import logger

logger.setLevel(logging.WARNING)
import numpy as np
import torch
from omegaconf import OmegaConf, DictConfig
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm

from florl.common.util import aggregate_weighted_average, stateful_client
from florl.client.kitten.dqn import *

from strategy import RlFedAvg
from visualisation import *
from experiment_utils import *



sns.set_theme("paper")

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLIENTS = 10
TOTAL_ROUNDS = 100
FRAMES_PER_ROUND =  25 
# EXPERIMENT_REPEATS = 30
EXPERIMENT_REPEATS = 20
SEED = 1

config = DictConfig({
    "rl": {
        "env": {
            "name": "CartPole",
        },
        "algorithm": {
            "gamma": 0.99,
            "tau": 0.005,
            "lr": 0.001,
            "update_frequency": 1,
            "clip_grad_norm": 1,
            "critic": {
                "features": 64
            }
        },
        "memory": {
            "type": "experience_replay",
            "capacity": max(128, TOTAL_ROUNDS * FRAMES_PER_ROUND)
        },
        "train": {
            "initial_collection_size": 1024,
            "minibatch_size": 64
        }
    },
    "fl": {
        "train_config": {
        },
        "evaluate_config": {
            "evaluation_repeats": 5
        }
    }
})

train_config = OmegaConf.to_container(config["fl"]["train_config"])
evaluate_config = OmegaConf.to_container(config["fl"]["evaluate_config"])

def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round}

torch.manual_seed(SEED)
np.random.seed(SEED)

class ClientFactory(DQNClientFactory):
    def __init__(self, config, fixed_reset: bool = False, device: str = "cpu") -> None:
        super().__init__(config, device)
        if fixed_reset:
            self.env = FixedResetWrapper(self.env)

client_factory = ClientFactory(config, fixed_reset = True)

CONTEXT_WS = "florl_ws"
torch.use_deterministic_algorithms(True)

In [None]:
all_results = []
for i, proximal_mu in enumerate([0, 0.01, 0.1, 1.0, 10.0]):
    rng = np.random.default_rng(seed=SEED)
    def _on_fit_config_fn(server_round: int):
        return train_config | {"server_round": server_round, "proximal_mu": proximal_mu, "frames": FRAMES_PER_ROUND}

    strategy = RlFedAvg(
        knowledge=copy.deepcopy(client_factory.create_default_knowledge(config=config["rl"])),
        on_fit_config_fn = _on_fit_config_fn,
        on_evaluate_config_fn= _on_evaluate_config_fn,
        fit_metrics_aggregation_fn=aggregate_weighted_average,
        evaluate_metrics_aggregation_fn=aggregate_weighted_average,
        evaluate_fn=get_evaluation_fn(client_factory.create_client(0, config["rl"])),
        accept_failures=False,
        inplace=False
    )

    federated_results = []
    for _ in tqdm(range(EXPERIMENT_REPEATS)):
        seed = rng.integers(0, 65535)
        if os.path.exists(CONTEXT_WS):
            shutil.rmtree(CONTEXT_WS)
        initialized_clients = {}
        @stateful_client
        def build_client(cid: str) -> fl.client.Client:
            cid = int(cid) + seed
            if cid not in initialized_clients.keys():
                initialized_clients[cid] = client_factory.create_client(
                    cid=cid,
                    config=config["rl"],
                    enable_evaluation = False
                )
                return initialized_clients[cid]
            else:
                return initialized_clients[cid]
        hist = fl.simulation.start_simulation(
            client_fn=build_client,
            client_resources={'num_cpus': 1},
            config=fl.server.ServerConfig(num_rounds=TOTAL_ROUNDS),
            num_clients = NUM_CLIENTS,
            strategy = strategy
        )

        federated_results.append(hist)
    
    all_results.append(federated_results)

In [None]:
all_results[-1][0].metrics_centralized['reward'][-1]

In [None]:
all_results[1][0].metrics_centralized['reward'][-1]

In [None]:
import seaborn as sns

fig, axs = plt.subplots(1,2)
fig.set_size_inches(14,5)
fig.suptitle("QTOpt/QTOptAvg on Pendulum")

colors = sns.color_palette()


plt.rcParams['svg.fonttype'] = 'none'
sns.set_theme("paper")



rounds = list(range(TOTAL_ROUNDS))

for i, federated_result in enumerate(all_results):
    # Loss & Rewards
    federated_losses, federated_rewards = get_federated_metrics(
        federated_result,
        EXPERIMENT_REPEATS,
        NUM_CLIENTS,
        TOTAL_ROUNDS,
        centralised_evaluation=True
    )
    color = colors[i]

    ax_losses = axs[0]
    ax_losses.set_title("Training Loss (Distributed)")
    ax_losses.set_ylabel("Average Loss")
    ax_losses.set_xlabel("Round")

    plot_losses(ax=ax_losses,
                xs=rounds,
                losses=federated_losses,
                label="QtOpt FedAvg",
                color=color,
                hatch="//") 
    #ax_losses.set_yscale("log")

    # Evaluation Reward
    ax_rewards = axs[1]
    ax_rewards.set_title("Evaluation Reward (Centralised)")
    ax_rewards.set_ylabel("Average Episode Reward")
    ax_rewards.set_xlabel("Round")

    #for i in range(EXPERIMENT_REPEATS):
    #    ax_rewards.scatter(rounds, federated_rewards_1[i, :], color=BASELINE_COLOR, alpha=0.3, s=5)
    #    ax_rewards.scatter(rounds, federated_rewards_2[i, :], color=FEDERATED_COLOR, alpha=0.3, s=5)


    plot_rewards(ax=ax_rewards,
                xs=rounds,
                rewards=federated_rewards,
                label="QtOpt FedAvg",
                color=color)

handles, labels = ax_rewards.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right')

In [None]:
for result in all_results:
    