# MatNet Model

In [39]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(2*"../")

from einops import rearrange, repeat
import math
from dataclasses import dataclass
from omegaconf import OmegaConf, DictConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchrl.modules.models.models import MLP
from torchrl.envs import EnvBase
from tensordict.tensordict import TensorDict

from rl4co.envs import ATSPEnv 
from rl4co.utils.ops import batchify, unbatchify, select_start_nodes
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
from rl4co.models.nn.env_context import env_context
from rl4co.models.nn.env_embedding import env_dynamic_embedding, env_init_embedding
from rl4co.data.dataset import tensordict_collate_fn
from rl4co.tasks.rl4co import RL4COLitModule

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Differences between AM and MatNet

1. MatNet uses a dual graph attention layer for processing the  set of source and destination nodes A and B separately
2. Mixed-score attention: this should make the network learn the "best" recipe
3. Initial node representation: zero-vectors for A nodes and one-hot vectors for B nodes

In [2]:
env = ATSPEnv(num_loc=10)
env.reset()

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.bool, is_shared=False),
        cost_matrix: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:
# Test out

# col_emb.shape: (batch, col_cnt, embedding)
# row_emb.shape: (batch, row_cnt, embedding)
# cost_mat.shape: (batch, row_cnt, col_cnt)
batch = 64
row_cnt = 20
col_cnt = 30


model_params = {
    'embedding_dim': 256,
    'sqrt_embedding_dim': 256**(1/2),
    'encoder_layer_num': 5,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 16,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'argmax',
    'one_hot_seed_cnt': 20,  # must be >= node_cnt
}


row_emb = torch.randn(batch, row_cnt, model_params['embedding_dim'])
col_emb = torch.randn(batch, col_cnt, model_params['embedding_dim'])
cost_mat = torch.randn(batch, row_cnt, col_cnt)

## Ours (all we need)

In [4]:
class MixedScoreMHA(nn.Module):
    def __init__(self, 
                    embed_dim,
                    num_heads,
                    hidden_dim: int = 16,
                    qkv_dim: int = 16,
                    bias=False,
                    layer1_init: float = (1/2)**(1/2),
                    layer2_init: float = (1/16)**(1/2),
                    device=None,
                    dtype=None
        ):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        assert (embed_dim % num_heads == 0), "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        # Project
        self.Wq = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wk = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wv = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.out_proj = nn.Linear(num_heads*qkv_dim, embed_dim, **factory_kwargs)

        # Init mix params
        self.mix1_weight = nn.Parameter(torch.empty(num_heads, 2, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix1_bias = nn.Parameter(torch.empty(num_heads, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix2_weight = nn.Parameter(torch.empty(num_heads, hidden_dim, 1).uniform_(-layer2_init, layer2_init))
        self.mix2_bias = nn.Parameter(torch.empty(num_heads, 1).uniform_(-layer2_init, layer2_init))

    def forward(self, q, k, v, matrix):
        # Project q, k, v and reshape to [batch, head_num, row_cnt, hidden_dim]
        # q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
        # q, k, v = map(lambda t: self._reshape_heads(t), (q, k, v))
        q = self._make_heads(self.Wq(q))
        k = self._make_heads(self.Wk(k))
        v = self._make_heads(self.Wv(v))

        # Prepare dot product and matrix score: [batch, head_num, row_cnt, col_cnt]
        dot_product = torch.einsum('...rd,...cd->...rc', q, k) / math.sqrt(q.shape[-1])
        matrix_score = repeat(matrix, 'b r c -> b h r c', h=self.num_heads)

        # Mix the scores. Use einsum for best performance
        two_scores = torch.stack((dot_product, matrix_score), dim=-1)
        ms1 = torch.einsum('bhrct,htd->brhcd', two_scores, self.mix1_weight)
        ms2 = torch.einsum('brhcd,hdt->brhct', F.relu(ms1), self.mix2_weight)
        mixed_scores = rearrange(ms2, 'b h r c 1 -> b r h c')

        # Softmax and multiply with values
        weights = F.softmax(mixed_scores, dim=3)
        out = torch.matmul(weights, v)
        
        # Project out
        out = rearrange(out, 'b h r d -> b r (h d)')
        return self.out_proj(out)

    def _make_heads(self, x):
        return rearrange(x, "b g (h s) -> b h g s", h=self.num_heads)
         
    
class AddAndInstanceNormalization(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # [batch, problem, embedding]
        added = input1 + input2
        normalized = self.norm(added.transpose(1, 2)).transpose(1, 2)
        return normalized
    
class EncodingBlock(nn.Module):
    def __init__(self,
                    embed_dim,
                    num_heads,
                    ms_hidden_dim=16,
                    ff_hidden_dim=512,
                    **mha_kwargs
                 ):
        super().__init__()
        self.mixed_score_mha = MixedScoreMHA(embed_dim, num_heads, ms_hidden_dim, **mha_kwargs)
        self.add_n_normalization_1 = AddAndInstanceNormalization(embed_dim)
        self.feed_forward = MLP(embed_dim, embed_dim, 1, ff_hidden_dim, activation_class=nn.ReLU)
        self.add_n_normalization_2 = AddAndInstanceNormalization(embed_dim)

    def forward(self, row_emb, col_emb, cost_mat):
        q, k, v = row_emb, col_emb, col_emb
        out_mha = self.mixed_score_mha(q, k, v, cost_mat)
        out1 = self.add_n_normalization_1(row_emb, out_mha)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)
        return out3 # shape: (batch, row_cnt, embedding)
    

class EncoderLayer(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**kw)
        self.col_encoding_block = EncodingBlock(**kw)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))
        return row_emb_out, col_emb_out
        

class MatNetEncoder(nn.Module):
    def __init__(self, num_layers, **kw):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(**kw) for _ in range(num_layers)])

    def forward(self, row_emb, col_emb, cost_mat):
        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)
        return row_emb, col_emb

