# MatNet Model

In [11]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(2*"../")

from einops import rearrange, repeat
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchrl.modules.models.models import MLP
from torchrl.envs import EnvBase
from tensordict.tensordict import TensorDict

from rl4co.envs import ATSPEnv 
from rl4co.utils.ops import batchify, unbatchify
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
from rl4co.models.nn.env_context import env_context
from rl4co.models.nn.env_embedding import env_dynamic_embedding, env_init_embedding
from rl4co.data.dataset import TensorDictCollate

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Differences between AM and MatNet

1. MatNet uses a dual graph attention layer for processing the  set of source and destination nodes A and B separately
2. Mixed-score attention: this should make the network learn the "best" recipe
3. Initial node representation: zero-vectors for A nodes and one-hot vectors for B nodes

In [2]:
env = ATSPEnv(num_loc=10)
env.reset()

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.bool, is_shared=False),
        cost_matrix: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:
# Test out

# col_emb.shape: (batch, col_cnt, embedding)
# row_emb.shape: (batch, row_cnt, embedding)
# cost_mat.shape: (batch, row_cnt, col_cnt)
batch = 64
row_cnt = 20
col_cnt = 30


model_params = {
    'embedding_dim': 256,
    'sqrt_embedding_dim': 256**(1/2),
    'encoder_layer_num': 5,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 16,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'argmax',
    'one_hot_seed_cnt': 20,  # must be >= node_cnt
}


row_emb = torch.randn(batch, row_cnt, model_params['embedding_dim'])
col_emb = torch.randn(batch, col_cnt, model_params['embedding_dim'])
cost_mat = torch.randn(batch, row_cnt, col_cnt)

## Ours (all we need)

In [4]:
class MixedScoreMHA(nn.Module):
    def __init__(self, 
                    embed_dim,
                    num_heads,
                    hidden_dim: int = 16,
                    qkv_dim: int = 16,
                    bias=False,
                    layer1_init: float = (1/2)**(1/2),
                    layer2_init: float = (1/16)**(1/2),
                    device=None,
                    dtype=None
        ):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        assert (embed_dim % num_heads == 0), "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        # Project
        self.Wq = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wk = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wv = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.out_proj = nn.Linear(num_heads*qkv_dim, embed_dim, **factory_kwargs)

        # Init mix params
        self.mix1_weight = nn.Parameter(torch.empty(num_heads, 2, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix1_bias = nn.Parameter(torch.empty(num_heads, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix2_weight = nn.Parameter(torch.empty(num_heads, hidden_dim, 1).uniform_(-layer2_init, layer2_init))
        self.mix2_bias = nn.Parameter(torch.empty(num_heads, 1).uniform_(-layer2_init, layer2_init))

    def forward(self, q, k, v, matrix):
        # Project q, k, v and reshape to [batch, head_num, row_cnt, hidden_dim]
        q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
        q, k, v = map(lambda t: self._reshape_heads(t), (q, k, v))

        # Prepare dot product and matrix score: [batch, head_num, row_cnt, col_cnt]
        dot_product = torch.einsum('...rd,...cd->...rc', q, k) / math.sqrt(q.shape[-1])
        matrix_score = repeat(matrix, 'b r c -> b h r c', h=self.num_heads)

        # Mix the scores. Use einsum for best performance
        two_scores = torch.stack((dot_product, matrix_score), dim=-1)
        ms1 = torch.einsum('bhrct,htd->brhcd', two_scores, self.mix1_weight)
        ms2 = torch.einsum('brhcd,hdt->brhct', F.relu(ms1), self.mix2_weight)
        mixed_scores = rearrange(ms2, 'b h r c 1 -> b r h c')

        # Softmax and multiply with values
        weights = F.softmax(mixed_scores, dim=3)
        out = torch.matmul(weights, v)
        
        # Project out
        out = rearrange(out, 'b h r d -> b r (h d)')
        return self.out_proj(out)
    
    def _reshape_heads(self, x):
        # same as rearrange(v, 'b r (h d) -> b h r d', h=self.num_heads) but faster
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).transpose(1, 2)
         
    
class AddAndInstanceNormalization(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # [batch, problem, embedding]
        added = input1 + input2
        normalized = self.norm(added.transpose(1, 2)).transpose(1, 2)
        return normalized
    
class EncodingBlock(nn.Module):
    def __init__(self,
                    embed_dim,
                    num_heads,
                    ms_hidden_dim=16,
                    ff_hidden_dim=512,
                    **mha_kwargs
                 ):
        super().__init__()
        self.mixed_score_mha = MixedScoreMHA(embed_dim, num_heads, ms_hidden_dim, **mha_kwargs)
        self.add_n_normalization_1 = AddAndInstanceNormalization(embed_dim)
        self.feed_forward = MLP(embed_dim, embed_dim, 1, ff_hidden_dim, activation_class=nn.ReLU)
        self.add_n_normalization_2 = AddAndInstanceNormalization(embed_dim)

    def forward(self, row_emb, col_emb, cost_mat):
        q, k, v = row_emb, col_emb, col_emb
        out_mha = self.mixed_score_mha(q, k, v, cost_mat)
        out1 = self.add_n_normalization_1(row_emb, out_mha)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)
        return out3 # shape: (batch, row_cnt, embedding)
    

class EncoderLayer(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**kw)
        self.col_encoding_block = EncodingBlock(**kw)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))
        return row_emb_out, col_emb_out
        

class MatNetEncoder(nn.Module):
    def __init__(self, num_layers, **kw):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(**kw) for _ in range(num_layers)])

    def forward(self, row_emb, col_emb, cost_mat):
        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)
        return row_emb, col_emb

