In [None]:
from tqdm import tqdm
import torch
import torch.nn as nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torchrl.objectives import ClipPPOLoss, ValueEstimators
from torchrl.envs import TransformedEnv, RewardSum
from torchrl.envs.libs import PettingZooWrapper
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor, TruncatedNormal
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
import matplotlib.pyplot as plt


from public_datasets_game.rotting_bandits import (
    RottingBanditsGame,
    SlidingWindowObsWrapper,
)
from public_datasets_game.mechanism import PrivateFunding


device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu")
minibatch_size = 100
n_mini_batches = 10
n_iters = 100
n_epochs = 4
frames_per_batch = n_mini_batches * minibatch_size
total_frames = frames_per_batch * n_iters

In [None]:
class DeepSetsLayer(nn.Module):
    def __init__(self, emb_dim: int):
        """DeepSets layer using mean aggregation."""
        super().__init__()
        self.linear = nn.Linear(emb_dim * 2, emb_dim)
        self.sigma = nn.LeakyReLU()

    def forward(self, x: torch.Tensor):
        mu = x.mean(dim=-2).unsqueeze(-2).expand(x.shape)
        x = self.linear(torch.cat([x, mu], dim=-1))
        x = self.sigma(x)
        return x


class DeepSetsActor(nn.Module):
    def __init__(self, num_windows: int, emb_dim: int = 32):
        super().__init__()
        self.lin_in = nn.Linear(num_windows, emb_dim)
        self.deepsets = nn.Sequential(
            DeepSetsLayer(emb_dim),
            DeepSetsLayer(emb_dim),
        )
        self.lin_out = nn.Linear(emb_dim, 2)
        self.softplus = nn.Softplus()

    def forward(self, x: torch.Tensor):
        x = self.lin_in(x)
        x = self.deepsets(x)
        x = self.lin_out(x)
        x = self.softplus(x)

        loc = x[..., 0]
        scale = x[..., 1]
        x = torch.concat([loc, scale], dim=-1)

        return x


class DeepSetsValue(nn.Module):
    def __init__(self, num_windows: int, emb_dim: int = 32):
        super().__init__()
        self.lin_in = nn.Linear(num_windows, emb_dim)
        self.deepsets = nn.Sequential(
            DeepSetsLayer(emb_dim),
            DeepSetsLayer(emb_dim),
        )
        self.mlp_out = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.LeakyReLU(),
            nn.Linear(emb_dim, 1),
        )

    def forward(self, x: torch.Tensor):
        x = self.lin_in(x)
        x = self.deepsets(x)
        x = x.mean(dim=-2)
        x = self.mlp_out(x)
        return x


env = SlidingWindowObsWrapper(
    env=RottingBanditsGame(
        num_bandits=3,
        num_arms=5,
        mechanism=PrivateFunding(),
        max_steps=100,
        cost_per_play=0.5,
        infinite_horizon=True,
    ),
    window_sizes=[10, 50, 250],
)
env = PettingZooWrapper(env, device=device)

check_env_specs(env)

policy_module = TensorDictModule(
    module=torch.nn.Sequential(
        DeepSetsActor(num_windows=env.num_windows), NormalParamExtractor()
    ),
    in_keys=env.observation_keys,
    out_keys=[("agent", "loc"), ("agent", "scale")],
)
policy = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=[("agent", "loc"), ("agent", "scale")],
    distribution_class=TruncatedNormal,
    distribution_kwargs={
        "low": 0.0,
        "high": env._agent_budget_per_collector_step,
    },
    out_keys=env.action_keys,
    return_log_prob=True,
    log_prob_key=("agent", "sample_log_prob"),
)
value = TensorDictModule(
    module=DeepSetsValue(num_windows=env.num_windows),
    in_keys=env.observation_keys,
    out_keys=[("agent", "state_value")],
)

policy.to(device)
value.to(device)

# Check / Initialise
td = env.reset()
with torch.no_grad():
    print(policy(td)[("agent", "action")].shape)
    print(value(td)[("agent", "state_value")].shape)

env = TransformedEnv(
    env,
    RewardSum(in_keys=[env.reward_key], out_keys=[("agent", "episode_reward")]),
)


In [None]:
collector = SyncDataCollector(
    env,
    policy,
    device=device,
    storing_device=device,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
)

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(frames_per_batch, device=device),
    sampler=SamplerWithoutReplacement(),
    batch_size=minibatch_size,
)

loss_module = ClipPPOLoss(
    actor_network=policy, critic_network=value, normalise_advantages=False
)
loss_module.set_keys(
    reward=env.reward_key,
    action=env.action_key,
    sample_log_prob=("agent", "sample_log_prob"),
    value=("agent", "state_value"),
    done=("agent", "done"),
    terminated=("agent", "terminated"),
)
loss_module.make_value_estimator(ValueEstimators.GAE, gamma=0.99, lmbda=0.95)
loss_module.to(device=device)
optim = torch.optim.Adam(loss_module.parameters(), 3e-4)

In [None]:
pbar = tqdm(total=n_iters, desc="episode_reward_mean = 0")

episode_reward_mean_list = []
for sampling_td in collector:
    with torch.no_grad():
        loss_module.value_estimator(
            sampling_td,
            params=loss_module.critic_network_params,
            target_params=loss_module.target_critic_network_params,
        )
    data_view = sampling_td.reshape(-1)
    replay_buffer.extend(data_view)

    for _ in range(n_epochs):
        for _ in range(frames_per_batch // minibatch_size):
            minibatch: TensorDict = replay_buffer.sample()
            loss_vals = loss_module(minibatch)

            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            loss_value.backward()

            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), 1.0)

            optim.step()
            optim.zero_grad()

    collector.update_policy_weights_()

    # Logging
    done = sampling_td.get(("next", "agent", "done"))
    episode_reward_mean = (
        sampling_td.get(("next", "agent", "episode_reward"))[done].mean().item()
    )
    episode_reward_mean_list.append(episode_reward_mean)
    pbar.set_description(f"episode_reward_mean = {episode_reward_mean}", refresh=False)
    pbar.update()


In [None]:
plt.plot(episode_reward_mean_list)