In [5]:
# encoder = MatNetEncoder(**model_params)

encoder = MatNetEncoder(num_layers=5, embed_dim=256, num_heads=16)
print('Number of parameters: {:.2f} MB'.format(sum(p.numel() for p in encoder.parameters() if p.requires_grad) / 1e6))
out = encoder(row_emb, col_emb, cost_mat)
print(out[0].shape, out[1].shape)
#print number of parameters


Number of parameters: 5.27 MB
torch.Size([64, 20, 256]) torch.Size([64, 30, 256])


## Create Decoder

In [24]:
@dataclass
class PrecomputedCache:
    row_embeddings: torch.Tensor
    # graph_context: torch.Tensor # TODO: check if used in MatNet
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_starts=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value ) so 2 * embedding_dim
        # In original implementation, this is separated but this is the same
        # Note that compared to original AM, we do not project the logit key
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 2 * embedding_dim, bias=False
        )
    
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_starts = max(num_starts, 1) # POMO = 1 is just normal REINFORCE

    def forward(
        self,
        td,
        embeddings,
        decode_type="sampling",
        softmax_temp=None,
        single_traj=False,
        num_starts=None,
    ):

        # Greedy multi-start decoding if num_starts > 1
        num_starts = (
            self.num_starts if num_starts is None else num_starts
        )  # substitute self.num_starts with num_starts
        assert not (
            "multistart" in decode_type and num_starts <= 1
        ), "Multi-start decoding requires `num_starts` > 1"

        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)

        # Collect outputs
        outputs = []
        actions = []

        if num_starts > 1 and not single_traj or "multistart" in decode_type:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(td, num_starts, self.env)

            # Expand td to batch_size * num_starts
            td = batchify(td, num_starts)

            td.set("action", action)
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(
                td["action_mask"], device=td.device
            )  # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p)
            actions.append(action)

        # Main decoding
        while not td["done"].all():

            log_p, mask = self._get_log_p(cached_embeds, td, softmax_temp, num_starts)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(log_p.exp(), mask, decode_type=decode_type)

            td.set("action", action)
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p)
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):
        # The projection of the node embeddings for the attention is calculated once up front
        row_embed, col_embed = embeddings
    
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
        ) = self.project_node_embeddings(col_embed).chunk(2, dim=-1)
        
        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            row_embeddings=row_embed,
            glimpse_key=glimpse_key_fixed,
            glimpse_val=glimpse_val_fixed,
            logit_key=col_embed,
        )

        return cached_embeds
    

    def _get_log_p(self, cached, td, softmax_temp=None, num_starts=0):
        # Compute the query based on the context (computes automatically the first and last node context)

        # Unbatchify to [batch_size, num_starts, ...]. Has no effect if num_starts = 0
        td_unbatch = unbatchify(td, num_starts)

        step_context = self.context(cached.row_embeddings, td_unbatch)
        glimpse_q = step_context  # in POMO, no graph context is used to compute query
        glimpse_q = glimpse_q.unsqueeze(1) if glimpse_q.ndim == 2 else glimpse_q

        # Compute keys and values for the nodes
        (
            glimpse_key_dynamic,
            glimpse_val_dynamic,
            logit_key_dynamic,
        ) = self.dynamic_embedding(td_unbatch)
        glimpse_k = cached.glimpse_key + glimpse_key_dynamic
        glimpse_v = cached.glimpse_val + glimpse_val_dynamic
        logit_k = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td_unbatch["action_mask"]

        # Compute logits
        log_p = self.logit_attention(
            glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp
        )

        # Now we need to reshape the logits and log_p to [batch_size*num_starts, num_nodes]
        # Note that rearranging order is important here
        log_p = rearrange(log_p, "b s l -> (s b) l") if num_starts > 1 else log_p
        mask = rearrange(mask, "b s l -> (s b) l") if num_starts > 1 else mask
        return log_p, mask


