In [None]:
from typing import Literal

from tqdm import tqdm
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from torchrl.envs import TransformedEnv, RewardSum, ParallelEnv
from torchrl.envs.libs import PettingZooWrapper
from torchrl.envs.utils import check_env_specs, set_exploration_type, ExplorationType
from torchrl.modules import ProbabilisticActor, TruncatedNormal
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
import matplotlib.pyplot as plt

from public_datasets_game.rotting_bandits import (
    RottingBanditsGame,
    SlidingWindowObsWrapper,
)
from public_datasets_game.mechanism import PrivateFunding, QuadraticFundng


device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")
scenario_max_steps = 100
minibatch_size = 1000
num_mini_batches = 10
num_iters = 100
num_epochs = 4
frames_per_batch = num_mini_batches * minibatch_size
total_frames = frames_per_batch * num_iters
num_train_envs = min(10, frames_per_batch // scenario_max_steps)

eval_interval = 10
eval_num_iters = 5

env_num_bandits = 3
env_num_arms = 1
# env_reward_allocation = "collaborative"
env_reward_allocation = "individual"
env_mechanism = "private"
# env_mechanism = "quadratic"

In [None]:
mechanism = QuadraticFundng()

env = SlidingWindowObsWrapper(
    env=RottingBanditsGame(
        num_bandits=env_num_bandits,
        num_arms=env_num_arms,
        mechanism=mechanism,
        max_steps=scenario_max_steps,
        cost_per_play=0.5,
        infinite_horizon=True,
        reward_allocation=env_reward_allocation,
        deficit_resolution="tax",
    ),
    window_sizes=[10, 50, 250],
)

env.reset()

env.step({agent: env.action_space(agent).sample() for agent in env.agents})

In [None]:
class DeepSetsLayer(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        self.transform = nn.Sequential(
            nn.Linear(emb_dim, emb_dim), nn.LeakyReLU(), nn.Linear(emb_dim, emb_dim)
        )
        self.update = nn.Sequential(
            nn.Linear(2 * emb_dim, emb_dim), nn.ReLU(), nn.Linear(emb_dim, emb_dim)
        )

    def forward(self, x: torch.Tensor):
        x = self.transform(x)
        aggr_x = x.sum(dim=-2).unsqueeze(-2).expand(x.shape)
        x = self.update(torch.cat([x, aggr_x], dim=-1))
        return x


class DeepSetsActor(nn.Module):
    def __init__(self, num_windows: int, emb_dim: int = 32):
        super().__init__()
        self.lin_in = nn.Linear(num_windows, emb_dim)
        self.deepsets = nn.Sequential(
            DeepSetsLayer(emb_dim),
            DeepSetsLayer(emb_dim),
        )
        self.lin_out = nn.Linear(emb_dim, 2)
        self.softplus = nn.Softplus()

    def forward(self, x: torch.Tensor):
        x = self.lin_in(x)
        x = self.deepsets(x)
        x = self.lin_out(x)
        x = self.softplus(x)

        loc = x[..., 0]
        scale = x[..., 1]
        x = torch.concat([loc, scale], dim=-1)

        return x


class DeepSetsValue(nn.Module):
    def __init__(self, num_windows: int, emb_dim: int = 32):
        super().__init__()
        self.lin_in = nn.Linear(num_windows, emb_dim)
        self.deepsets = nn.Sequential(
            DeepSetsLayer(emb_dim),
            DeepSetsLayer(emb_dim),
        )
        self.mlp_out = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.LeakyReLU(),
            nn.Linear(emb_dim, 1),
        )

    def forward(self, x: torch.Tensor):
        x = self.lin_in(x)
        x = self.deepsets(x)
        x = x.mean(dim=-2)
        x = self.mlp_out(x)
        return x


def create_env(type: Literal["ref", "train", "eval"] = "eval"):
    def _create(device):
        if env_mechanism == "private":
            mechanism = PrivateFunding()
        elif env_mechanism == "quadratic":
            mechanism = QuadraticFundng()

        env = SlidingWindowObsWrapper(
            env=RottingBanditsGame(
                num_bandits=env_num_bandits,
                num_arms=env_num_arms,
                mechanism=mechanism,
                max_steps=scenario_max_steps,
                cost_per_play=0.5,
                infinite_horizon=True,
                reward_allocation=env_reward_allocation,
                deficit_resolution="tax",
            ),
            window_sizes=[10, 50, 250],
        )
        env = PettingZooWrapper(env, device=device)

        if type == "train" or type == "eval":
            env = TransformedEnv(
                env,
                RewardSum(
                    in_keys=[env.reward_key], out_keys=[("agent", "episode_reward")]
                ),
            )

        return env

    if type == "ref":
        return _create(device)
    elif type == "train":
        return ParallelEnv(
            num_workers=num_train_envs,
            create_env_fn=lambda: _create("cpu"),
            device=device,
        )
    elif type == "eval":
        return ParallelEnv(
            num_workers=eval_num_iters,
            create_env_fn=lambda: _create("cpu"),
            device=device,
        )


ref_env = create_env(type="ref")
train_env = create_env(type="train")
eval_env = create_env(type="eval")

check_env_specs(ref_env)

policy_module = TensorDictModule(
    module=torch.nn.Sequential(
        DeepSetsActor(num_windows=ref_env.num_windows), NormalParamExtractor()
    ),
    in_keys=ref_env.observation_keys,
    out_keys=[("agent", "loc"), ("agent", "scale")],
)
policy = ProbabilisticActor(
    module=policy_module,
    spec=ref_env.action_spec,
    in_keys=[("agent", "loc"), ("agent", "scale")],
    distribution_class=TruncatedNormal,
    distribution_kwargs={
        "low": 0.0,
        "high": ref_env._agent_budget_per_collector_step,
    },
    out_keys=ref_env.action_keys,
    return_log_prob=True,
    log_prob_key=("agent", "sample_log_prob"),
)
value = TensorDictModule(
    module=DeepSetsValue(num_windows=ref_env.num_windows),
    in_keys=ref_env.observation_keys,
    out_keys=[("agent", "state_value")],
)

policy.to(device)
value.to(device)

# Check / Initialise
td = ref_env.reset()
with torch.no_grad():
    print(policy(td)[("agent", "action")].shape)
    print(value(td)[("agent", "state_value")].shape)


In [None]:
collector = SyncDataCollector(
    train_env,
    policy,
    device=device,
    storing_device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    exploration_type=ExplorationType.RANDOM,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(frames_per_batch, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=minibatch_size,
)

loss_module = ClipPPOLoss(
    actor_network=policy, critic_network=value, normalise_advantages=False
)
loss_module.set_keys(
    reward=ref_env.reward_key,
    action=ref_env.action_key,
    sample_log_prob=("agent", "sample_log_prob"),
    value=("agent", "state_value"),
    done=("agent", "done"),
    terminated=("agent", "terminated"),
)
loss_module.make_value_estimator(ValueEstimators.GAE, gamma=0.99, lmbda=0.95)
loss_module.to(device=device)
optim = torch.optim.Adam(loss_module.parameters(), 3e-4)

In [None]:
train_episode_reward_mean_list = []
eval_episode_reward_mean_list = []

with tqdm(total=num_iters, desc="episode_reward_mean = 0") as pbar:
    for sampling_td in collector:
        with torch.no_grad():
            loss_module.value_estimator(
                sampling_td,
                params=loss_module.critic_network_params,
                target_params=loss_module.target_critic_network_params,
            )
        data_view = sampling_td.reshape(-1)
        replay_buffer.extend(data_view)

        for epoch in range(num_epochs):
            for _ in range(frames_per_batch // minibatch_size):
                minibatch: TensorDict = replay_buffer.sample()
                loss_vals = loss_module(minibatch)

                loss_value = (
                    loss_vals["loss_objective"]
                    + loss_vals["loss_critic"]
                    + loss_vals["loss_entropy"]
                )

                loss_value.backward()

                torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0)

                optim.step()
                optim.zero_grad()

            if epoch % eval_interval == 0:
                with set_exploration_type(ExplorationType.DETERMINISTIC):
                    reset_td = eval_env.reset(
                        list_of_kwargs=[{"seed": i} for i in range(eval_num_iters)]
                    )
                    eval_td = eval_env.rollout(
                        max_steps=scenario_max_steps,
                        policy=policy,
                        auto_cast_to_device=True,
                        auto_reset=False,
                        tensordict=reset_td,
                    )

                    done = eval_td.get(("next", "agent", "done"))
                    episode_reward_mean = (
                        eval_td.get(("next", "agent", "episode_reward"))[done]
                        .mean()
                        .item()
                    )

                    eval_episode_reward_mean_list.append(episode_reward_mean)

        collector.update_policy_weights_()

        # Logging
        done = sampling_td.get(("next", "agent", "done"))
        episode_reward_mean = (
            sampling_td.get(("next", "agent", "episode_reward"))[done].mean().item()
        )
        train_episode_reward_mean_list.append(episode_reward_mean)
        pbar.set_description(
            f"train_reward_mean = {episode_reward_mean}, eval_reward_mean = {eval_episode_reward_mean_list[-1]}",
            refresh=False,
        )
        pbar.update()


In [None]:
plt.plot(eval_episode_reward_mean_list)