# Graph isomorphism PPO

## Setup

In [46]:
FORCE_CPU = True

SEED = 349287

BATCH_SIZE = 32

In [47]:
from typing import Optional

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.data.tensor_specs import (
    CompositeSpec,
    DiscreteTensorSpec,
    BinaryDiscreteTensorSpec,
    MultiDiscreteTensorSpec,
    TensorSpec,
    UnboundedContinuousTensorSpec,
    Box,
)
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
from torchrl.modules import ProbabilisticActor
from torchrl.objectives import ClipPPOLoss, ValueEstimators

from torch_geometric.loader import DataLoader as GeometricDataLoader
from torch_geometric.data import (
    Batch as GeometricBatch,
    Data as GeometricData
)

from jaxtyping import Float, Int, Bool

from pvg.graph_isomorphism import (
    GraphIsomorphismProver,
    GraphIsomorphismVerifier,
    GraphIsomorphismScenario,
)
from pvg.parameters import Parameters
from pvg.graph_isomorphism.data import GraphIsomorphismData, GraphIsomorphismDataset
from pvg.utils.data import gi_data_to_tensordict

In [48]:
torch.manual_seed(SEED)
torch_generator = torch.Generator().manual_seed(SEED)

In [49]:
if not FORCE_CPU and torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

cpu


## Specification

In [50]:
class AdjacencyMatrixBox(Box):
    """An abstract representation of the space of adjacency matrices."""

    def __init__(self, max_num_nodes: int):
        self.max_num_nodes = max_num_nodes

In [51]:
class AdjacencyMatrixSpec(TensorSpec):
    def __init__(
        self,
        max_num_nodes: int,
        shape: torch.Size | None = None,
        device: Optional[torch.device | str | int] = None,
        dtype: str | torch.dtype = torch.int32,
    ):
        self.max_num_nodes = max_num_nodes
        self.device = device
        self.dtype = dtype

        if shape is None:
            self.shape = torch.Size([max_num_nodes, max_num_nodes])
        else:
            if shape[-2:] != (max_num_nodes, max_num_nodes):
                raise ValueError(
                    f"The last two dimensions of the shape must be {max_num_nodes}. "
                    f"Got {shape[-2:]}."
                )
            self.shape = torch.Size(shape)
            
        self.space = AdjacencyMatrixBox(max_num_nodes)

    def is_in(self, val: torch.Tensor) -> bool:
        """Check if a value is a valid adjacency matrix."""

        # Basic type checks
        if not isinstance(val, torch.Tensor):
            return False
        if val.shape[-2:] != (self.max_num_nodes, self.max_num_nodes):
            return False
        if val.dtype != self.dtype:
            return False

        # Make sure the values are either 0 or 1
        if not torch.all(torch.isin(val, torch.tensor([0, 1], device=self.device))):
            return False

        # Make sure the matrix is symmetric
        if val.transpose(-1, -2) != val:
            return False

        # Make sure the diagonal is all zeros
        if not torch.all(torch.isin(torch.diagonal(val), 0)):
            return False

        return True

    def rand(self, shape: Optional[list[int] | torch.Size] = None) -> torch.Tensor:
        """Generate a random 1/2 Erdos-Renyi adjacency matrix."""

        if shape is None:
            shape = shape = torch.Size([])

        adjacency_values = torch.rand(*shape, *self.shape, device=device)
        adjacency = (adjacency_values < 0.5).to(self.dtype)
        adjacency = adjacency.triu(diagonal=1)
        adjacency += adjacency.transpose(1, 2).clone()

        return adjacency
    
    def _project(self, val: torch.Tensor) -> torch.Tensor:
        """Project a value to the space of valid adjacency matrices."""

        # Symmetrize the matrix
        val = (val + val.transpose(1, 2)) / 2

        # Make sure the diagonal is all zeros
        val[..., torch.arange(self.max_num_nodes), torch.arange(self.max_num_nodes)] = 0

        # Make sure the values are either 0 or 1
        return torch.clamp(torch.round(val), min=0, max=1).to(self.dtype)

## Environment

In [52]:
def forgetful_cycle(iterable):
    """A version of cycle that doesn't save copies of the values"""
    while True:
        for i in iterable:
            yield i

