In [2]:
%load_ext autoreload
%autoreload 2

import sys; sys.path.append('../../')
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from tensordict import TensorDict
from torch.utils.data import DataLoader
from torchrl.envs import EnvBase

from rl4co.envs import TSPEnv
from rl4co.data.dataset import TensorDictCollate
from rl4co.models.nn.env_context import env_context
from rl4co.models.nn.env_embedding import env_init_embedding
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
from rl4co.models.zoo.mdam.encoder import GraphAttentionEncoder

  from .autonotebook import tqdm as notebook_tqdm


### ***1. Create environment***

Test on the TSP environment.

In [4]:
def random_policy(td):
    """Helper function to select a random action from available actions"""
    action = torch.multinomial(td["action_mask"].float(), 1).squeeze(-1)
    td.set("action", action)
    return td


def rollout(env, td, policy):
    """Helper function to rollout a policy"""
    actions = []
    while not td["done"].all():
        td = policy(td)
        actions.append(td["action"])
        td = env.step(td)["next"]
    actions = torch.stack(actions, dim=1)
    td.set("action", actions)
    return td

env = TSPEnv()

td = env.reset(batch_size=[5])
td = rollout(env, td, random_policy)
print(td)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 20]), device=cpu, dtype=torch.int64, is_shared=False),
        action_mask: Tensor(shape=torch.Size([5, 20]), device=cpu, dtype=torch.bool, is_shared=False),
        current_node: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([5]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        locs: Tensor(shape=torch.Size([5, 20, 2]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=cpu,
    is_shared=False)


### ***2. Encoder***

In MDAM model, the encoder is different with our `GraphAttentionEncoder` which is shared with model models. Now I just ues the MDAM's model and make it works, later we can reconstruct it. 

Since the code is pretty long, I put it in the `rl4co/models/zoo/mdam/encoder.py`.

### ***3. Decoder***

For now the decoder is inside the policy.

### ***4. Policy***

This is the whole model combines the encoder and decoder. I didn't modify this policy a lot. For now, let's keep the decoder inside the policy.

Discussion points:
- Different with other model, the MDAM requiresl a series of tensordict input. Do this inside the model or outside the model
- Normalize for the function `get_the_log_p`;

TODO list:
- [x] init embedding for the MDAM model -> [DONE]: checked it's fine to use ours;
- [ ] split the decoder and the policy, the hard part is the decoder will call the encoder, so this may be a problem; actually not, the embedding in the decoder is used repeatly, so theoritically we don't need to call the encoder inside the decoder;


In [7]:
@dataclass
class PrecomputedCache:
    node_embeddings: torch.Tensor
    graph_context: torch.Tensor
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor

class AttentionModelPolicy(nn.Module):
    def __init__(
        self,
        env: EnvBase,
        encoder: nn.Module = None,
        decoder: nn.Module = None,
        embedding_dim: int = 128,
        num_encode_layers: int = 3,
        num_heads: int = 8,
        num_paths: int = 5,
        eg_step_gap: int = 200,
        normalization: str = "batch",
        mask_inner: bool = True,
        force_flash_attn: bool = False,
        train_decode_type: str = "sampling",
        val_decode_type: str = "greedy",
        test_decode_type: str = "greedy",
        **unused_kw
    ):
        super(AttentionModelPolicy, self).__init__()
        if len(unused_kw) > 0: print(f"Unused kwargs: {unused_kw}")

        self.env = env
        self.init_embedding = env_init_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        self.encoder = (
            GraphAttentionEncoder(
                num_heads=num_heads,
                embed_dim=embedding_dim,
                num_layers=num_encode_layers,
                normalization=normalization,
                force_flash_attn=force_flash_attn,
            )
            if encoder is None
            else encoder
        )

        self.train_decode_type = train_decode_type
        self.val_decode_type = val_decode_type
        self.test_decode_type = test_decode_type
        
        # [Decoder]
        self.context = [
            env_context(self.env.name, {"embedding_dim": embedding_dim}) for _ in range(self.num_path)
        ]
        self.logit_attention = LogitAttention(
            embedding_dim, 
            num_heads, 
            mask_inner=mask_inner,
            force_flash_attn=force_flash_attn,
        )

        # TODO: other features
        self.num_path = num_paths
        self.eg_step_gap = eg_step_gap

    def forward(
        self,
        td: TensorDict,
        phase: str = "train",
        return_actions: bool = False,
        **decoder_kwargs,
    ) -> TensorDict:
        # SECTION: Encode and get embeddings
        embedding = self.init_embedding(td)
        encoded_inputs, _, attn, V, h_old = self.encoder(embedding)

        # Get decode type depending on phase
        if decoder_kwargs.get("decode_type", None) is None:
            decoder_kwargs["decode_type"] = getattr(self, f"{phase}_decode_type")

        # SECTION: Decoder first step: calculate for the decoder divergence loss
        # Cost list and log likelihood list along with path
        output_list = []
        td_list = [self.env.reset(td) for i in range(self.num_paths)]
        for i in range(self.num_paths):  
            # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
            fixed = self._precompute(encoded_inputs, path_index=i)
            log_p, _ = self._get_log_p(fixed, td_list[i], i)

            # Collect output of step
            output_list.append(log_p[:, 0, :]) # TODO: for vrp, ignore the first one (depot)
            output_list[-1] = torch.max(output_list[-1], torch.ones(output_list[-1].shape, dtype=output_list[-1].dtype, device=outputs[-1].device) * (-1e9)) # for the kl loss

        if self.num_paths > 1: # TODO: add a check for the baseline
            kl_divergences = []
            for _i in range(self.num_paths):
                for _j in range(self.num_paths):
                    if _i==_j:
                        continue
                    kl_divergence = torch.sum(torch.exp(output_list[_i]) * (output_list[_i] - output_list[_j]), -1)
                    kl_divergences.append(kl_divergence)
            loss_kl_divergence = torch.stack(kl_divergences, 0).mean()

        # SECTION: Decoder rest step: calculate for other decoder divergence loss
        # Cost list and log likelihood list along with path
        reward_list = []; output_list = []; action_list = []; ll_list = []
        td_list = [self.env.reset(td) for _ in range(self.num_paths)]
        for i in range(self.num_paths):
            outputs, actions = [], []
            embeddings, _, attn, V, h_old = self.embedder(self._init_embed(td))
            fixed = self._precompute(embeddings, path_index=i)
            j = 0
            while not (self.shrink_size is None and td_list[i].all_finished()):
                if j > 1 and j % self.eg_step_gap == 0:
                    if not self.is_vrp:
                        mask_attn = mask ^ mask_first
                    else:
                        mask_attn = mask
                    embeddings, _ = self.embedder.change(attn, V, h_old, mask_attn, self.is_tsp)
                    fixed = self._precompute(embeddings, path_index=i)
                log_p, mask = self._get_log_p(fixed, td_list[i], i)
                if j == 0:
                    mask_first = mask

                # Select the indices of the next nodes in the sequences, result (batch_size) long
                action = decode_probs(log_p.exp()[:, 0, :], mask[:, 0, :], decode_type=decoder_kwargs["decode_type"])

                td_list[i].set("action", action)
                td_list[i] = self.env.step(td_list[i])["next"]

                # Collect output of step
                outputs.append(log_p[:, 0, :])
                actions.append(action)
                j += 1

            outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
            reward = self.env.get_reward(td, actions)
            ll = self.get_log_likelihood(outputs, actions, mask)

            reward_list.append(reward)
            output_list.append(outputs)
            action_list.append(actions)
            ll_list.append(ll)

        # SECTION: Policy output part
        out = {
            "reward": torch.stack(reward_list),
            "log_likelihood": torch.stack(ll_list),
            "kl_divergence": loss_kl_divergence,
            "actions": actions if return_actions else None,
        }
        return out

    def _precompute(self, embeddings, num_steps=1, path_index=None):
        ''' Decoder '''
        # The fixed context projection of the graph embedding is calculated only once for efficiency
        graph_embed = embeddings.mean(1)

        # Fixed context = (batch_size, 1, embed_dim) to make broadcastable with parallel timesteps
        fixed_context = self.project_fixed_context[path_index](graph_embed)[:, None, :]

        # The projection of the node embeddings for the attention is calculated once up front
        glimpse_key_fixed, glimpse_val_fixed, logit_key_fixed = \
            self.project_node_embeddings[path_index](embeddings[:, None, :, :]).chunk(3, dim=-1)

        fixed = PrecomputedCache(
            node_embeddings=embeddings,
            graph_context=fixed_context,
            glimpse_key=self._make_heads(glimpse_key_fixed, num_steps),
            glimpse_val=self._make_heads(glimpse_val_fixed, num_steps),
            logit_key=logit_key_fixed.contiguous(),
        )
        return fixed

    def _get_log_p(self, cached, td, path_idx, softmax_temp):
        step_context = self.context[path_idx](cached.node_embeddings, td)  # [batch, embed_dim]
        glimpse_q = (cached.graph_context + step_context).unsqueeze(1)  # [batch, 1, embed_dim]

        # Compute keys and values for the nodes
        (
            glimpse_key_dynamic,
            glimpse_val_dynamic,
            logit_key_dynamic,
        ) = self.dynamic_embedding(td)
        glimpse_k = cached.glimpse_key + glimpse_key_dynamic
        glimpse_v = cached.glimpse_val + glimpse_val_dynamic
        logit_k = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]
        
        # Compute log prob: MHA + single-head attention
        log_p, _ = self._one_to_many_logits(
            glimpse_q,
            glimpse_k,
            glimpse_v,
            logit_k,
            mask,
            path_idx
        )

        return log_p, mask

    def _one_to_many_logits(self, query, glimpse_K, glimpse_V, logit_K, mask, path_index):
        batch_size, num_steps, embed_dim = query.size()
        key_size = val_size = embed_dim // self.n_heads

        # Compute the glimpse, rearrange dimensions so the dimensions are (n_heads, batch_size, num_steps, 1, key_size)
        glimpse_Q = query.view(batch_size, num_steps, self.n_heads, 1, key_size).permute(2, 0, 1, 3, 4)

        # Batch matrix multiplication to compute compatibilities (n_heads, batch_size, num_steps, graph_size)
        compatibility = torch.matmul(glimpse_Q, glimpse_K.transpose(-2, -1)) / math.sqrt(glimpse_Q.size(-1))
        if self.mask_inner:
            assert self.mask_logits, "Cannot mask inner without masking logits"
            compatibility[mask[None, :, :, None, :].expand_as(compatibility)] = -math.inf

        # Batch matrix multiplication to compute heads (n_heads, batch_size, num_steps, val_size)
        heads = torch.matmul(F.softmax(compatibility, dim=-1), glimpse_V)

        # Project to get glimpse/updated context node embedding (batch_size, num_steps, embedding_dim)
        glimpse = self.project_out[path_index](
            heads.permute(1, 2, 3, 0, 4).contiguous().view(-1, num_steps, 1, self.n_heads * val_size))

        # Now projecting the glimpse is not needed since this can be absorbed into project_out
        # final_Q = self.project_glimpse(glimpse)
        final_Q = glimpse

        # Batch matrix multiplication to compute logits (batch_size, num_steps, graph_size)
        # logits = 'compatibility'
        logits = torch.matmul(final_Q, logit_K.transpose(-2, -1)).squeeze(-2) / math.sqrt(final_Q.size(-1))

        # From the logits compute the probabilities by clipping, masking and softmax
        if self.tanh_clipping > 0:
            logits = F.tanh(logits) * self.tanh_clipping
        if self.mask_logits:
            logits[mask] = -math.inf

        return logits, glimpse.squeeze(-2)

### ***5. Test: Test the Policy Only***

In [6]:
# Load the environment with test data
env = TSPEnv()

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,
    num_workers=0,
    collate_fn=TensorDictCollate(),
)

