In [None]:
# No significant improvements from only training actor

import os, shutil

import numpy as np
import torch
from omegaconf import OmegaConf, DictConfig
import seaborn as sns
from matplotlib import pyplot as plt
from tqdm import tqdm

from florl.common.util import aggregate_weighted_average, stateful_client
from florl.client.kitten.td3 import *

from strategy import RlFedAvg
from visualisation import *
from experiment_utils import *



sns.set_theme()

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLIENTS = 5
TOTAL_ROUNDS = 150
FRAMES_PER_ROUND = 50
EXPERIMENT_REPEATS = 10
SEED = 0

config = DictConfig({
    "rl": {
        "env": {
            "name": "Pendulum-v1"
        },
        "algorithm": {
            "gamma": 0.99,
            "tau": 0.005,
            "lr": 0.001,
            "update_frequency": 1,
            "clip_grad_norm": 1,
            "critic": {
                "features": 64
            }
        },
        "memory": {
            "type": "experience_replay",
            "capacity": max(128, TOTAL_ROUNDS * FRAMES_PER_ROUND)
        },
        "train": {
            "initial_collection_size": 1024,
            "minibatch_size": 64
        }
    },
    "fl": {
        "train_config": {
            "frames": FRAMES_PER_ROUND,
        },
        "evaluate_config": {
            "evaluation_repeats": 1
        }
    }
})

train_config = OmegaConf.to_container(config["fl"]["train_config"])
evaluate_config = OmegaConf.to_container(config["fl"]["evaluate_config"])

client_factory = TD3ClientFactory(config)

In [None]:
CONTEXT_WS = "florl_ws"

def _on_fit_config_fn(server_round: int):
        return train_config | {"server_round": server_round}
def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round}

strategy = RlFedAvg(
    knowledge=copy.deepcopy(client_factory.create_default_knowledge(config=config["rl"])),
    on_fit_config_fn = _on_fit_config_fn,
    on_evaluate_config_fn= _on_evaluate_config_fn,
    fit_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_fn=get_evaluation_fn(client_factory.create_client(0, config["rl"])),
    accept_failures=False,
    inplace=False
)

federated_results = []
rng = np.random.default_rng(seed=SEED)

for _ in tqdm(range(EXPERIMENT_REPEATS)):
    seed = rng.integers(0, 65535)
    if os.path.exists(CONTEXT_WS):
        shutil.rmtree(CONTEXT_WS)

    initialized_clients = {}

    @stateful_client
    def build_client(cid: str) -> fl.client.Client:
        cid = int(cid) + seed
        if cid not in initialized_clients.keys():
            initialized_clients[cid] = client_factory.create_client(
                cid=cid,
                config=config["rl"],
                enable_evaluation = False
            )
            return initialized_clients[cid]
        else:
            return initialized_clients[cid]

    hist = fl.simulation.start_simulation(
        client_fn=build_client,
        client_resources={'num_cpus': 1},
        config=fl.server.ServerConfig(num_rounds=TOTAL_ROUNDS),
        num_clients = NUM_CLIENTS,
        strategy = strategy
    )

    federated_results.append(hist)

In [None]:
CONTEXT_WS = "florl_ws"

def _on_fit_config_fn(server_round: int):
        return train_config | {"server_round": server_round, "shards": "actor|actor_target"}
def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round, "shards": "actor|actor_target"}

strategy = RlFedAvg(
    knowledge=copy.deepcopy(client_factory.create_default_knowledge(config=config["rl"])),
    on_fit_config_fn = _on_fit_config_fn,
    on_evaluate_config_fn= _on_evaluate_config_fn,
    fit_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_fn=get_evaluation_fn(client_factory.create_client(0, config["rl"])),
    accept_failures=False,
    inplace=False
)

federated_results_2= []
rng = np.random.default_rng(seed=SEED)

for _ in tqdm(range(EXPERIMENT_REPEATS)):
    seed = rng.integers(0, 65535)
    if os.path.exists(CONTEXT_WS):
        shutil.rmtree(CONTEXT_WS)

    initialized_clients = {}

    @stateful_client
    def build_client(cid: str) -> fl.client.Client:
        cid = int(cid) + seed
        if cid not in initialized_clients.keys():
            initialized_clients[cid] = client_factory.create_client(
                cid=cid,
                config=config["rl"],
                enable_evaluation = False
            )
            return initialized_clients[cid]
        else:
            return initialized_clients[cid]

    hist = fl.simulation.start_simulation(
        client_fn=build_client,
        client_resources={'num_cpus': 1},
        config=fl.server.ServerConfig(num_rounds=TOTAL_ROUNDS),
        num_clients = NUM_CLIENTS,
        strategy = strategy
    )

    federated_results_2.append(hist)

In [None]:
fig, axs = plt.subplots(1,2)
fig.set_size_inches(12,5)
fig.suptitle("TD3Avg / TD3Avg (Actor) on Pendulum")

rounds = list(range(TOTAL_ROUNDS))

# Loss & Rewards
federated_losses, federated_rewards = get_federated_metrics(
    federated_results,
    EXPERIMENT_REPEATS,
    NUM_CLIENTS,
    TOTAL_ROUNDS,
    centralised_evaluation=True
)
federated_losses_2, federated_rewards_2 = get_federated_metrics(
    federated_results_2,
    EXPERIMENT_REPEATS,
    NUM_CLIENTS,
    TOTAL_ROUNDS,
    centralised_evaluation=True
)




ax_losses = axs[0]
ax_losses.set_title("Training Loss")
ax_losses.set_ylabel("Average Loss")
ax_losses.set_xlabel("Round")
# ax_losses.set_prop_cycle(color=['red', 'orange', 'yellow', 'green', 'cyan', 'blue', 'violet'])

plot_losses(ax=ax_losses,
            xs=rounds,
            losses=federated_losses,
            label="TD3 FedAvg",
            color="green")


plot_losses(ax=ax_losses,
            xs=rounds,
            losses=federated_losses_2,
            label="TD3 FedAvg (Actor)",
            color="blue")

# Evaluation Reward
ax_rewards = axs[1]
ax_rewards.set_title("Evaluation Reward")
ax_rewards.set_ylabel("Average Episode Reward")
ax_rewards.set_xlabel("Round")

plot_rewards(ax=ax_rewards,
            xs=rounds,
            rewards=federated_rewards,
            label="TD3 FedAvg",
            color="green")

plot_rewards(ax=ax_rewards,
            xs=rounds,
            rewards=federated_rewards_2,
            label="TD3 FedAvg (Actor)",
            color="blue")

handles, labels = ax_rewards.get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right')