# Graph isomorphism PPO


## Setup


In [1]:
FORCE_CPU = True

SEED = 349287

BATCH_SIZE = 32

In [2]:
from typing import Optional

import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch import nn
from torch.distributions import Categorical

from tensordict.nn import (
    TensorDictModule,
    TensorDictModuleBase,
    TensorDictSequential,
    ProbabilisticTensorDictSequential,
)
from tensordict.nn.distributions import CompositeDistribution
from tensordict.tensordict import TensorDict, TensorDictBase
from tensordict.nn import InteractionType
from torchrl.modules import ActorCriticOperator, SafeProbabilisticModule

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.tensor_specs import (
    CompositeSpec,
    DiscreteTensorSpec,
    BinaryDiscreteTensorSpec,
    MultiDiscreteTensorSpec,
    TensorSpec,
    UnboundedContinuousTensorSpec,
    Box,
)
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss, ValueEstimators

from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import Batch as GeometricBatch, Data as GeometricData

from jaxtyping import Float, Int, Bool

from pvg.graph_isomorphism import (
    GraphIsomorphismAgentBody,
    GraphIsomorphismAgentPolicyHead,
    GraphIsomorphismAgentCriticHead,
)
from pvg.parameters import Parameters
from pvg.graph_isomorphism.data import GraphIsomorphismData, GraphIsomorphismDataset
from pvg.utils.data import gi_data_to_tensordict
from pvg.constants import VERIFIER_AGENT_NUM, PROVER_AGENT_NUM

In [3]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [4]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


In [5]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="ppo",
    dataset="eru10000",
)

## Specification


In [6]:
class AdjacencyMatrixBox(Box):
    """An abstract representation of the space of adjacency matrices."""

    def __init__(self, max_num_nodes: int):
        self.max_num_nodes = max_num_nodes

In [7]:
class AdjacencyMatrixSpec(TensorSpec):
    def __init__(
        self,
        max_num_nodes: int,
        shape: torch.Size | None = None,
        device: Optional[torch.device | str | int] = None,
        dtype: str | torch.dtype = torch.int32,
    ):
        self.max_num_nodes = max_num_nodes
        self.device = device
        self.dtype = dtype

        if shape is None:
            self.shape = torch.Size([max_num_nodes, max_num_nodes])
        else:
            if shape[-2:] != (max_num_nodes, max_num_nodes):
                raise ValueError(
                    f"The last two dimensions of the shape must be {max_num_nodes}. "
                    f"Got {shape[-2:]}."
                )
            self.shape = torch.Size(shape)

        self.space = AdjacencyMatrixBox(max_num_nodes)

    def is_in(self, val: torch.Tensor) -> bool:
        """Check if a value is a valid adjacency matrix."""

        # Basic type checks
        if not isinstance(val, torch.Tensor):
            return False
        if val.shape[-2:] != (self.max_num_nodes, self.max_num_nodes):
            return False
        if val.dtype != self.dtype:
            return False

        # Make sure the values are either 0 or 1
        if not torch.all(torch.isin(val, torch.tensor([0, 1], device=self.device))):
            return False

        # Make sure the matrix is symmetric
        if not torch.all(val.transpose(-1, -2) == val):
            return False

        # Make sure the diagonal is all zeros
        if not torch.all(torch.isin(torch.diagonal(val, dim1=-2, dim2=-1), 0)):
            return False

        return True

    def rand(self, shape: Optional[list[int] | torch.Size] = None) -> torch.Tensor:
        """Generate a random 1/2 Erdos-Renyi adjacency matrix."""

        if shape is None:
            shape = shape = torch.Size([])

        adjacency_values = torch.rand(*shape, *self.shape, device=device)
        adjacency = (adjacency_values < 0.5).to(self.dtype)
        adjacency = adjacency.triu(diagonal=1)
        adjacency += adjacency.transpose(1, 2).clone()

        return adjacency

    def _project(self, val: torch.Tensor) -> torch.Tensor:
        """Project a value to the space of valid adjacency matrices."""

        # Symmetrize the matrix
        val = (val + val.transpose(1, 2)) / 2

        # Make sure the diagonal is all zeros
        val[..., torch.arange(self.max_num_nodes), torch.arange(self.max_num_nodes)] = 0

        # Make sure the values are either 0 or 1
        return torch.clamp(torch.round(val), min=0, max=1).to(self.dtype)

