# MatNet Model

In [2]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(2*"../")

import torch
from collections import defaultdict
from typing import Optional

from rl4co.envs import TSPEnv, ATSPEnv 


## Differences between AM and MatNet

1. MatNet uses a dual graph attention layer for processing the  set of source and destination nodes A and B separately
2. Mixed-score attention: this should make the network learn the "best" recipe
3. Initial node representation: zero-vectors for A nodes and one-hot vectors for B nodes

In [3]:
env = ATSPEnv(num_loc=10)
env.reset()

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        observation: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

## MixedScore MHA

In [50]:
# Original


import torch
import torch.nn as nn
import torch.nn.functional as F


class AddAndInstanceNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        added = input1 + input2
        # shape: (batch, problem, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, problem)

        normalized = self.norm(transposed)
        # shape: (batch, embedding, problem)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, problem, embedding)

        return back_trans


class FeedForward(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))


class MixedScore_MultiHeadAttention(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params

        head_num = model_params['head_num']
        ms_hidden_dim = model_params['ms_hidden_dim']
        mix1_init = model_params['ms_layer1_init']
        mix2_init = model_params['ms_layer2_init']

        mix1_weight = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((head_num, 2, ms_hidden_dim))
        mix1_bias = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((head_num, ms_hidden_dim))
        self.mix1_weight = nn.Parameter(mix1_weight)
        # shape: (head, 2, ms_hidden)
        self.mix1_bias = nn.Parameter(mix1_bias)
        # shape: (head, ms_hidden)

        mix2_weight = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((head_num, ms_hidden_dim, 1))
        mix2_bias = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((head_num, 1))
        self.mix2_weight = nn.Parameter(mix2_weight)
        # shape: (head, ms_hidden, 1)
        self.mix2_bias = nn.Parameter(mix2_bias)
        # shape: (head, 1)

    def forward(self, q, k, v, cost_mat):
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        # k,v shape: (batch, head_num, col_cnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

        batch_size = q.size(0)
        row_cnt = q.size(2)
        col_cnt = k.size(2)

        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        sqrt_qkv_dim = self.model_params['sqrt_qkv_dim']

        dot_product = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, row_cnt, col_cnt)

        dot_product_score = dot_product / sqrt_qkv_dim
        # shape: (batch, head_num, row_cnt, col_cnt)

        cost_mat_score = cost_mat[:, None, :, :].expand(batch_size, head_num, row_cnt, col_cnt)
        # shape: (batch, head_num, row_cnt, col_cnt)

        two_scores = torch.stack((dot_product_score, cost_mat_score), dim=4)
        # shape: (batch, head_num, row_cnt, col_cnt, 2)

        two_scores_transposed = two_scores.transpose(1,2)
        # shape: (batch, row_cnt, head_num, col_cnt, 2)

        ms1 = torch.matmul(two_scores_transposed, self.mix1_weight)
        # shape: (batch, row_cnt, head_num, col_cnt, ms_hidden_dim)

        ms1 = ms1 + self.mix1_bias[None, None, :, None, :]
        # shape: (batch, row_cnt, head_num, col_cnt, ms_hidden_dim)

        ms1_activated = F.relu(ms1)

        ms2 = torch.matmul(ms1_activated, self.mix2_weight)
        # shape: (batch, row_cnt, head_num, col_cnt, 1)

        ms2 = ms2 + self.mix2_bias[None, None, :, None, :]
        # shape: (batch, row_cnt, head_num, col_cnt, 1)

        mixed_scores = ms2.transpose(1,2)
        # shape: (batch, head_num, row_cnt, col_cnt, 1)

        mixed_scores = mixed_scores.squeeze(4)
        # shape: (batch, head_num, row_cnt, col_cnt)

        weights = nn.Softmax(dim=3)(mixed_scores)
        # shape: (batch, head_num, row_cnt, col_cnt)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, row_cnt, qkv_dim)

        out_transposed = out.transpose(1, 2)
        # shape: (batch, row_cnt, head_num, qkv_dim)

        out_concat = out_transposed.reshape(batch_size, row_cnt, head_num * qkv_dim)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        return out_concat


In [52]:
# make axamples with
       # q shape: (batch, head_num, row_cnt, qkv_dim)
        # k,v shape: (batch, head_num, col_cnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

batch_size = 32
head_num = 8
row_cnt = 30
col_cnt = 20
qkv_dim = 64

q = torch.randn(batch_size, head_num, row_cnt, qkv_dim)
k = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
v = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
cost_mat = torch.randn(batch_size, row_cnt, col_cnt)

model_params = {
        'head_num': head_num,
        'qkv_dim': qkv_dim,
        'sqrt_qkv_dim': qkv_dim**(1/2),
        'ms_hidden_dim': 16,
        'ms_layer1_init': (1/2)**(1/2),
        'ms_layer2_init': (1/16)**(1/2)
        }

ms_mha_old = MixedScore_MultiHeadAttention(**model_params)
out = ms_mha_old(q, k, v, cost_mat)
print(out.shape)