In [5]:
# encoder = MatNetEncoder(**model_params)

encoder = MatNetEncoder(num_layers=5, embed_dim=256, num_heads=16)
print(encoder)
print('Number of parameters: {:.2f} MB'.format(sum(p.numel() for p in encoder.parameters() if p.requires_grad) / 1e6))
out = encoder(row_emb, col_emb, cost_mat)
print(out[0].shape, out[1].shape)
#print number of parameters


MatNetEncoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderLayer(
      (row_encoding_block): EncodingBlock(
        (mixed_score_mha): MixedScoreMHA(
          (Wq): Linear(in_features=256, out_features=256, bias=False)
          (Wk): Linear(in_features=256, out_features=256, bias=False)
          (Wv): Linear(in_features=256, out_features=256, bias=False)
          (out_proj): Linear(in_features=256, out_features=256, bias=True)
        )
        (add_n_normalization_1): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
        (feed_forward): MLP(
          (0): Linear(in_features=256, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=256, bias=True)
        )
        (add_n_normalization_2): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
     

In [6]:
%timeit encoder(row_emb, col_emb, cost_mat)

129 ms ± 1.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Create Decoder

In [7]:
def select_start_nodes(batch_size, num_nodes, device="cpu"):
    """Node selection strategy for POMO
    Selects different start nodes for each batch element
    """
    selected = torch.arange(num_nodes, device=device).repeat(batch_size)  # TODO: check
    return selected

In [8]:
@dataclass
class PrecomputedCache:
    row_embeddings: torch.Tensor
    column_embeddings: torch.Tensor
    graph_context: torch.Tensor # TODO: check if used in MatNet
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_pomo=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim
        self.n_heads = num_heads

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value, logit key) so 3 * embedding_dim
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 3 * embedding_dim, bias=False
        )
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_pomo = max(num_pomo, 1) # POMO = 1 is just normal REINFORCE

    def forward(self, td, embeddings, decode_type="sampling"):
        # Collect outputs
        outputs = []
        actions = []

        if self.num_pomo > 1:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(batch_size=td.shape[0], num_nodes=self.num_pomo, device=td.device)

            # # Expand td to batch_size * num_pomo
            td = batchify(td, self.num_pomo)

            td.set("action", action[:, None])
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(td['action_mask'], device=td.device) # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p.squeeze(1))
            actions.append(action)
        
        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)        

        # Here we suppose all the batch is done at the same time
        while not td["done"].all():  
            # Compute the logits for the next node
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(
                log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
            )

            # Step the environment
            td.set("action", action[:, None])
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):       
        # The projection of the node embeddings for the attention is calculated once up front
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
            logit_key_fixed,
        ) = self.project_node_embeddings(embeddings[:, None, :, :]).chunk(3, dim=-1)

        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            node_embeddings=batchify(embeddings, self.num_pomo),
            glimpse_key=batchify(self.logit_attention._make_heads(glimpse_key_fixed), self.num_pomo),
            glimpse_val=batchify(self.logit_attention._make_heads(glimpse_val_fixed), self.num_pomo),
            logit_key=batchify(logit_key_fixed, self.num_pomo)
        )

        return cached_embeds

    def _get_log_p(self, cached, td):
        # Compute the query based on the context (computes automatically the first and last node context)
        step_context = self.context(cached.node_embeddings, td)
        query = step_context # in POMO, no graph context (trick for overfit) # [batch, 1, embed_dim] # TODO: check if this is the same as POMO

        # Compute keys and values for the nodes
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td)
        glimpse_key = cached.glimpse_key + glimpse_key_dynamic
        glimpse_key = cached.glimpse_val + glimpse_val_dynamic
        logit_key = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]
        mask = mask.unsqueeze(1) if mask.dim() == 2 else mask

        # Compute logits
        log_p = self.logit_attention(query, glimpse_key, glimpse_key, logit_key, mask)

        return log_p, mask

In [71]:
# problems = reset_state.problems
# # problems.shape: (batch, node, node)