## Environment


In [8]:
def forgetful_cycle(iterable):
    """A version of cycle that doesn't save copies of the values"""
    while True:
        for i in iterable:
            yield i

In [9]:
class VariableGeometricDataCycler:
    """A loader that cycles through geometric data, but allows the batch size to vary.

    Parameters
    ----------
    dataloader : GeometricDataLoader
        The base dataloader to use. This dataloader will be cycled through.
    """

    def __init__(self, dataloader: GeometricDataLoader):
        self.dataloader = dataloader
        self.dataloader_iter = iter(forgetful_cycle(self.dataloader))
        self.remainder: Optional[list] = None

    def get_batch(self, batch_size: int) -> GeometricBatch:
        """Get a batch of data from the dataloader with the given batch size.

        If the dataloader is exhausted, it will be reset.

        Parameters
        ----------
        batch_size : int
            The size of the batch to return.

        Returns
        -------
        batch : Tensor
            A batch of data with the given batch size.
        """

        left_to_sample = batch_size
        batch_components = []

        # Start by sampling from the remainder from the previous sampling
        if self.remainder is not None:
            batch_components.extend(self.remainder[:left_to_sample])
            if len(self.remainder) <= left_to_sample:
                left_to_sample -= len(self.remainder)
                self.remainder = None
            else:
                self.remainder = self.remainder[left_to_sample:]
                left_to_sample = 0

        # Keep sampling batches until we have enough
        while left_to_sample > 0:
            batch = next(self.dataloader_iter)
            batch_components.extend(batch[:left_to_sample])
            if len(batch) <= left_to_sample:
                left_to_sample -= len(batch)
            else:
                self.remainder = batch[left_to_sample:]
                left_to_sample = 0

        # Concatenate the batch components into a single batch
        batch = GeometricBatch.from_data_list(
            batch_components, follow_batch=self.dataloader.follow_batch
        )
        return batch

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.dataloader!r})"