In [25]:
class MatNetInitEmbedding(nn.Module):
    def __init__(self, embedding_dim, one_hot_seed_cnt):
        super(MatNetInitEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.one_hot_seed_cnt = one_hot_seed_cnt

    def forward(self, td):
        # Generate initial embeddings: [batch, node, node]
        cost_mat = td["cost_matrix"]
        batch_size = cost_mat.size(0)
        node_cnt = cost_mat.size(1)
        row_emb = torch.zeros((batch_size, node_cnt, self.embedding_dim), device=cost_mat.device)
        col_emb = torch.zeros((batch_size, node_cnt, self.embedding_dim), device=cost_mat.device)
        # randomize col_emb: we refactor with topk
        rand = torch.rand(batch_size, self.one_hot_seed_cnt, device=cost_mat.device)
        _, rand_idx = rand.topk(node_cnt, dim=1, largest=False)
        b_idx, n_idx = torch.meshgrid(torch.arange(batch_size, device=cost_mat.device),
                                    torch.arange(node_cnt, device=cost_mat.device))
        col_emb[b_idx, n_idx, rand_idx] = 1
        return row_emb, col_emb

In [26]:
embedding = MatNetInitEmbedding(128, 50)

td = TensorDict({"cost_matrix": torch.randn(100, 50, 50)}, batch_size=100)
a, b = embedding(td)
a.shape, b.shape

(torch.Size([100, 50, 128]), torch.Size([100, 50, 128]))

In [27]:
class MatNetPolicy(nn.Module):
    def __init__(
        self,
        env: EnvBase,
        encoder: nn.Module = None,
        decoder: nn.Module = None,
        embedding_dim: int = 256,
        num_starts: int = 10,
        one_hot_seed_cnt: int = 20,
        num_encode_layers: int = 5,
        num_heads: int = 16,
        mask_inner: bool = True,
        train_decode_type: str = "sampling",
        val_decode_type: str = "greedy",
        test_decode_type: str = "greedy",
        **unused_kwargs
    ):
        super(MatNetPolicy, self).__init__()

        if len(unused_kwargs) > 0:
            print("Unused kwargs found in MatNetPolicy init: ", unused_kwargs)

        self.env = env

        self.init_embedding = MatNetInitEmbedding(embedding_dim, one_hot_seed_cnt)

        self.encoder = (
            MatNetEncoder(
                num_heads=num_heads,
                embed_dim=embedding_dim,
                num_layers=num_encode_layers,
            )
            if encoder is None
            else encoder
        )

        self.decoder = (
            Decoder(
                env,
                embedding_dim,
                num_heads,
                num_starts=num_starts,
                mask_inner=mask_inner,
            )
            if decoder is None
            else decoder
        )
        self.num_starts = num_starts
        self.train_decode_type = train_decode_type
        self.val_decode_type = val_decode_type
        self.test_decode_type = test_decode_type

    def forward(
        self,
        td: TensorDict,
        phase: str = "train",
        return_actions: bool = False,
        **decoder_kwargs,
    ) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        row_emb, col_emb = self.init_embedding(td)
        encoded_inputs = self.encoder(row_emb, col_emb, td["cost_matrix"])
        
        # Get decode type depending on phase
        if decoder_kwargs.get("decode_type", None) is None:
            decoder_kwargs["decode_type"] = getattr(self, f"{phase}_decode_type")

        # Main rollout
        log_p, actions, td = self.decoder(td, encoded_inputs, **decoder_kwargs)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(log_p, actions, td.get("mask", None))
        out = {
            "reward": td["reward"],
            "log_likelihood": ll,
            "actions": actions if return_actions else None,
        }

        return out


## Test policy only

In [28]:
from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7eff8c1f45b0>>

In [50]:
env = ATSPEnv(num_loc=20)

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,  # no need to shuffle, we're resampling every epoch
    num_workers=0,
    collate_fn=tensordict_collate_fn,
)