torch.Size([32, 30, 512])


In [53]:
%timeit ms_mha_old(q, k, v, cost_mat)

781 µs ± 2.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Ours

In [77]:
from torchrl.modules.models import MLP


class MixedScoreMHA(nn.Module):
    def __init__(self, head_num = 8,
                    qkv_dim = 64,
                    ms_hidden_dim = 16,
                    sqrt_qkv_dim = qkv_dim**(1/2),
                    **kwargs     
            ):
        
        super().__init__()
        self.head_num = head_num
        self.qkv_dim = qkv_dim
        self.ms_hidden_dim = ms_hidden_dim
        self.sqrt_qkv_dim = sqrt_qkv_dim

        # NOTE: we refactor the code to use MLP, but we do not use the initialization MatNet used
        # this is not mentioned in the paper of course
        self.mlp  = MLP(2, 1, depth=1, num_cells=ms_hidden_dim, activation_class=nn.ReLU)

    def forward(self, q, k, v, cost_mat):
        dot_product = torch.matmul(q, k.transpose(2, 3))
        dot_product_score = dot_product / self.sqrt_qkv_dim
        cost_mat_score = cost_mat.unsqueeze(1).repeat(1, self.head_num, 1, 1)
        two_scores = torch.stack((dot_product_score, cost_mat_score), dim=-1).transpose(1, 2)
        # shape: (batch, row_cnt, head_num, col_cnt, 2)
        mixed_scores = self.mlp(two_scores).squeeze(-1).transpose(1, 2)
        # shape: (batch, head_num, row_cnt, col_cnt)
        weights = nn.Softmax(dim=3)(mixed_scores)
        out = torch.matmul(weights, v)
        out_concat = out.transpose(1, 2).reshape(q.size(0), q.size(2), self.head_num * self.qkv_dim)
        return out_concat

In [90]:
mlp  = MLP(2, 1, depth=1, num_cells=ms_hidden_dim, activation_class=nn.ReLU)

torchrl.modules.models.models.MLP

In [87]:
class MixedScoreMHA2(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params

        head_num = model_params['head_num']
        ms_hidden_dim = model_params['ms_hidden_dim']
        mix1_init = model_params['ms_layer1_init']
        mix2_init = model_params['ms_layer2_init']
        self.mlp  = MLP(2, 1, depth=1, num_cells=ms_hidden_dim, activation_class=nn.ReLU)
        


    def forward(self, q, k, v, cost_mat):
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        # k,v shape: (batch, head_num, col_cnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

        batch_size = q.size(0)
        row_cnt = q.size(2)
        col_cnt = k.size(2)

        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        sqrt_qkv_dim = self.model_params['sqrt_qkv_dim']

        dot_product = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, row_cnt, col_cnt)

        dot_product_score = dot_product / sqrt_qkv_dim
        # shape: (batch, head_num, row_cnt, col_cnt)

        cost_mat_score = cost_mat[:, None, :, :].expand(batch_size, head_num, row_cnt, col_cnt)
        # shape: (batch, head_num, row_cnt, col_cnt)

        two_scores = torch.stack((dot_product_score, cost_mat_score), dim=4)
        # shape: (batch, head_num, row_cnt, col_cnt, 2)

        two_scores_transposed = two_scores.transpose(1,2)
        # shape: (batch, row_cnt, head_num, col_cnt, 2)

        mixed_scores = self.mlp(two_scores).squeeze(-1)
        # print(mixed_scores.shape)

        weights = nn.Softmax(dim=3)(mixed_scores)
        # shape: (batch, head_num, row_cnt, col_cnt)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, row_cnt, qkv_dim)

        out_transposed = out.transpose(1, 2)
        # shape: (batch, row_cnt, head_num, qkv_dim)

        out_concat = out_transposed.reshape(batch_size, row_cnt, head_num * qkv_dim)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        return out_concat


In [88]:
# make axamples with
       # q shape: (batch, head_num, row_cnt, qkv_dim)
        # k,v shape: (batch, head_num, col_cnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

batch_size = 32
head_num = 8
row_cnt = 30
col_cnt = 20
qkv_dim = 64

q = torch.randn(batch_size, head_num, row_cnt, qkv_dim)
k = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
v = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
cost_mat = torch.randn(batch_size, row_cnt, col_cnt)

model_params = {
        'head_num': head_num,
        'qkv_dim': qkv_dim,
        'sqrt_qkv_dim': qkv_dim**(1/2),
        'ms_hidden_dim': 16,
        'ms_layer1_init': (1/2)**(1/2),
        'ms_layer2_init': (1/16)**(1/2)
        }

ms_mha = MixedScoreMHA2(**model_params)
out = ms_mha(q, k, v, cost_mat)
print(out.shape)

torch.Size([32, 30, 512])


In [89]:
%timeit ms_mha(q, k, v, cost_mat)

1.59 ms ± 5.52 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
