In [None]:
import os
import pickle as pkl

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches
from matplotlib.ticker import FuncFormatter

from torchrl.envs.libs import PettingZooWrapper
from torchrl.modules import ProbabilisticActor, TruncatedNormal, MultiAgentMLP
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor

from public_datasets_game.mechanism import (
    PrivateFunding,
    QuadraticFunding,
    AssuranceContract,
)
from public_datasets_game.rotting_bandits import (
    RottingBanditsGame,
    SlidingWindowObsWrapper,
)

device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")
num_experiment_repeats = 5
minibatch_size = 1000
num_mini_batches = 10
frames_per_batch = num_mini_batches * minibatch_size
scenario_max_steps = 100


experiment_settings = [
    (
        ("rb", 3, 1),
        [
            ("collaborative", "private"),
            ("individual", "private"),
            ("individual", "quadratic"),
            ("individual", "assurance"),
        ],
    ),
    (
        ("rb", 5, 2),
        [
            ("collaborative", "private"),
            ("individual", "private"),
            ("individual", "quadratic"),
            ("individual", "assurance"),
        ],
    ),
]


def create_env(env_num_bandits, env_num_arms, env_mechanism, env_reward_allocation):
    def _create(device):
        if env_mechanism == "private":
            mechanism = PrivateFunding()
        elif env_mechanism == "quadratic":
            mechanism = QuadraticFunding()
        elif env_mechanism == "assurance":
            mechanism = AssuranceContract()
        env = SlidingWindowObsWrapper(
            env=RottingBanditsGame(
                num_bandits=env_num_bandits,
                num_arms=env_num_arms,
                mechanism=mechanism,
                max_steps=scenario_max_steps,
                cost_per_play=0.2 * env_num_bandits,
                infinite_horizon=True,
                reward_allocation=env_reward_allocation,
                deficit_resolution="tax",
                normalise_action_space=False,
                randomise_on_reset=True,
                return_funds_info=True,
            ),
            window_sizes=[5, 25, 125],
            flatten_obs=True,
        )
        env = PettingZooWrapper(env, device=device)

        return env

    return _create(device)


def create_policy(ref_env):
    policy_module = TensorDictModule(
        module=torch.nn.Sequential(
            MultiAgentMLP(
                n_agent_inputs=ref_env.num_windows * env_num_arms,
                n_agent_outputs=ref_env.action_spec.shape[-1] * 2,
                n_agents=env_num_bandits,
                centralized=False,
                share_params=True,
                device=device,
                depth=2,
                num_cells=128,
                activation_class=nn.Tanh,
            ),
            NormalParamExtractor(),
        ),
        in_keys=("agent", "observation"),
        out_keys=[("agent", "loc"), ("agent", "scale")],
    )
    policy = ProbabilisticActor(
        module=policy_module,
        spec=ref_env.action_spec,
        in_keys=[("agent", "loc"), ("agent", "scale")],
        distribution_class=TruncatedNormal,
        distribution_kwargs={
            "low": 0.0,
            "high": ref_env.action_spec.space.high,
        },
        # default_interaction_type=ExplorationType.RANDOM,
        out_keys=ref_env.action_keys,
        return_log_prob=True,
        log_prob_key=("agent", "sample_log_prob"),
    )
    return policy


experiments = {}

for (env_name, env_num_bandits, env_num_arms), settings in experiment_settings:
    for env_reward_allocation, env_mechanism in settings:
        for experiment_repeat in range(num_experiment_repeats):
            output_dir = f"data/rb_{env_num_bandits}_{env_num_arms}_{env_reward_allocation}_{env_mechanism}_{experiment_repeat}"
            if os.path.exists(output_dir):
                with open(
                    os.path.join(output_dir, "train_episode_reward_mean_list"), "rb"
                ) as fp:
                    train_rewards = pkl.load(fp)
                with open(
                    os.path.join(output_dir, "eval_episode_reward_mean_list"), "rb"
                ) as fp:
                    eval_rewards = pkl.load(fp)
                policy_state = torch.load(
                    os.path.join(output_dir, "policy"), map_location=device
                )
                experiments[
                    (
                        env_name,
                        env_num_bandits,
                        env_num_arms,
                        env_reward_allocation,
                        env_mechanism,
                        experiment_repeat,
                    )
                ] = {
                    "train_rewards": train_rewards,
                    "eval_rewards": eval_rewards,
                    "policy_state": policy_state,
                }
            else:
                print(f"{output_dir} does not exist.")

# Display the keys of the loaded experiments
print("Loaded experiments:", list(experiments.keys()))


