In [None]:
import flwr as fl
import torch
from omegaconf import OmegaConf, DictConfig
from matplotlib import pyplot as plt
from tqdm import tqdm

from common.interface import aggregate_weighted_average
from gorila import create_dqn_client


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLIENTS = 10

In [None]:
# Define Config
config = DictConfig({
    "rl": {
        "env": {
            "name": "CartPole-v1"
        },
        "algorithm": {
            "type": "ddpg",
            "gamma": 0.99,
            "tau": 0.005,
            "lr": 0.001,
            "update_frequency": 1,
            "clip_grad_norm": 1,
            "critic": {
                "features": 64
            }
        },
        "memory": {
            "type": "experience_replay",
            "capacity": 20000
        },
        "train": {
            "initial_collection_size": 512,
            "minibatch_size": 32
        }
    },
    "fl": {
        "train_config": {
            "frames": 100,
        },
        "evaluate_config": {
            "evaluation_repeats": 1 
        }
    }
})

train_config = OmegaConf.to_container(config["fl"]["train_config"])
evaluate_config = OmegaConf.to_container(config["fl"]["evaluate_config"])

def _on_fit_config_fn(server_round: int):
    return train_config | {"server_round": server_round}
def _on_evaluate_config_fn(server_round: int):
    return evaluate_config | {"server_round": server_round}

# Baseline

In [None]:
client = create_dqn_client(0, config)

# Manually run through the training loop
hist_fit = []
evaluation_reward = []
for simulated_rounds in tqdm(range(100)):
    metrics, _ = client.train(client.algorithm.critic.net, {"frames": 100})
    hist_fit.append(metrics)
    evaluation_reward.append(client.evaluator.evaluate())



In [None]:
fig, axs = plt.subplots(1,2)
fig.set_size_inches(10,5)

# TODO: Confidence bounds

# Loss
ax = axs[0]
ax.plot([x['loss'] for x in hist_fit])
ax.set_title("Training Loss")
ax.set_ylabel("Average Loss")
ax.set_xlabel("Round")

# Evaluate reward
ax = axs[1]
ax.plot(evaluation_reward)
ax.set_title("Evaluation Reward")
ax.set_ylabel("Average Reward")
ax.set_xlabel("Round")

# Federated

In [None]:
strategy = fl.server.strategy.FedAvg(
    on_fit_config_fn = _on_fit_config_fn,
    on_evaluate_config_fn= _on_evaluate_config_fn,
    fit_metrics_aggregation_fn=aggregate_weighted_average,
    evaluate_metrics_aggregation_fn=aggregate_weighted_average, 
    accept_failures=False
)

hist = fl.simulation.start_simulation(
    client_fn=lambda cid: create_dqn_client(int(cid), config=config).to_client(),
    client_resources={'num_cpus': 1},
    config=fl.server.ServerConfig(num_rounds=100),
    num_clients = 2,
    strategy = strategy
)

In [None]:
fig, axs = plt.subplots(1,2)
fig.set_size_inches(10,5)

# TODO: Confidence bounds

# Loss
ax = axs[0]
ax.plot([x[1]['avg'] for x in hist.metrics_distributed_fit['loss']])
ax.set_title("Training Loss")
ax.set_ylabel("Average Loss")
ax.set_xlabel("Round")

# Evaluate reward
ax = axs[1]
ax.plot([x[1]['avg'] for x in hist.metrics_distributed['reward']])
ax.set_title("Evaluation Reward")
ax.set_ylabel("Average Reward")
ax.set_xlabel("Round")