In [None]:
import os
from typing import Literal
import pickle as pkl

from tqdm import tqdm
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from torchrl.envs import TransformedEnv, RewardSum, ParallelEnv
from torchrl.envs.libs import PettingZooWrapper
from torchrl.envs.utils import check_env_specs, set_exploration_type, ExplorationType
from torchrl.modules import ProbabilisticActor, TruncatedNormal, MultiAgentMLP
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

from public_datasets_game.rotting_bandits import (
    RottingBanditsGame,
    SlidingWindowObsWrapper,
)
from public_datasets_game.mechanism import (
    PrivateFunding,
    QuadraticFunding,
    AssuranceContract,
)


device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")
scenario_max_steps = 100
minibatch_size = 1000
num_mini_batches = 10
num_iters = 100
num_epochs = 4
frames_per_batch = num_mini_batches * minibatch_size
total_frames = frames_per_batch * num_iters
num_train_envs = min(10, frames_per_batch // scenario_max_steps)
num_experiment_repeats = 5

eval_interval = 10
eval_num_iters = 5

env_num_bandits = 9
env_num_arms = 1


def create_env(
    env_mechanism, env_reward_allocation, type: Literal["ref", "train", "eval"] = "eval"
):
    def _create(device):
        if env_mechanism == "private":
            mechanism = PrivateFunding()
        elif env_mechanism == "quadratic":
            mechanism = QuadraticFunding()
        elif env_mechanism == "assurance":
            mechanism = AssuranceContract()

        env = SlidingWindowObsWrapper(
            env=RottingBanditsGame(
                num_bandits=env_num_bandits,
                num_arms=env_num_arms,
                mechanism=mechanism,
                max_steps=scenario_max_steps,
                cost_per_play=0.6,
                infinite_horizon=True,
                reward_allocation=env_reward_allocation,
                deficit_resolution="tax",
                normalise_action_space=False,
                randomise_on_reset=True,
                return_funds_info=False,
            ),
            window_sizes=[5, 25, 125],
            flatten_obs=True,
        )
        env = PettingZooWrapper(env, device=device)

        if type == "train" or type == "eval":
            env = TransformedEnv(
                env,
                RewardSum(
                    in_keys=[env.reward_key], out_keys=[("agent", "episode_reward")]
                ),
            )

        return env

    if type == "ref":
        return _create(device)
    elif type == "train":
        return ParallelEnv(
            num_workers=num_train_envs,
            create_env_fn=lambda: _create("cpu"),
            device=device,
        )
    elif type == "eval":
        return ParallelEnv(
            num_workers=eval_num_iters,
            create_env_fn=lambda: _create("cpu"),
            device=device,
        )


In [None]:
for env_reward_allocation, env_mechanism in [
    ("collaborative", "private"),
    ("individual", "private"),
    ("individual", "quadratic"),
    # ("individual", "assurance"),
]:
    for experiment_repeat in range(num_experiment_repeats):
        output_dir = f"data/rb_{env_num_bandits}_{env_num_arms}_{env_reward_allocation}_{env_mechanism}_{experiment_repeat}"

        if os.path.exists(output_dir):
            print("Skipping save: already exists")
            continue

        ref_env = create_env(env_mechanism, env_reward_allocation, type="ref")
        train_env = create_env(env_mechanism, env_reward_allocation, type="train")
        eval_env = create_env(env_mechanism, env_reward_allocation, type="eval")

        check_env_specs(ref_env)

        policy_module = TensorDictModule(
            module=torch.nn.Sequential(
                MultiAgentMLP(
                    n_agent_inputs=ref_env.num_windows * env_num_arms,
                    n_agent_outputs=ref_env.action_spec.shape[-1] * 2,
                    n_agents=env_num_bandits,
                    centralized=False,
                    share_params=True,
                    device=device,
                    depth=2,
                    num_cells=128,
                    activation_class=nn.Tanh,
                ),
                NormalParamExtractor(),
            ),
            in_keys=("agent", "observation"),
            out_keys=[("agent", "loc"), ("agent", "scale")],
        )
        policy = ProbabilisticActor(
            module=policy_module,
            spec=ref_env.action_spec,
            in_keys=[("agent", "loc"), ("agent", "scale")],
            distribution_class=TruncatedNormal,
            distribution_kwargs={
                "low": 0.0,
                "high": ref_env.action_spec.space.high,
            },
            # default_interaction_type=ExplorationType.RANDOM,
            out_keys=ref_env.action_keys,
            return_log_prob=True,
            log_prob_key=("agent", "sample_log_prob"),
        )
        value = TensorDictModule(
            # module=DeepSetsValue(num_windows=ref_env.num_windows),
            MultiAgentMLP(
                n_agent_inputs=ref_env.num_windows * env_num_arms,
                n_agent_outputs=1,
                n_agents=env_num_bandits,
                centralized=True,
                share_params=True,
                device=device,
                depth=2,
                num_cells=128,
                activation_class=nn.Tanh,
            ),
            in_keys=("agent", "observation"),
            out_keys=[("agent", "state_value")],
        )

        policy.to(device)
        value.to(device)

        # Check / Initialise
        td = ref_env.reset()
        with torch.no_grad():
            print(policy(td)[("agent", "action")].shape)
            print(value(td)[("agent", "state_value")].shape)

        collector = SyncDataCollector(
            train_env,
            policy,
            device=device,
            storing_device=device,
            frames_per_batch=frames_per_batch,
            total_frames=total_frames,
            exploration_type=ExplorationType.RANDOM,
        )

        replay_buffer = ReplayBuffer(
            storage=LazyTensorStorage(frames_per_batch, device=device),
            sampler=SamplerWithoutReplacement(),
            batch_size=minibatch_size,
        )

        loss_module = ClipPPOLoss(
            actor_network=policy, critic_network=value, normalise_advantages=False
        )
        loss_module.set_keys(
            reward=ref_env.reward_key,
            action=ref_env.action_key,
            sample_log_prob=("agent", "sample_log_prob"),
            value=("agent", "state_value"),
            done=("agent", "done"),
            terminated=("agent", "terminated"),
        )
        loss_module.make_value_estimator(ValueEstimators.GAE, gamma=0.99, lmbda=0.95)
        loss_module.to(device=device)
        optim = torch.optim.Adam(loss_module.parameters(), 3e-4)

        train_episode_reward_mean_list = []
        eval_episode_reward_mean_list = []

        with tqdm(total=num_iters, desc="episode_reward_mean = 0") as pbar:
            for sampling_td in collector:
                with torch.no_grad():
                    loss_module.value_estimator(
                        sampling_td,
                        params=loss_module.critic_network_params,
                        target_params=loss_module.target_critic_network_params,
                    )
                data_view = sampling_td.reshape(-1)
                replay_buffer.extend(data_view)

                for epoch in range(num_epochs):
                    for _ in range(frames_per_batch // minibatch_size):
                        minibatch: TensorDict = replay_buffer.sample()
                        loss_vals = loss_module(minibatch)

                        loss_value = (
                            loss_vals["loss_objective"]
                            + loss_vals["loss_critic"]
                            + loss_vals["loss_entropy"]
                        )

                        loss_value.backward()

                        torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0)

                        optim.step()
                        optim.zero_grad()

                    if epoch % eval_interval == 0:
                        with set_exploration_type(ExplorationType.DETERMINISTIC):
                            reset_td = eval_env.reset(
                                list_of_kwargs=[
                                    {"seed": i} for i in range(eval_num_iters)
                                ]
                            )
                            eval_td = eval_env.rollout(
                                max_steps=scenario_max_steps,
                                policy=policy,
                                auto_cast_to_device=True,
                                auto_reset=False,
                                tensordict=reset_td,
                            )

                            done = eval_td.get(("next", "agent", "done"))
                            episode_reward_mean = (
                                eval_td.get(("next", "agent", "episode_reward"))[done]
                                .mean()
                                .item()
                            )

                            eval_episode_reward_mean_list.append(episode_reward_mean)

                collector.update_policy_weights_()

                # Logging
                done = sampling_td.get(("next", "agent", "done"))
                episode_reward_mean = (
                    sampling_td.get(("next", "agent", "episode_reward"))[done]
                    .mean()
                    .item()
                )
                train_episode_reward_mean_list.append(episode_reward_mean)
                pbar.set_description(
                    f"train_reward_mean = {episode_reward_mean}, eval_reward_mean = {eval_episode_reward_mean_list[-1]}",
                    refresh=False,
                )
                pbar.update()

        os.makedirs(output_dir)
        with open(
            os.path.join(output_dir, "train_episode_reward_mean_list"), "wb"
        ) as fp:
            pkl.dump(train_episode_reward_mean_list, fp)
        with open(
            os.path.join(output_dir, "eval_episode_reward_mean_list"), "wb"
        ) as fp:
            pkl.dump(eval_episode_reward_mean_list, fp)
        torch.save(policy.state_dict(), os.path.join(output_dir, "policy"))