In [None]:
if not os.path.exists("data/funding_amt.pkl"):
    funding_amt = {}
    for key, exp in experiments.items():
        (
            env_name,
            env_num_bandits,
            env_num_arms,
            env_reward_allocation,
            env_mechanism,
            repeat,
        ) = key
        # Create the environment for the experiment
        env = create_env(
            env_num_bandits, env_num_arms, env_mechanism, env_reward_allocation
        )
        # Create the policy using the created environment
        policy = create_policy(env)
        # Load the saved policy state
        policy.load_state_dict(exp["policy_state"])
        print(f"Loaded environment and policy for experiment: {key}")

        funding = []
        for seed in range(100):
            env.reset(seed=seed)
            td = env.rollout(
                max_steps=scenario_max_steps, policy=policy, auto_reset=True
            )
            info = td.get(("next", "agent", "info"))

            funding.append(info["funding"].mean().item())
        funding_amt[key] = funding

    with open("data/funding_amt.pkl", "wb") as fp:
        pkl.dump(funding_amt, fp)
else:
    with open("data/funding_amt.pkl", "rb") as fp:
        funding_amt = pkl.load(fp)


In [None]:
def moving_average(data, window_size=5):
    return np.convolve(data, np.ones(window_size) / window_size, mode="valid")


fig, all_axs = plt.subplots(2, 2, figsize=(10, 6))

fig.suptitle("MAPPO Training on Rotting Bandits")
fig.set_constrained_layout(True)

axs_train = all_axs[0]  # First row for training rewards
axs_funding = all_axs[1]  # Second row for funding box plots

# --- TRAINING REWARDS (Top Row) ---
for i, ((env_name, env_num_bandits, env_num_arms), settings) in enumerate(
    experiment_settings
):
    for env_reward_allocation, env_mechanism in settings:
        setting_train_rewards = []
        for repeat in range(num_experiment_repeats):
            key = (
                env_name,
                env_num_bandits,
                env_num_arms,
                env_reward_allocation,
                env_mechanism,
                repeat,
            )
            if key in experiments:
                setting_train_rewards.append(experiments[key]["train_rewards"])

        if setting_train_rewards:
            rewards_array = np.array(setting_train_rewards)
            mean_rewards = np.mean(rewards_array, axis=0)
            std_rewards = np.std(rewards_array, axis=0)

            # Smooth the curves using a moving average
            ma_mean = moving_average(mean_rewards, 5)
            ma_std = moving_average(std_rewards, 5)
            x_smoothed = np.arange(len(ma_mean)) * frames_per_batch + (
                frames_per_batch * 2
            )

            ax = axs_train[i]
            ax.plot(
                x_smoothed, ma_mean, label=f"{env_reward_allocation}-{env_mechanism}"
            )
            ax.fill_between(x_smoothed, ma_mean - ma_std, ma_mean + ma_std, alpha=0.3)


def million_formatter(x, pos):
    if x >= 1e6:
        return f"{x / 1e6:.0f} million"
    return str(int(x))


for ax, exp in zip(axs_train, experiment_settings):
    ax.set_xlabel("Collected Frames")
    ax.set_ylabel("Mean Episode Reward")
    ax.set_title(f"Training Curve ({exp[0][1]} consumer | {exp[0][2]} producer)")
    ax.xaxis.set_major_formatter(FuncFormatter(million_formatter))

# --- FUNDING BOX PLOTS (Bottom Row) ---

color_map = list(mcolors.TABLEAU_COLORS.values())

for i, ((env_name, env_num_bandits, env_num_arms), settings) in enumerate(
    experiment_settings
):
    funding_data = []
    colors = []

    for j, (env_reward_allocation, env_mechanism) in enumerate(settings):
        color = color_map[j]

        for repeat in range(num_experiment_repeats):
            key = (
                env_name,
                env_num_bandits,
                env_num_arms,
                env_reward_allocation,
                env_mechanism,
                repeat,
            )
            if key in funding_amt:
                funding_data.append(funding_amt[key])
                colors.append(color)

    if funding_data:
        box = axs_funding[i].boxplot(funding_data, vert=True, patch_artist=True)
        for patch, c in zip(box["boxes"], colors):
            patch.set_facecolor(c)
        axs_funding[i].set_ylabel("Funding Distribution")
        axs_funding[i].set_title(
            f"Funding Distributions ({env_num_bandits} consumer | {env_num_arms} producer)"
        )


legend_patches = [
    mpatches.Patch(
        color=color_map[j], label=f"{env_reward_allocation}, {env_mechanism}"
    )
    for j, (env_reward_allocation, env_mechanism) in enumerate(settings)
]


