In [None]:
import os, shutil
from typing import Any, List

import flwr as fl
from flwr.server.history import History
import numpy as np
import torch
from omegaconf import OmegaConf, DictConfig
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm
from florl.common.util import aggregate_weighted_average, stateful_client

from qtoptavg import *

sns.set_theme()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLIENTS = 5
TOTAL_ROUNDS = 20
FRAMES_PER_ROUND = 25
# EXPERIMENT_REPEATS = 30
EXPERIMENT_REPEATS = 2

config = DictConfig({
    "rl": {
        "env": {
            "name": "Pendulum-v1"
        },
        "algorithm": {
            "gamma": 0.99,
            "tau": 0.005,
            "lr": 0.001,
            "update_frequency": 1,
            "clip_grad_norm": 1,
            "critic": {
                "features": 64
            }
        },
        "memory": {
            "type": "experience_replay",
            "capacity": max(128, TOTAL_ROUNDS * FRAMES_PER_ROUND)
        },
        "train": {
            "initial_collection_size": 1024,
            "minibatch_size": 64
        }
    },
    "fl": {
        "train_config": {
            "frames": FRAMES_PER_ROUND,
        },
        "evaluate_config": {
            "evaluation_repeats": 1
        }
    }
})

train_config = OmegaConf.to_container(config["fl"]["train_config"])
evaluate_config = OmegaConf.to_container(config["fl"]["evaluate_config"])

def _on_fit_config_fn(server_round: int):
        return train_config | {"server_round": server_round}
def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round}

client_factory = QTOptClientFactory(config)

# Baseline

In [None]:
baseline_results = []
for seed in range(EXPERIMENT_REPEATS):
    client = client_factory.create_dqn_client(seed, config["rl"])

    # Manually run through the training loop
    hist_fit = []
    evaluation_reward = []
    for simulated_rounds in tqdm(range(TOTAL_ROUNDS)):
        _, metrics = client.train(config["fl"]["train_config"])
        hist_fit.append(metrics)
        evaluation_reward.append(client._evaluator.evaluate(client.policy, repeats=config["fl"]["evaluate_config"]["evaluation_repeats"]))

    baseline_results.append((hist_fit, evaluation_reward))

# Federated

In [None]:
CONTEXT_WS = "florl_ws"

strategy = fl.server.strategy.FedAvg(
#strategy = fl.server.strategy.FedProx(
    #proximal_mu=0.1,
    on_fit_config_fn = _on_fit_config_fn,
    on_evaluate_config_fn= _on_evaluate_config_fn,
    fit_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_metrics_aggregation_fn=aggregate_weighted_average, 
    accept_failures=False
)

federated_results = []

for seed in tqdm(range(EXPERIMENT_REPEATS)):
    if os.path.exists(CONTEXT_WS):
        shutil.rmtree(CONTEXT_WS)

    initialized_clients = {}
    @stateful_client
    def build_client(cid: str) -> fl.client.Client:
        print(cid)
        if cid not in initialized_clients.keys():
            cid = int(cid) + seed * NUM_CLIENTS
            initialized_clients[cid] = client_factory.create_dqn_client(cid, config=config["rl"])
            return initialized_clients[cid]
        else:
            print(f"Client {cid} reused")
            return initialized_clients[cid]

    hist = fl.simulation.start_simulation(
        client_fn=build_client,
        client_resources={'num_cpus': 1},
        config=fl.server.ServerConfig(num_rounds=TOTAL_ROUNDS),
        num_clients = NUM_CLIENTS,
        strategy = strategy
    )

    federated_results.append(hist)

In [None]:
def get_federated_metrics(results: List[History]):
    losses = np.array([[x[1]['all'] for x in hist.metrics_distributed_fit["loss"]] for hist in results])
    losses = losses.transpose((0,2,1,3)).reshape((EXPERIMENT_REPEATS*NUM_CLIENTS, TOTAL_ROUNDS, 2))[:,:,1]

    rewards = np.array([[x[1]['all'] for x in hist.metrics_distributed["reward"]] for hist in results])
    rewards = rewards.transpose((0,2,1,3)).reshape((EXPERIMENT_REPEATS*NUM_CLIENTS, TOTAL_ROUNDS, 2))[:,:,1]

    return losses, rewards


def plot_losses(ax, xs, losses: List[Any], label: str, color: str="green"):
    losses_mean = losses.mean(axis=0)
    losses_std = losses.std(axis=0)
    ax.plot(xs, losses_mean, color=color, label=label)
    # for i in range(NUM_CLIENTS):
    #     ax.scatter(rounds, federated_losses[i], color="g", alpha=0.3, s=5)
    ax.fill_between(
        x=xs,
        y1=losses_mean-losses_std*1.96,
        y2=losses_mean+losses_std*1.96,
        alpha=0.2,
        color=color
    )

def plot_rewards(ax, xs, rewards: List[Any], label: str, color: str="green"):
    rewards_mean = rewards.mean(axis=0)
    rewards_std = rewards.std(axis=0)


    ax.plot(xs, rewards_mean, color=color, label=label)
    # for i in range(NUM_CLIENTS):
    #     ax.scatter(rounds, federated_rewards[i], color="g", alpha=0.3, s=5)
    ax.fill_between(
        x=xs,
        y1=rewards_mean-rewards_std*1.96,
        y2=rewards_mean+rewards_std*1.96,
        alpha=0.2,
        color="g"
    )

def plot_fed_results(ax_loss, ax_reward, xs, results: List[Any], label: str, color: str="green"):
    losses, rewards = get_federated_metrics(results)
    plot_losses(ax_loss, xs, )


In [None]:
fig, axs = plt.subplots(1,2)
fig.set_size_inches(12,5)
fig.suptitle("QTOpt/QTOptAvg on CartPole")

rounds = list(range(TOTAL_ROUNDS))

# Loss & Rewards
baseline_losses = np.array([[s['loss'] for s in ex[0]] for ex in baseline_results])
baseline_rewards = np.array([ex[1] for ex in baseline_results])
federated_losses, federated_rewards = get_federated_metrics(federated_results)

ax_losses = axs[0]
ax_losses.set_title("Training Loss")
ax_losses.set_ylabel("Average Loss")
ax_losses.set_xlabel("Round")
# ax_losses.set_prop_cycle(color=['red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'violet'])

plot_losses(ax=ax_losses,
            xs=rounds,
            losses=baseline_losses,
            label="QtOpt Centralised",
            color="red")

plot_losses(ax=ax_losses,
            xs=rounds,
            losses=federated_losses,
            label="QtOpt FedAvg",
            color="green")

# Evaluation Reward
ax_rewards = axs[1]
ax_rewards.set_title("Evaluation Reward")
ax_rewards.set_ylabel("Average Episode Reward")
ax_rewards.set_xlabel("Round")

plot_rewards(ax=ax_rewards,
            xs=rounds,
            rewards=baseline_rewards,
            label="QtOpt Centralised",
            color="red")

plot_rewards(ax=ax_rewards,
            xs=rounds,
            rewards=federated_rewards,
            label="QtOpt FedAvg",
            color="green")

handles, labels = ax_rewards.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right')