# batch_size = problems.size(0)
# node_cnt = problems.size(1)
# embedding_dim = self.model_params['embedding_dim']

# row_emb = torch.zeros(size=(batch_size, node_cnt, embedding_dim))
# # emb.shape: (batch, node, embedding)
# col_emb = torch.zeros(size=(batch_size, node_cnt, embedding_dim))
# # shape: (batch, node, embedding)

# seed_cnt = self.model_params['one_hot_seed_cnt']
# rand = torch.rand(batch_size, seed_cnt)
# batch_rand_perm = rand.argsort(dim=1)
# rand_idx = batch_rand_perm[:, :node_cnt]

# b_idx = torch.arange(batch_size)[:, None].expand(batch_size, node_cnt)
# n_idx = torch.arange(node_cnt)[None, :].expand(batch_size, node_cnt)
# col_emb[b_idx, n_idx, rand_idx] = 


class MatNetInitEmbedding(nn.Module):
    def __init__(self, embedding_dim, one_hot_seed_cnt):
        super(MatNetInitEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.one_hot_seed_cnt = one_hot_seed_cnt

    def forward(self, td):
        # Generate initial embeddings: [batch, node, node]
        cost_mat = td["cost_matrix"]
        batch_size = cost_mat.size(0)
        node_cnt = cost_mat.size(1)
        row_emb = torch.zeros_like(cost_mat, device=cost_mat.device)
        col_emb = torch.zeros_like(cost_mat, device=cost_mat.device)
        # randomize col_emb: we refactor with topk
        rand = torch.rand(batch_size, self.one_hot_seed_cnt, device=cost_mat.device)
        _, rand_idx = rand.topk(node_cnt, dim=1, largest=False)
        b_idx, n_idx = torch.meshgrid(torch.arange(batch_size, device=cost_mat.device),
                                    torch.arange(node_cnt, device=cost_mat.device))
        col_emb[b_idx, n_idx, rand_idx] = 1
        print("col_emb", col_emb.shape)
        print("row_emb", row_emb.shape)
        return row_emb, col_emb

In [73]:
class MatNetPolicy(nn.Module):
    def __init__(
        self,
        env: EnvBase,
        encoder: nn.Module = None,
        decoder: nn.Module = None,
        embedding_dim: int = 128,
        num_pomo: int = 10,
        one_hot_seed_cnt: int = 20,
        num_encode_layers: int = 5,
        num_heads: int = 16,
        mask_inner: bool = True,
        train_decode_type: str = "sampling",
        val_decode_type: str = "greedy",
        test_decode_type: str = "greedy",
        **unused_kwargs
    ):
        super(MatNetPolicy, self).__init__()

        if len(unused_kwargs) > 0:
            print("Unused kwargs found in MatNetPolicy init: ", unused_kwargs)

        self.env = env

        self.init_embedding = MatNetInitEmbedding(embedding_dim, one_hot_seed_cnt)

        self.encoder = (
            MatNetEncoder(
                num_heads=num_heads,
                embed_dim=embedding_dim,
                num_layers=num_encode_layers,
            )
            if encoder is None
            else encoder
        )

        self.decoder = (
            Decoder(
                env,
                embedding_dim,
                num_heads,
                num_pomo=num_pomo,
                mask_inner=mask_inner,
            )
            if decoder is None
            else decoder
        )
        self.num_pomo = num_pomo
        self.train_decode_type = train_decode_type
        self.val_decode_type = val_decode_type
        self.test_decode_type = test_decode_type

    def forward(
        self,
        td: TensorDict,
        phase: str = "train",
        return_actions: bool = False,
        **decoder_kwargs,
    ) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        row_emb, col_emb = self.init_embedding(td)
        encoded_inputs = self.encoder(row_emb, col_emb, td["cost_matrix"])

        # Get decode type depending on phase
        if decoder_kwargs.get("decode_type", None) is None:
            decoder_kwargs["decode_type"] = getattr(self, f"{phase}_decode_type")

        # Main rollout
        log_p, actions, td = self.decoder(td, encoded_inputs, **decoder_kwargs)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(log_p, actions, td.get("mask", None))
        out = {
            "reward": td["reward"],
            "log_likelihood": ll,
            "actions": actions if return_actions else None,
        }

        return out


## Test environment

In [74]:
env = ATSPEnv(num_loc=20)
env.name = "tsp" # TODO: make this automatic when creating env

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,  # no need to shuffle, we're resampling every epoch
    num_workers=0,
    collate_fn=TensorDictCollate(),
)

policy = MatNetPolicy(
    env,
).to("cuda")

# model = torch.compile(model)

td = next(iter(dataloader)).to("cuda")
td = env.reset(td)

out = policy(td, decode_type="sampling")

print(out)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (1280x20 and 128x256)