fig.legend(
    handles=legend_patches,
    title="Reward Allocation & Mechanism",
    loc="lower center",
    bbox_to_anchor=(0.5, -0.1),  # Centers below the subplots
    ncol=len(legend_patches),  # Spread legend items in a single row
)

plt.show()

# Save the figure
fig.savefig("data/basic_training.svg")


In [None]:
experiment_settings = [
    (
        ("rb", 3, 1),
        [
            ("collaborative", "private"),
            ("individual", "private"),
            ("individual", "quadratic"),
            ("individual", "assurance"),
        ],
    ),
    (
        ("rb", 9, 1),
        [
            ("collaborative", "private"),
            ("individual", "private"),
            ("individual", "quadratic"),
            ("individual", "assurance"),
        ],
    ),
    (
        ("rb", 27, 1),
        [
            ("collaborative", "private"),
            ("individual", "private"),
            ("individual", "quadratic"),
            ("individual", "assurance"),
        ],
    ),
]

experiments = {}

for (env_name, env_num_bandits, env_num_arms), settings in experiment_settings:
    for env_reward_allocation, env_mechanism in settings:
        for experiment_repeat in range(num_experiment_repeats):
            output_dir = f"data/rb_{env_num_bandits}_{env_num_arms}_{env_reward_allocation}_{env_mechanism}_{experiment_repeat}"
            if os.path.exists(output_dir):
                with open(
                    os.path.join(output_dir, "train_episode_reward_mean_list"), "rb"
                ) as fp:
                    train_rewards = pkl.load(fp)
                with open(
                    os.path.join(output_dir, "eval_episode_reward_mean_list"), "rb"
                ) as fp:
                    eval_rewards = pkl.load(fp)
                policy_state = torch.load(
                    os.path.join(output_dir, "policy"), map_location=device
                )
                experiments[
                    (
                        env_name,
                        env_num_bandits,
                        env_num_arms,
                        env_reward_allocation,
                        env_mechanism,
                        experiment_repeat,
                    )
                ] = {
                    "train_rewards": train_rewards,
                    "eval_rewards": eval_rewards,
                    "policy_state": policy_state,
                }
            else:
                print(f"{output_dir} does not exist.")

# Extract final rewards, averaging over last 5 iterations
final_rewards = {}

for key, data in experiments.items():
    env_name, env_num_bandits, env_num_arms, env_reward_allocation, env_mechanism, _ = (
        key
    )
    last_5_avg = np.mean(data["train_rewards"][-5:])  # Average of last 5 iterations

    if env_num_bandits not in final_rewards:
        final_rewards[env_num_bandits] = {}

    label_order = [
        "collaborative-private",
        "individual-private",
        "individual-quadratic",
        "individual-assurance",
    ]
    label = f"{env_reward_allocation}-{env_mechanism}"

    if label not in final_rewards[env_num_bandits]:
        final_rewards[env_num_bandits][label] = []

    final_rewards[env_num_bandits][label].append(last_5_avg)

# Compute mean and std across repeats, skipping std for 27 bandits
bar_data = {}
for num_bandits, settings in final_rewards.items():
    bar_data[num_bandits] = {}
    for label in label_order:
        if label in settings:
            mean_val = np.mean(settings[label])
            std_val = (
                np.std(settings[label]) if num_bandits != 27 else 0
            )  # Skip std for 27 bandits
            bar_data[num_bandits][label] = (mean_val, std_val)


# Plotting grouped bar chart
num_groups = len(bar_data)
bar_labels = label_order
num_bars = len(bar_labels)
x = np.arange(num_groups)  # Positions for groups
bar_width = 0.15

fig, ax = plt.subplots(figsize=(10, 6))
colors = plt.cm.tab10.colors  # Use categorical color map

for i, label in enumerate(bar_labels):
    means = [bar_data[num][label][0] for num in sorted(bar_data.keys())]
    stds = [bar_data[num][label][1] for num in sorted(bar_data.keys())]

    # Skip std for 27 bandits in error bars
    yerr = [
        std if num != 27 else np.NaN for std, num in zip(stds, sorted(bar_data.keys()))
    ]

    ax.bar(
        x + i * bar_width,
        means,
        yerr=yerr,
        width=bar_width,
        label=label,
        color=colors[i % len(colors)],
        capsize=5,
    )

ax.set_xticks(x + (num_bars - 1) * bar_width / 2)
ax.set_xticklabels(sorted(bar_data.keys()))
ax.set_xlabel("Number of Agents")
ax.set_ylabel("Trained Episode Reward")
ax.set_yscale("log")
ax.set_title("Scaling Number of Consumers (1 Producer)")
ax.legend(title="Reward Allocation - Mechanism")

plt.show()

fig.savefig("data/scaling.svg")