In [53]:
class VariableGeometricDataCycler:
    """A loader that cycles through geometric data, but allows the batch size to vary.

    Parameters
    ----------
    dataloader : GeometricDataLoader
        The base dataloader to use. This dataloader will be cycled through.
    """

    def __init__(self, dataloader: GeometricDataLoader):
        self.dataloader = dataloader
        self.dataloader_iter = iter(forgetful_cycle(self.dataloader))
        self.remainder: Optional[Tensor] = None

    def get_batch(self, batch_size: int) -> GeometricBatch:
        """Get a batch of data from the dataloader with the given batch size.

        If the dataloader is exhausted, it will be reset.

        Parameters
        ----------
        batch_size : int
            The size of the batch to return.

        Returns
        -------
        batch : Tensor
            A batch of data with the given batch size.
        """

        left_to_sample = batch_size
        batch_components = []

        # Start by sampling from the remainder from the previous sampling
        if self.remainder is not None:
            batch_components.extend(self.remainder[:left_to_sample])
            if self.remainder.shape[0] <= left_to_sample:
                left_to_sample -= self.remainder.shape[0]
                self.remainder = None
            else:
                self.remainder = self.remainder[left_to_sample:]
                left_to_sample = 0

        # Keep sampling batches until we have enough
        while left_to_sample > 0:
            batch = next(self.dataloader_iter)
            batch_components.extend(batch[:left_to_sample])
            if len(batch) <= left_to_sample:
                left_to_sample -= len(batch)
            else:
                self.remainder = batch[left_to_sample:]
                left_to_sample = 0

        # Concatenate the batch components into a single batch
        batch = GeometricBatch.from_data_list(
            batch_components, follow_batch=self.dataloader.follow_batch
        )
        return batch

In [54]:
from typing import Optional