td = next(iter(dataloader)).to("cuda")
td = env.reset(td)
print(td)

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([64, 20]), device=cuda:0, dtype=torch.bool, is_shared=True),
        current_node: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        done: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.bool, is_shared=True),
        first_node: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        i: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        locs: Tensor(shape=torch.Size([64, 20, 2]), device=cuda:0, dtype=torch.float32, is_shared=True),
        reward: Tensor(shape=torch.Size([64, 1]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([64]),
    device=cuda,
    is_shared=True)


In [8]:
policy = AttentionModelPolicy(
    env,
).to("cuda")
out = policy(td, decode_type="sampling", return_actions=False)
print(out)

AttributeError: 'AttentionModelPolicy' object has no attribute 'num_path'

### ***-1. Trash bin***

Other stuff

In [None]:
    # def _get_log_p(self, fixed, td, path_index, normalize=True):
    #     ''' Decoder '''
    #     # Compute query = context node embedding
    #     query = fixed.context_node_projected + \
    #         self.project_step_context[path_index](self._get_parallel_step_context(fixed.node_embeddings, td))

    #     # Compute keys and values for the nodes
    #     glimpse_K, glimpse_V, logit_K = self._get_attention_node_data(fixed, td)

    #     # Compute the mask
    #     mask = ~td["action_mask"]

    #     # Compute logits (unnormalized log_p)
    #     log_p, _ = self._one_to_many_logits(query, glimpse_K, glimpse_V, logit_K, mask, path_index)

    #     if normalize:
    #         log_p = F.log_softmax(log_p / self.temp, dim=-1)

    #     assert not torch.isnan(log_p).any()

    #     return log_p, mask
    
    # def _get_parallel_step_context(self, embeddings, td, from_depot=False):
    #     """
    #     Returns the context per step, optionally for multiple steps at once (for efficient evaluation of the model)
        
    #     :param embeddings: (batch_size, graph_size, embed_dim)
    #     :param prev_a: (batch_size, num_steps)
    #     :param first_a: Only used when num_steps = 1, action of first step or None if first step
    #     :return: (batch_size, num_steps, context_dim)
    #     """

    #     current_node = td["current_node"]
    #     batch_size, num_steps = current_node.size()

    #     if self.env.name == 'vrp':
    #         # Embedding of previous node + remaining capacity
    #         if from_depot:
    #             # 1st dimension is node idx, but we do not squeeze it since we want to insert step dimension
    #             # i.e. we actually want embeddings[:, 0, :][:, None, :] which is equivalent
    #             return torch.cat(
    #                 (
    #                     embeddings[:, 0:1, :].expand(batch_size, num_steps, embeddings.size(-1)),
    #                     # used capacity is 0 after visiting depot
    #                     self.problem.VEHICLE_CAPACITY - torch.zeros_like(td["used_capacity"][:, :, None])
    #                 ),
    #                 -1
    #             )
    #         else:
    #             return torch.cat(
    #                 (
    #                     torch.gather(
    #                         embeddings,
    #                         1,
    #                         current_node.contiguous()
    #                             .view(batch_size, num_steps, 1)
    #                             .expand(batch_size, num_steps, embeddings.size(-1))
    #                     ).view(batch_size, num_steps, embeddings.size(-1)),
    #                     self.problem.VEHICLE_CAPACITY - td["used_capacity"][:, :, None]
    #                 ),
    #                 -1
    #             )
    #     elif self.is_orienteering or self.env.name == "pctsp":
    #         return torch.cat(
    #             (
    #                 torch.gather(
    #                     embeddings,
    #                     1,
    #                     current_node.contiguous()
    #                         .view(batch_size, num_steps, 1)
    #                         .expand(batch_size, num_steps, embeddings.size(-1))
    #                 ).view(batch_size, num_steps, embeddings.size(-1)),
    #                 (
    #                     td["capacity_length"][:, :, None]
    #                     # TODO
    #                     # if self.is_orienteering
    #                     # else td[""].get_remaining_prize_to_collect()[:, :, None]
    #                 )
    #             ),
    #             -1
    #         )
    #     else:  # TSP
    #         if num_steps == 1:  # We need to special case if we have only 1 step, may be the first or not
    #             if state.i.item() == 0:
    #                 # First and only step, ignore prev_a (this is a placeholder)
    #                 return self.W_placeholder[None, None, :].expand(batch_size, 1, self.W_placeholder.size(-1))
    #             else:
    #                 return embeddings.gather(
    #                     1,
    #                     torch.cat((state.first_a, current_node), 1)[:, :, None].expand(batch_size, 2, embeddings.size(-1))
    #                 ).view(batch_size, 1, -1)
    #         # More than one step, assume always starting with first
    #         embeddings_per_step = embeddings.gather(
    #             1,
    #             current_node[:, 1:, None].expand(batch_size, num_steps - 1, embeddings.size(-1))
    #         )
    #         return torch.cat((
    #             # First step placeholder, cat in dim 1 (time steps)
    #             self.W_placeholder[None, None, :].expand(batch_size, 1, self.W_placeholder.size(-1)),
    #             # Second step, concatenate embedding of first with embedding of current/previous (in dim 2, context dim)
    #             torch.cat((
    #                 embeddings_per_step[:, 0:1, :].expand(batch_size, num_steps - 1, embeddings.size(-1)),
    #                 embeddings_per_step
    #             ), 2)
    #         ), 1)