policy = MatNetPolicy(
    env,
    num_starts=20,
).to("cuda")

# model = torch.compile(model)

td = next(iter(dataloader)).to("cuda")
td = env.reset(td)

out = policy(td, decode_type="sampling")

print(out)

{'reward': tensor([-2.7901, -3.0211, -4.2714,  ..., -2.7211, -3.7595, -3.9496],
       device='cuda:0'), 'log_likelihood': tensor([-40.8242, -34.1915, -34.9612,  ..., -34.8064, -35.7479, -31.4031],
       device='cuda:0', grad_fn=<SumBackward1>), 'actions': None}


## Create full MatNet: `env` + `policy` + `baseline`

In [51]:
from typing import Any, Optional, Tuple, Union

import lightning as L

from rl4co.utils.lightning import get_lightning_device
from rl4co.models.rl.reinforce.baselines import WarmupBaseline, RolloutBaseline, ExponentialBaseline, SharedBaseline
from rl4co.models.zoo.pomo import POMO


class MatNet(POMO):

    def __init__(
        self,
        env,
        policy=None,
        baseline=None,
        num_starts=10,
        num_augment=0,
        **policy_kwargs
    ):
        super(POMO, self).__init__(env, policy, baseline)
        self.policy = (
            MatNetPolicy(self.env, num_starts=num_starts, **policy_kwargs)
            if policy is None
            else policy
        )

        self.baseline = SharedBaseline() if baseline is None else baseline

        # POMO parameters
        self.num_augment = num_augment
        self.augment = None # TODO: remove from POMO?
        

: 

In [47]:

model = MatNet(
    env,
    policy,
    # baseline=baseline,
).to("cuda")


td = next(iter(dataloader)).to("cuda")
td = env.reset(td)

out = model(td, decode_type="sampling")

print(out)

{'reward': tensor([-4.0538, -4.0305, -3.1932,  ..., -3.9774, -3.8262, -3.2630],
       device='cuda:0'), 'log_likelihood': tensor([-33.7423, -40.1478, -33.1871,  ..., -34.4182, -31.3874, -35.5232],
       device='cuda:0', grad_fn=<SumBackward1>), 'actions': None, 'loss': tensor(0.0009, device='cuda:0', grad_fn=<SubBackward0>), 'reinforce_loss': tensor(0.0009, device='cuda:0', grad_fn=<NegBackward0>), 'bl_loss': 0, 'bl_val': tensor([[-4.0845],
        [-4.1233],
        [-3.5481],
        [-3.4527],
        [-2.4980],
        [-3.0476],
        [-4.1127],
        [-3.8137],
        [-3.4986],
        [-3.9167],
        [-4.8775],
        [-2.4045],
        [-3.0212],
        [-4.1176],
        [-3.0078],
        [-3.0771],
        [-4.5094],
        [-3.5632],
        [-3.5146],
        [-3.9960],
        [-3.2412],
        [-4.0907],
        [-2.6833],
        [-3.0754],
        [-2.9327],
        [-3.5346],
        [-3.1042],
        [-3.6088],
        [-3.1072],
        [-3.2694],
  

## Config

In [48]:

config = DictConfig(
    {
        "data": {
            "train_size": 100000, # with 1 epochs, this is 1k samples
            "val_size": 10000, 
            "batch_size": 64, #64,
        },
        "optim": {
            "lr": 1e-4,
            "weight_decay": 1e-6,
        },
        "num_epochs": 10,        
    }
)


lit_module = RL4COLitModule(cfg=config, env=env, model=model)

In [49]:
# Trainer
trainer = L.Trainer(
    max_epochs=config.num_epochs, # only few epochs
    accelerator="gpu", # use GPU if available, else you can use others as "cpu"
    devices=[0], # GPU number, or multiple GPUs [0, 1, 2, ...]
    logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
    precision="16-mixed", # Lightning will handle faster training with mixed precision
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients
    reload_dataloaders_every_n_epochs=1, # necessary for sampling new data
)

# Fit the model
trainer.fit(lit_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
No optimizer specified, using default

  | Name  | Type    | Params
----------------------------------
0 | env   | ATSPEnv | 0     
1 | model | MatNet  | 5.7 M 
----------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.670    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