In [10]:
class GraphIsomorphismEnv(EnvBase):
    def __init__(
        self,
        params: Parameters,
        device: torch.device | str = device,
        int_dtype: torch.dtype = torch.int,
    ):
        super().__init__(device=device)
        self.params = params
        self.int_dtype = int_dtype

        # Create a random number generator
        self.rng = torch.Generator(device=device)

        # Load the dataset
        self.dataset = GraphIsomorphismDataset(params)
        self.data_cycler: Optional[VariableGeometricDataCycler] = None

        # Compute the maximum number of nodes in the dataset
        self.max_num_nodes = 0
        for data in self.dataset:
            self.max_num_nodes = max(
                self.max_num_nodes, data.x_a.shape[0], data.x_b.shape[0]
            )

        # The number of environments is the number of episodes we can fit in a batch
        self.num_envs = params.ppo.frames_per_batch // params.max_message_rounds
        self.batch_size = (self.num_envs,)

        # The spec for the observation space: agents see the adjacency matrix and the
        # messages sent so far. The "message" field contains the most recent message.
        self.observation_spec = CompositeSpec(
            adjacency=AdjacencyMatrixSpec(
                self.max_num_nodes,
                shape=(self.num_envs, 2, self.max_num_nodes, self.max_num_nodes),
                dtype=self.int_dtype,
            ),
            x=BinaryDiscreteTensorSpec(
                params.max_message_rounds,
                shape=(
                    self.num_envs,
                    2,
                    self.max_num_nodes,
                    params.max_message_rounds,
                ),
                dtype=torch.float,
            ),
            node_mask=BinaryDiscreteTensorSpec(
                self.max_num_nodes,
                shape=(
                    self.num_envs,
                    2,
                    self.max_num_nodes,
                ),
                dtype=torch.bool,
            ),
            message=DiscreteTensorSpec(
                2 * self.max_num_nodes,
                shape=(self.num_envs,),
                dtype=torch.long,
            ),
            round=DiscreteTensorSpec(
                params.max_message_rounds,
                shape=(self.num_envs,),
                dtype=self.int_dtype,
            ),
            shape=(self.num_envs,),
        ).to(self.device)

        # The spec for the state space: the true label and the round number
        self.state_spec = CompositeSpec(
            y=BinaryDiscreteTensorSpec(
                1,
                shape=(self.num_envs, 1),
                dtype=self.int_dtype,
            ),
            shape=(self.num_envs,),
        ).to(self.device)

        # Each action space has shape (batch_size, num_agents). Each agent chooses both
        # a node and a decision: reject, accept or continue (represented as 0, 1 or 2).
        # The node is is a number between 0 and 2 * max_num_nodes - 1. If it is less
        # than max_num_nodes, it is a node in the first graph, otherwise it is a node in
        # the second graph. The verifier is agent 0 and the prover is agent 1.
        self.action_spec = CompositeSpec(
            agents=CompositeSpec(
                node_selected=DiscreteTensorSpec(
                    2 * self.max_num_nodes,
                    shape=(self.num_envs, 2),
                    dtype=self.int_dtype,
                ),
                decision=DiscreteTensorSpec(
                    3,
                    shape=(self.num_envs, 2),
                    dtype=self.int_dtype,
                ),
                shape=(self.num_envs,),
            ),
            shape=(self.num_envs,),
        ).to(self.device)

        self.reward_spec = CompositeSpec(
            agents=CompositeSpec(
                reward=UnboundedContinuousTensorSpec(shape=(self.num_envs, 2)),
                shape=(self.num_envs,),
            ),
            shape=(self.num_envs,),
        ).to(self.device)

        self.done_spec = CompositeSpec(
            done=BinaryDiscreteTensorSpec(
                self.num_envs, shape=(self.num_envs,), dtype=torch.bool
            ),
            shape=(self.num_envs,),
        ).to(self.device)

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        """Perform a step in the environment."""

        # Extract the tensors from the dict
        y: Int[Tensor, "batch"] = tensordict["y"]
        x: Float[Tensor, "batch graph node message_round"] = tensordict["x"]
        round: Int[Tensor, "batch"] = tensordict["round"]
        node_selected: Int[Tensor, "batch agent"] = tensordict[
            "agents", "node_selected"
        ]
        decision: Int[Tensor, "batch agent"] = tensordict["agents", "decision"]
        done: Bool[Tensor, "batch"] = tensordict["done"]

        # Compute index of the agent whose turn it is
        agent_index: Int[Tensor, "batch"] = round % 2
        if PROVER_AGENT_NUM == 0:
            agent_index = 1 - agent_index

        # Determine which graph contains the selected node and which node it is there
        # (batch agent)
        which_graph = node_selected >= self.max_num_nodes
        # (batch agent)
        graph_node = torch.where(
            which_graph, node_selected - self.max_num_nodes, node_selected
        )

        # Write the node selected by the agent whose turn it is as a (one-hot) message
        x[
            torch.arange(x.shape[0]),
            which_graph[torch.arange(which_graph.shape[0]), agent_index].int(),
            graph_node[torch.arange(which_graph.shape[0]), agent_index],
            round,
        ] = 1

        # Set the node selected by the agent whose turn it is as the message
        message = node_selected[
            torch.arange(node_selected.shape[0]), agent_index
        ].long()

        # If the verifier has made a guess, compute the reward and terminate the episode
        verifier_decision_made = (agent_index == VERIFIER_AGENT_NUM) & (
            decision[:, VERIFIER_AGENT_NUM] != 2
        )
        done = done | verifier_decision_made
        reward_verifier = (
            verifier_decision_made & (decision[:, VERIFIER_AGENT_NUM] == y.squeeze())
        ).float()
        reward_verifier = reward_verifier * self.params.verifier_reward
        reward_prover = (
            verifier_decision_made & (decision[:, VERIFIER_AGENT_NUM] == 1)
        ).float()
        reward_prover = reward_prover * self.params.prover_reward

        # If we reach the end of the episode and the verifier has not made a guess,
        # terminate it with a negative reward for the verifier
        done = done | (round >= self.params.max_message_rounds - 1)
        reward_verifier[
            (round >= self.params.max_message_rounds - 1) & ~verifier_decision_made
        ] = self.params.verifier_terminated_penalty

        # Stack the rewards for the two agents
        reward = torch.stack([reward_verifier, reward_prover], dim=-1)

        # Put everything together
        next = TensorDict(
            dict(
                adjacency=tensordict["adjacency"],
                x=x,
                node_mask=tensordict["node_mask"],
                message=message,
                round=round + 1,
                done=done,
                agents=TensorDict(dict(reward=reward), batch_size=self.batch_size),
            ),
            batch_size=self.batch_size,
        )
        return next

    def _reset(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
        """(Partially) reset the environment.

        For each episode which is done, takes a new sample from the dataset and resets
        the episode.
        """

        # If no tensordict is given, we're starting afresh
        if tensordict is None:
            tensordict = TensorDict(
                dict(
                    adjacency=torch.empty(
                        *self.batch_size,
                        2,
                        self.max_num_nodes,
                        self.max_num_nodes,
                        device=self.device,
                        dtype=self.int_dtype,
                    ),
                    x=torch.empty(
                        *self.batch_size,
                        2,
                        self.max_num_nodes,
                        self.params.max_message_rounds,
                        device=self.device,
                        dtype=torch.float,
                    ),
                    node_mask=torch.empty(
                        *self.batch_size,
                        2,
                        self.max_num_nodes,
                        device=self.device,
                        dtype=torch.bool,
                    ),
                    message=torch.empty(
                        *self.batch_size,
                        device=self.device,
                        dtype=torch.long,
                    ),
                    y=torch.empty(
                        *self.batch_size,
                        1,
                        device=self.device,
                        dtype=self.int_dtype,
                    ),
                    round=torch.empty(
                        *self.batch_size,
                        device=self.device,
                        dtype=self.int_dtype,
                    ),
                    done=torch.empty(
                        *self.batch_size,
                        device=self.device,
                        dtype=torch.bool,
                    ),
                ),
                batch_size=self.batch_size,
            )

            new_mask = torch.ones(
                *self.batch_size, dtype=torch.bool, device=self.device
            )

        else:
            new_mask = tensordict["done"]
            tensordict = tensordict.clone()

        # If we don't have a data cycler yet, create one
        if self.data_cycler is None:
            dataloader = GeometricDataLoader(
                self.dataset,
                batch_size=self.num_envs,
                follow_batch=["x_a", "x_b"],
                shuffle=True,
                generator=self.rng,
            )
            self.data_cycler = VariableGeometricDataCycler(dataloader)

        # Sample a new batch of data for the episodes that are done
        batch = self.data_cycler.get_batch(new_mask.sum().item())
        batch_tensordict = gi_data_to_tensordict(
            batch, node_dim_size=self.max_num_nodes
        )

        # Copy the new data into the output
        tensordict["adjacency"][new_mask] = batch_tensordict["adjacency"]
        tensordict["x"][new_mask] = torch.zeros_like(tensordict["x"][new_mask])
        tensordict["node_mask"][new_mask] = batch_tensordict["node_mask"]
        tensordict["message"][new_mask] = 0
        tensordict["y"][new_mask] = (batch.wl_score == -1).int().unsqueeze(-1)
        tensordict["round"][new_mask] = 0
        tensordict["done"][new_mask] = False

        return tensordict

    def _set_seed(self, seed: int | None):
        self.rng.manual_seed(seed)

## Test environment


In [11]:
env = GraphIsomorphismEnv(params)

In [12]:
check_env_specs(env)

check_env_specs succeeded!


In [13]:
def printer(env, tensordict):
    to_print = []
    for i in range(2):
        to_print.append(
            f"[{tensordict['round'][i].item()}] "
            f"y: {tensordict['y'][i].item()} "
            f"P: ({tensordict['agents', 'node_selected'][i, 1].item():>2}) "
            f" {tensordict['next', 'agents', 'reward'][i, 1].item():>2} "
            f"V: ({tensordict['agents', 'node_selected'][i, 0].item():>2}, "
            f" {tensordict['agents', 'decision'][i, 0].item():>2}) "
            f" {tensordict['next', 'agents', 'reward'][i, 0].item():>2} "
        )
    print(" | ".join(to_print))
    # print(tensordict["message"][:2, ..., :3].transpose(-1, -2))
    # print(tensordict["message"][:2, ..., :3])


with torch.no_grad():
    out = env.rollout(
        max_steps=40,
        callback=printer,
        auto_cast_to_device=True,
        break_when_any_done=False,
    )

[0] y: 0 P: (16)  0.0 V: (15,   2)  0.0  | [0] y: 1 P: (16)  1.0 V: (11,   1)  1.0 
[1] y: 0 P: ( 5)  0.0 V: (11,   0)  0.0  | [0] y: 0 P: (12)  0.0 V: (12,   2)  0.0 
[2] y: 0 P: ( 1)  0.0 V: ( 9,   0)  1.0  | [1] y: 0 P: (18)  0.0 V: (14,   1)  0.0 
[0] y: 0 P: ( 5)  0.0 V: (17,   0)  1.0  | [2] y: 0 P: (13)  0.0 V: ( 0,   2)  0.0 
[0] y: 0 P: (11)  0.0 V: (21,   0)  1.0  | [3] y: 0 P: (16)  0.0 V: (19,   2)  0.0 
[0] y: 1 P: (14)  0.0 V: ( 3,   2)  0.0  | [4] y: 0 P: (10)  0.0 V: (13,   0)  1.0 
[1] y: 1 P: (11)  0.0 V: ( 3,   2)  0.0  | [0] y: 0 P: ( 7)  1.0 V: (17,   1)  0.0 
[2] y: 1 P: (13)  0.0 V: ( 6,   2)  0.0  | [0] y: 1 P: (16)  0.0 V: (16,   2)  0.0 
[3] y: 1 P: ( 0)  0.0 V: (17,   1)  0.0  | [1] y: 1 P: (15)  0.0 V: ( 9,   0)  0.0 
[4] y: 1 P: (21)  1.0 V: (17,   1)  1.0  | [2] y: 1 P: (10)  0.0 V: (16,   0)  0.0 
[0] y: 0 P: ( 2)  0.0 V: (17,   0)  1.0  | [0] y: 0 P: ( 3)  1.0 V: ( 0,   1)  0.0 
[0] y: 1 P: (20)  1.0 V: (17,   1)  1.0  | [0] y: 0 P: ( 9)  0.0 V: ( 1,   2

In [14]:
out

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 40, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision: Tensor(shape=torch.Size([125, 40, 2]), device=cpu, dtype=torch.int32, is_shared=False),
                node_selected: Tensor(shape=torch.Size([125, 40, 2]), device=cpu, dtype=torch.int32, is_shared=False)},
            batch_size=torch.Size([125, 40]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([125, 40]), device=cpu, dtype=torch.bool, is_shared=False),
        message: Tensor(shape=torch.Size([125, 40]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                adjacency: Tensor(shape=torch.Size([125, 40, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
                agents: TensorDict(
                    fields={
                        reward: Tensor(sha

## Policy and critic


In [15]:
class GraphIsomorphismCombinedBody(TensorDictModuleBase):
    in_keys = ("round", "x", "adjacency", "message")
    out_keys = ("round", "node_level_repr", "graph_level_repr")

    def __init__(
        self,
        prover_body: GraphIsomorphismAgentBody,
        verifier_body: GraphIsomorphismAgentBody,
    ) -> None:
        super().__init__()
        self.prover_body = prover_body
        self.verifier_body = verifier_body

    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        round: Int[Tensor, "batch"] = tensordict["round"]

        # Compute the index of the agent whose turn it is
        prover_turn: Int[Tensor, "batch"] = round % 2 == 0

        batch_size = tensordict.batch_size[0]

        # Build tensordicts to feed to each agent when it's their turn
        input_prover_dict = {}
        input_verifier_dict = {}
        for key in self.prover_body.in_keys:
            if key == "ignore_message":
                input_prover_dict[key] = torch.zeros(
                    prover_turn.sum().item(), device=tensordict.device, dtype=torch.bool
                )
                input_verifier_dict[key] = round[~prover_turn] == 0
            else:
                input_prover_dict[key] = tensordict[key][prover_turn]
                input_verifier_dict[key] = tensordict[key][~prover_turn]
        input_prover = TensorDict(
            input_prover_dict,
            batch_size=(prover_turn.sum().item()),
        )
        input_verifier = TensorDict(
            input_verifier_dict,
            batch_size=((~prover_turn).sum().item()),
        )

        # Run the prover and verifier on the inputs to get the hidden representations
        out_prover = self.prover_body(input_prover)
        out_verifier = self.verifier_body(input_verifier)

        # Combine the representations of the two agents
        node_level_repr = torch.zeros(
            batch_size,
            *out_prover["node_level_repr"].shape[1:],
            dtype=out_prover["node_level_repr"].dtype,
            device=out_prover["node_level_repr"].device,
        )
        node_level_repr[prover_turn] = out_prover["node_level_repr"]
        node_level_repr[~prover_turn] = out_verifier["node_level_repr"]
        graph_level_repr = torch.zeros(
            batch_size,
            *out_prover["graph_level_repr"].shape[1:],
            dtype=out_prover["graph_level_repr"].dtype,
            device=out_prover["graph_level_repr"].device,
        )
        graph_level_repr[prover_turn] = out_prover["graph_level_repr"]
        graph_level_repr[~prover_turn] = out_verifier["graph_level_repr"]

        return tensordict.update(
            dict(node_level_repr=node_level_repr, graph_level_repr=graph_level_repr)
        )

In [16]:
class GraphIsomorphismCombinedPolicyHead(TensorDictModuleBase):
    in_keys = ("round", "node_level_repr", "graph_level_repr")
    out_keys = (("agents", "node_selected_logits"), ("agents", "decision_logits"))

    def __init__(
        self,
        prover_policy_head: GraphIsomorphismAgentPolicyHead,
        verifier_policy_head: GraphIsomorphismAgentPolicyHead,
    ):
        super().__init__()
        self.prover_policy_head = prover_policy_head
        self.verifier_policy_head = verifier_policy_head

    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        round: Int[Tensor, "batch"] = tensordict["round"]
        batch_size = round.shape[0]

        # Compute the index of the agent whose turn it is
        prover_turn: Int[Tensor, "batch"] = (round % 2 == 1)

        # Build tensordicts to feed to each agent when it's their turn
        input_prover = TensorDict(
            {
                key: tensordict[key][prover_turn]
                for key in self.prover_policy_head.in_keys
            },
            batch_size=(prover_turn.sum().item()),
        )
        input_verifier = TensorDict(
            {
                key: tensordict[key][~prover_turn]
                for key in self.verifier_policy_head.in_keys
            },
            batch_size=((~prover_turn).sum().item()),
        )

        # Run the policy heads to obtain the probability distributions
        out_prover = self.prover_policy_head(input_prover)
        out_verifier = self.verifier_policy_head(input_verifier)

        # The combined action distribution logits of the two agents, which defaults to
        # zeros
        node_selected_logits = torch.zeros(
            batch_size,
            2,
            *out_prover["node_selected_logits"].shape[1:],
            dtype=out_prover["node_selected_logits"].dtype,
            device=out_prover.device,
        )
        decision_logits = torch.zeros(
            batch_size,
            2,
            *out_verifier["decision_logits"].shape[1:],
            dtype=out_verifier["decision_logits"].dtype,
            device=out_verifier.device,
        )

        # Inset the agents' action distributions into the combined action distribution
        # logits
        node_selected_logits[prover_turn, PROVER_AGENT_NUM] = out_prover[
            "node_selected_logits"
        ]
        node_selected_logits[~prover_turn, VERIFIER_AGENT_NUM] = out_verifier[
            "node_selected_logits"
        ]
        decision_logits[~prover_turn, VERIFIER_AGENT_NUM] = out_verifier[
            "decision_logits"
        ]

        return tensordict.update(
            dict(
                agents=TensorDict(
                    dict(
                        node_selected_logits=node_selected_logits,
                        decision_logits=decision_logits,
                    ),
                    batch_size=tensordict.batch_size,
                )
            )
        )

In [17]:
class GraphIsomorphismCombinedCriticHead(TensorDictModuleBase):
    in_keys = ("round", "node_level_repr", "graph_level_repr")
    out_keys = (("agents", "value"), )

    def __init__(
        self,
        prover_critic_head: GraphIsomorphismAgentCriticHead,
        verifier_critic_head: GraphIsomorphismAgentCriticHead,
    ):
        super().__init__()
        self.prover_critic_head = prover_critic_head
        self.verifier_critic_head = verifier_critic_head

    def forward(self, tensordict: TensorDictBase) -> TensorDict:
        round: Int[Tensor, "batch"] = tensordict["round"]
        batch_size = round.shape[0]

        # Compute the index of the agent whose turn it is
        prover_turn: Int[Tensor, "batch"] = (round % 2 == 1)

        # Build tensordicts to feed to each agent when it's their turn
        input_prover = TensorDict(
            {
                key: tensordict[key][prover_turn]
                for key in self.prover_critic_head.in_keys
            },
            batch_size=(prover_turn.sum().item()),
        )
        input_verifier = TensorDict(
            {
                key: tensordict[key][~prover_turn]
                for key in self.verifier_critic_head.in_keys
            },
            batch_size=((~prover_turn).sum().item()),
        )

        # Run the critic heads to obtain the probability distributions
        out_prover = self.prover_critic_head(input_prover)
        out_verifier = self.verifier_critic_head(input_verifier)

        # # If we don't have a cached value, or if the batch size is different, initialise
        # # a new one with zeros.
        # if self.value_cache is None or self.value_cache.shape[0] != batch_size:
        #     self.value_cache = torch.zeros(
        #         batch_size,
        #         2,
        #         dtype=out_prover.dtype,
        #         device=out_prover.device,
        #     )

        value = torch.zeros(
            batch_size,
            2,
            dtype=out_prover["value"].dtype,
            device=out_prover.device,
        )

        # Inset the agents' values when it's their turn
        value[prover_turn, PROVER_AGENT_NUM] = out_prover["value"]
        value[~prover_turn, VERIFIER_AGENT_NUM] = out_verifier["value"]

        return tensordict.update(
            dict(
                agents=TensorDict(
                    dict(value=value),
                    batch_size=tensordict.batch_size,
                )
            ),
        )

In [18]:
class CompositeCategoricalDistribution(CompositeDistribution):
    def __init__(self, **categorical_params):
        batch_size = None
        composite_params = {}
        cat_param_kwargs_names = ("logits", "probs")
        for name, param in categorical_params.items():
            for kwarg_name in cat_param_kwargs_names:
                if name.endswith("_" + kwarg_name):
                    composite_params[name[: -len(kwarg_name) - 1]] = {kwarg_name: param}

                    # Make sure all the categorical parameters have the same batch size
                    if batch_size is None:
                        batch_size = param.shape[0]
                    elif batch_size != param.shape[0]:
                        raise ValueError(
                            "All categorical parameters must have the same batch size."
                        )

        composite_params_td = TensorDict(composite_params, batch_size=batch_size)

        super().__init__(
            params=composite_params_td,
            distribution_map={key: Categorical for key in composite_params},
        )

In [19]:
prover_body = GraphIsomorphismAgentBody(params, "prover", device=device)
verifier_body = GraphIsomorphismAgentBody(params, "verifier", device=device)
prover_policy_head = GraphIsomorphismAgentPolicyHead(params, "prover", device=device)
verifier_policy_head = GraphIsomorphismAgentPolicyHead(
    params, "verifier", device=device
)
prover_critic_head = GraphIsomorphismAgentCriticHead(params, "prover", device=device)
verifier_critic_head = GraphIsomorphismAgentCriticHead(
    params, "verifier", device=device
)

In [20]:
body = GraphIsomorphismCombinedBody(prover_body, verifier_body)

In [21]:
policy_head = GraphIsomorphismCombinedPolicyHead(prover_policy_head, verifier_policy_head)

In [22]:
policy_prob_module = SafeProbabilisticModule(
    spec=env.action_spec,
    distribution_class=CompositeCategoricalDistribution,
    in_keys=dict(
        node_selected_logits=("agents", "node_selected_logits"),
        decision_logits=("agents", "decision_logits"),
    ),
    out_keys=env.action_keys,
)

In [23]:
critic_head = GraphIsomorphismCombinedCriticHead(
    prover_critic_head, verifier_critic_head
)

In [24]:
full_model = ActorCriticOperator(body, policy_head, critic_head)

In [25]:
full_model.get_policy_operator()(env.reset())

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision_logits: Tensor(shape=torch.Size([125, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                node_selected_logits: Tensor(shape=torch.Size([125, 2, 22]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([125]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.bool, is_shared=False),
        graph_level_repr: Tensor(shape=torch.Size([125, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        message: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.int64, is_shared=False),
        node_level_repr: Tensor(shape=torch.Size([125, 2, 11, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        node_mask: Tensor(shape=torch.Si

In [26]:
env.reset()

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        done: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.bool, is_shared=False),
        message: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.int64, is_shared=False),
        node_mask: Tensor(shape=torch.Size([125, 2, 11]), device=cpu, dtype=torch.bool, is_shared=False),
        round: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.int32, is_shared=False),
        terminated: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.bool, is_shared=False),
        x: Tensor(shape=torch.Size([125, 2, 11, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        y: Tensor(shape=torch.Size([125, 1]), device=cpu, dtype=torch.int32, is_shared=False)},
    batch_size=torch.Size([125]),
    device=None,
    is_shared=False)

In [27]:
full_model.get_critic_operator()(env.reset())

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision_logits: Tensor(shape=torch.Size([125, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                node_selected_logits: Tensor(shape=torch.Size([125, 2, 22]), device=cpu, dtype=torch.float32, is_shared=False),
                value: Tensor(shape=torch.Size([125, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([125]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.bool, is_shared=False),
        graph_level_repr: Tensor(shape=torch.Size([125, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        message: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.int64, is_shared=False),
        node_level_repr: Tensor(shape=torch.Size([

In [28]:
policy_head

GraphIsomorphismCombinedPolicyHead(
  (prover_policy_head): GraphIsomorphismAgentPolicyHead(
    (node_selector): TensorDictModule(
        module=Sequential(
          (0): Linear(in_features=16, out_features=16, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=16, out_features=1, bias=True)
          (3): Rearrange('batch pair node d_out -> batch (pair node) d_out')
        ),
        device=cpu,
        in_keys=['node_level_repr'],
        out_keys=['node_selected_logits'])
  )
  (verifier_policy_head): GraphIsomorphismAgentPolicyHead(
    (node_selector): TensorDictModule(
        module=Sequential(
          (0): Linear(in_features=16, out_features=16, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=16, out_features=1, bias=True)
          (3): Rearrange('batch pair node d_out -> batch (pair node) d_out')
        ),
        device=cpu,
        in_keys=['node_level_repr'],
        out_keys=['node_selected_logits'])
    (d

In [29]:
policy = ProbabilisticTensorDictSequential(full_model, policy_head, policy_prob_module)

In [30]:
policy(env.reset())

TensorDict(
    fields={
        adjacency: Tensor(shape=torch.Size([125, 2, 11, 11]), device=cpu, dtype=torch.int32, is_shared=False),
        agents: TensorDict(
            fields={
                decision_logits: Tensor(shape=torch.Size([125, 2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                node_selected_logits: Tensor(shape=torch.Size([125, 2, 22]), device=cpu, dtype=torch.float32, is_shared=False),
                value: Tensor(shape=torch.Size([125, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([125]),
            device=None,
            is_shared=False),
        decision: Tensor(shape=torch.Size([125, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([125]), device=cpu, dtype=torch.bool, is_shared=False),
        graph_level_repr: Tensor(shape=torch.Size([125, 2, 16]), device=cpu, dtype=torch.float32, is_shared=False),
        message: Tensor(shape=torch.Size([125]

## Replay buffer

In [32]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(
        params.ppo.frames_per_batch, device=device
    ),
    sampler=SamplerWithoutReplacement(),
    batch_size=params.ppo.minibatch_size,
)
replay_buffer

ReplayBuffer(storage=<torchrl.data.replay_buffers.storages.LazyTensorStorage object at 0x7f974327fed0>, sampler=<torchrl.data.replay_buffers.samplers.SamplerWithoutReplacement object at 0x7f974327e450>, writer=<torchrl.data.replay_buffers.writers.RoundRobinWriter object at 0x7f97749932d0>)