class GraphIsomorphismEnv(EnvBase):
    def __init__(self, params: Parameters, device: torch.device | str = device):
        super().__init__(device=device)
        self.params = params

        # Create a random number generator
        self.rng = torch.Generator(device=device)

        # Load the dataset
        self.dataset = GraphIsomorphismDataset(params)
        self.data_cycler: Optional[VariableGeometricDataCycler] = None

        # Compute the maximum number of nodes in the dataset
        self.max_num_nodes = 0
        for data in self.dataset:
            self.max_num_nodes = max(
                self.max_num_nodes, data.x_a.shape[0], data.x_b.shape[0]
            )

        # Set the environment shape to the batch size
        self.batch_size = (params.batch_size,)

        # The spec for the observation space: agents see the adjacency matrix and the
        # messages sent so far
        self.observation_spec = CompositeSpec(
            adjacency=AdjacencyMatrixSpec(
                self.max_num_nodes,
                shape=(params.batch_size, 2, self.max_num_nodes, self.max_num_nodes),
            ),
            message=BinaryDiscreteTensorSpec(
                params.max_message_rounds,
                shape=(
                    params.batch_size,
                    2,
                    self.max_num_nodes,
                    params.max_message_rounds,
                ),
            ),
            shape=(params.batch_size,),
        )

        # The spec for the state space: the true label and the round number
        self.state_spec = CompositeSpec(
            y=BinaryDiscreteTensorSpec(1, shape=(params.batch_size, 1)),
            round=DiscreteTensorSpec(
                params.max_message_rounds, shape=(params.batch_size,)
            ),
            shape=(params.batch_size,),
        )

        # The action space has shape (batch_size, num_agents, num_actions). Each agent
        # chooses both a node and a decision (accept, reject or continue). We add a
        # dummy node at the beginning (shifting the indices by 1) to represent the case
        # where the agent does not choose any node (when it's not their turn).
        self.action_spec = CompositeSpec(
            agents=CompositeSpec(
                action=MultiDiscreteTensorSpec(
                    [self.max_num_nodes + 1, 3], shape=(params.batch_size, 2, 2)
                ),
                shape=(params.batch_size,),
            ),
            shape=(params.batch_size,),
        )

        self.reward_spec = CompositeSpec(
            agents=CompositeSpec(
                reward=UnboundedContinuousTensorSpec(shape=(params.batch_size, 2)),
                shape=(params.batch_size,),
            ),
            shape=(params.batch_size,),
        )

        self.done_spec = CompositeSpec(
            done=BinaryDiscreteTensorSpec(
                params.batch_size, shape=(params.batch_size,), dtype=torch.bool
            ),
            shape=(params.batch_size,),
        )

    def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
        """Perform a step in the environment."""
        # Extract the tensors from the dict
        y: Int[Tensor, "batch"] = tensordict["y"]
        message: Int[Tensor, "batch graph node message_round"] = tensordict["message"]
        round: Int[Tensor, "batch"] = tensordict["round"]
        action: Int[Tensor, "batch agent action"] = tensordict["agents", "action"]
        done: Bool[Tensor, "batch"] = tensordict["done"]
        agent_index = round % 2
        node_selected = action[..., 0]
        decision = action[..., 1]

        # Determine which graph contains the selected node and which node it is there
        # (batch agent)
        which_graph = node_selected < self.max_num_nodes
        # (batch agent)
        graph_node = torch.where(
            which_graph, node_selected, node_selected - self.max_num_nodes
        )

        # Write the node selected by the agent whose turn it is as a (one-hot) message
        message[
            agent_index,
            which_graph[torch.arange(which_graph.shape[0]), agent_index].int(),
            graph_node[torch.arange(which_graph.shape[0]), agent_index],
            round,
        ] = 1

        # If the verifier has made a guess, compute the reward and terminate the episode
        verifier_decision_made = (agent_index == 0) & (decision[:, 0] != 0)
        done = done | verifier_decision_made
        reward_verifier = (done & decision[:, 0] == y).float()
        reward_verifier = reward_verifier * self.params.verifier_reward
        reward_prover = (done & decision[:, 0] == 1).float()
        reward_prover = reward_prover * self.params.prover_reward

        # If we reach the end of the episode and the verifier has not made a guess,
        # terminate it with a negative reward for the verifier
        done = done | (round >= self.params.max_message_rounds - 1)
        reward_verifier[
            (round >= self.params.max_message_rounds - 1) & ~verifier_decision_made
        ] = self.params.verifier_terminated_penalty

        # Stack the rewards for the two agents
        reward = torch.stack([reward_verifier, reward_prover], dim=-1)

        # Put everything together
        out = TensorDict(
            adjacency=tensordict["adjacency"],
            message=message,
            y=y,
            round=round + 1,
            done=done,
            reward=reward,
        )
        return out

    def _reset(self, tensordict: Optional[TensorDictBase] = None) -> TensorDictBase:
        """(Partially) reset the environment.

        For each episode which is done, takes a new sample from the dataset and resets
        the episode.
        """

        # If no tensordict is given, we're starting afresh
        if tensordict is None:
            tensordict = TensorDict(
                dict(
                    adjacency=torch.empty(
                        *self.batch_size,
                        2,
                        self.max_num_nodes,
                        self.max_num_nodes,
                        device=self.device,
                        dtype=torch.int,
                    ),
                    message=torch.empty(
                        *self.batch_size,
                        2,
                        self.max_num_nodes,
                        self.params.max_message_rounds,
                        device=self.device,
                        dtype=torch.int,
                    ),
                    y=torch.empty(
                        *self.batch_size,
                        1,
                        device=self.device,
                        dtype=torch.int,
                    ),
                    round=torch.empty(
                        *self.batch_size,
                        device=self.device,
                        dtype=torch.int,
                    ),
                    done=torch.empty(
                        *self.batch_size,
                        device=self.device,
                        dtype=torch.bool,
                    ),
                ),
                batch_size=self.batch_size,
            )

            new_mask = torch.ones(
                *self.batch_size, dtype=torch.bool, device=self.device
            )

        else:
            new_mask = tensordict["done"]

        # If we don't have a data cycler yet, create one
        if self.data_cycler is None:
            dataloader = GeometricDataLoader(
                self.dataset,
                batch_size=self.params.batch_size,
                follow_batch=["x_a", "x_b"],
                shuffle=True,
                generator=self.rng,
            )
            self.data_cycler = VariableGeometricDataCycler(dataloader)

        # Sample a new batch of data for the episodes that are done
        batch = self.data_cycler.get_batch(new_mask.sum().item())
        batch_tensordict = gi_data_to_tensordict(
            batch, node_dim_size=self.max_num_nodes
        )

        # Copy the new data into the output
        tensordict["adjacency"][new_mask] = batch_tensordict["adjacency"]
        tensordict["message"][new_mask] = torch.zeros_like(
            tensordict["message"][new_mask]
        )
        tensordict["y"][new_mask] = (batch.wl_score == -1).int().unsqueeze(-1)
        tensordict["round"][new_mask] = 0
        tensordict["done"][new_mask] = False

        return tensordict

    def _set_seed(self, seed: int | None):
        self.rng.manual_seed(seed)

## Test environment

In [55]:
params = Parameters(
    scenario="graph_isomorphism",
    trainer="ppo",
    dataset="eru10000",
)

In [56]:
env = GraphIsomorphismEnv(params)

In [57]:
check_env_specs(env)

RuntimeError: stack expects each tensor to be equal size, but got [64, 64] at entry 0 and [64] at entry 1