# MatNet Model

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(2*"../")

import torch
from collections import defaultdict
from typing import Optional

from rl4co.envs import TSPEnv, ATSPEnv 


  from .autonotebook import tqdm as notebook_tqdm


## Differences between AM and MatNet

1. MatNet uses a dual graph attention layer for processing the  set of source and destination nodes A and B separately
2. Mixed-score attention: this should make the network learn the "best" recipe
3. Initial node representation: zero-vectors for A nodes and one-hot vectors for B nodes

In [2]:
env = ATSPEnv(num_loc=10)
env.reset()

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.bool, is_shared=False),
        cost_matrix: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:

batch_size = 128
head_num = 8
row_cnt = 30
col_cnt = 20
qkv_dim = 64

q = torch.randn(batch_size, head_num, row_cnt, qkv_dim)
k = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
v = torch.randn(batch_size, head_num, col_cnt, qkv_dim)
cost_mat = torch.randn(batch_size, row_cnt, col_cnt)

model_params = {
        'head_num': head_num,
        'qkv_dim': qkv_dim,
        'sqrt_qkv_dim': qkv_dim**(1/2),
        'ms_hidden_dim': 16,
        'ms_layer1_init': (1/2)**(1/2),
        'ms_layer2_init': (1/16)**(1/2)
        }

## MixedScore MHA

In [4]:
# Original


import torch
import torch.nn as nn
import torch.nn.functional as F



class MixedScore_MultiHeadAttention(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params

        head_num = model_params['head_num']
        ms_hidden_dim = model_params['ms_hidden_dim']
        mix1_init = model_params['ms_layer1_init']
        mix2_init = model_params['ms_layer2_init']

        mix1_weight = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((head_num, 2, ms_hidden_dim))
        mix1_bias = torch.torch.distributions.Uniform(low=-mix1_init, high=mix1_init).sample((head_num, ms_hidden_dim))
        self.mix1_weight = nn.Parameter(mix1_weight)
        # shape: (head, 2, ms_hidden)
        self.mix1_bias = nn.Parameter(mix1_bias)
        # shape: (head, ms_hidden)

        mix2_weight = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((head_num, ms_hidden_dim, 1))
        mix2_bias = torch.torch.distributions.Uniform(low=-mix2_init, high=mix2_init).sample((head_num, 1))
        self.mix2_weight = nn.Parameter(mix2_weight)
        # shape: (head, ms_hidden, 1)
        self.mix2_bias = nn.Parameter(mix2_bias)
        # shape: (head, 1)

    def forward(self, q, k, v, cost_mat):
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        # k,v shape: (batch, head_num, col_cnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

        batch_size = q.size(0)
        row_cnt = q.size(2)
        col_cnt = k.size(2)

        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']
        sqrt_qkv_dim = self.model_params['sqrt_qkv_dim']

        dot_product = torch.matmul(q, k.transpose(2, 3))
        # shape: (batch, head_num, row_cnt, col_cnt)

        dot_product_score = dot_product / sqrt_qkv_dim
        # shape: (batch, head_num, row_cnt, col_cnt)

        cost_mat_score = cost_mat[:, None, :, :].expand(batch_size, head_num, row_cnt, col_cnt)
        # shape: (batch, head_num, row_cnt, col_cnt)

        two_scores = torch.stack((dot_product_score, cost_mat_score), dim=4)
        # shape: (batch, head_num, row_cnt, col_cnt, 2)

        two_scores_transposed = two_scores.transpose(1,2)
        # shape: (batch, row_cnt, head_num, col_cnt, 2)

        ms1 = torch.matmul(two_scores_transposed, self.mix1_weight)
        # shape: (batch, row_cnt, head_num, col_cnt, ms_hidden_dim)

        ms1 = ms1 + self.mix1_bias[None, None, :, None, :]
        # shape: (batch, row_cnt, head_num, col_cnt, ms_hidden_dim)

        ms1_activated = F.relu(ms1)

        ms2 = torch.matmul(ms1_activated, self.mix2_weight)
        # shape: (batch, row_cnt, head_num, col_cnt, 1)

        ms2 = ms2 + self.mix2_bias[None, None, :, None, :]
        # shape: (batch, row_cnt, head_num, col_cnt, 1)

        mixed_scores = ms2.transpose(1,2)
        # shape: (batch, head_num, row_cnt, col_cnt, 1)

        mixed_scores = mixed_scores.squeeze(4)
        # shape: (batch, head_num, row_cnt, col_cnt)

        weights = nn.Softmax(dim=3)(mixed_scores)
        # shape: (batch, head_num, row_cnt, col_cnt)

        out = torch.matmul(weights, v)
        # shape: (batch, head_num, row_cnt, qkv_dim)

        out_transposed = out.transpose(1, 2)
        # shape: (batch, row_cnt, head_num, qkv_dim)

        out_concat = out_transposed.reshape(batch_size, row_cnt, head_num * qkv_dim)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        return out_concat


In [5]:
# make axamples with
       # q shape: (batch, head_num, row_cnt, AddAndInstanceNormalizationnt, qkv_dim)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

ms_mha_old = MixedScore_MultiHeadAttention(**model_params)
out = ms_mha_old(q, k, v, cost_mat)
print(out.shape)

# Get num of params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(ms_mha_old))

torch.Size([128, 30, 512])
520


In [6]:
%timeit -r 5 -n 20 ms_mha_old(q, k, v, cost_mat)

12.5 ms ± 318 µs per loop (mean ± std. dev. of 5 runs, 20 loops each)


In [7]:
from einops import rearrange, repeat
import math


class MixedScoreMHA_noweights(nn.Module):
    def __init__(self, 
                 head_num: int, 
                 ms_hidden_dim: int = 16,
                 ms_layer1_init: float = (1/2)**(1/2),
                 ms_layer2_init: float = (1/16)**(1/2),
                **kwargs):
        super().__init__()

        self.head_num = head_num
        self.ms_hidden_dim = ms_hidden_dim
        self.qkv_dim = qkv_dim
        self.sqrt_qkv_dim = int(math.sqrt(qkv_dim))

        # Initialize weights and biases
        self.mix1_weight = nn.Parameter(torch.empty(head_num, 2, ms_hidden_dim))
        self.mix1_bias = nn.Parameter(torch.empty(head_num, ms_hidden_dim))
        self.mix2_weight = nn.Parameter(torch.empty(head_num, ms_hidden_dim, 1))
        self.mix2_bias = nn.Parameter(torch.empty(head_num, 1))
        nn.init.uniform_(self.mix1_weight, -ms_layer1_init, ms_layer1_init)
        nn.init.uniform_(self.mix1_bias, -ms_layer1_init, ms_layer1_init)
        nn.init.uniform_(self.mix2_weight, -ms_layer2_init, ms_layer2_init)
        nn.init.uniform_(self.mix2_bias, -ms_layer2_init, ms_layer2_init)


    def forward(self, q, k, v, matrix):
        # Prepare dot product and matrix score: [batch, head_num, row_cnt, col_cnt]
        dot_product = torch.einsum('b h r d, b h c d -> b h r c', q, k) / math.sqrt(q.shape[-1])
        matrix_score = repeat(matrix, 'b r c -> b h r c', h=self.head_num)

        # Mix the scores
        two_scores = torch.stack((dot_product, matrix_score), dim=-1)
        ms1 = torch.matmul(two_scores.transpose(1,2), self.mix1_weight) + self.mix1_bias[None, None, :, None, :]
        
        ms2 = torch.matmul(F.relu(ms1), self.mix2_weight) + self.mix2_bias[None, None, :, None, :]
        mixed_scores = rearrange(ms2, 'b h r c 1 -> b r h c')

        # Softmax and multiply with values
        weights = F.softmax(mixed_scores, dim=3)
        out = torch.matmul(weights, v)
        return rearrange(out, 'b h r d -> b r (h d)')

In [8]:

ms_mha = MixedScoreMHA_noweights(**model_params)
out = ms_mha(q, k, v, cost_mat)
print(out.shape)

# Get num of params
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(ms_mha))

torch.Size([128, 30, 512])
520


In [9]:
%timeit -r 5 -n 20 ms_mha(q, k, v, cost_mat)

13.1 ms ± 78.1 µs per loop (mean ± std. dev. of 5 runs, 20 loops each)


## Encoder

In [10]:
class Encoder(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        encoder_layer_num = model_params['encoder_layer_num']
        self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

    def forward(self, row_emb, col_emb, cost_mat):
        # col_emb.shape: (batch, col_cnt, embedding)
        # row_emb.shape: (batch, row_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)

        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)

        return row_emb, col_emb


class EncoderLayer(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**model_params)
        self.col_encoding_block = EncodingBlock(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))

        return row_emb_out, col_emb_out


class EncodingBlock(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        self.model_params = model_params
        embedding_dim = self.model_params['embedding_dim']
        head_num = self.model_params['head_num']
        qkv_dim = self.model_params['qkv_dim']

        self.Wq = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wk = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.Wv = nn.Linear(embedding_dim, head_num * qkv_dim, bias=False)
        self.mixed_score_MHA = MixedScore_MultiHeadAttention(**model_params)
        self.multi_head_combine = nn.Linear(head_num * qkv_dim, embedding_dim)

        self.add_n_normalization_1 = AddAndInstanceNormalization(**model_params)
        self.feed_forward = FeedForward(**model_params)
        self.add_n_normalization_2 = AddAndInstanceNormalization(**model_params)

    def forward(self, row_emb, col_emb, cost_mat):
        # NOTE: row and col can be exchanged, if cost_mat.transpose(1,2) is used
        # input1.shape: (batch, row_cnt, embedding)
        # input2.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        head_num = self.model_params['head_num']

        q = reshape_by_heads(self.Wq(row_emb), head_num=head_num)
        # q shape: (batch, head_num, row_cnt, qkv_dim)
        k = reshape_by_heads(self.Wk(col_emb), head_num=head_num)
        v = reshape_by_heads(self.Wv(col_emb), head_num=head_num)
        # kv shape: (batch, head_num, col_cnt, qkv_dim)

        out_concat = self.mixed_score_MHA(q, k, v, cost_mat)
        # shape: (batch, row_cnt, head_num*qkv_dim)

        multi_head_out = self.multi_head_combine(out_concat)
        # shape: (batch, row_cnt, embedding)

        out1 = self.add_n_normalization_1(row_emb, multi_head_out)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)

        return out3
        # shape: (batch, row_cnt, embedding)


class FeedForward(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        ff_hidden_dim = model_params['ff_hidden_dim']

        self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
        self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

    def forward(self, input1):
        # input.shape: (batch, problem, embedding)

        return self.W2(F.relu(self.W1(input1)))
    

def reshape_by_heads(qkv, head_num):
    # q.shape: (batch, n, head_num*key_dim)   : n can be either 1 or PROBLEM_SIZE

    batch_s = qkv.size(0)
    n = qkv.size(1)

    q_reshaped = qkv.reshape(batch_s, n, head_num, -1)
    # shape: (batch, n, head_num, key_dim)

    q_transposed = q_reshaped.transpose(1, 2)
    # shape: (batch, head_num, n, key_dim)

    return q_transposed



class AddAndInstanceNormalization(nn.Module):
    def __init__(self, **model_params):
        super().__init__()
        embedding_dim = model_params['embedding_dim']
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # input.shape: (batch, problem, embedding)

        added = input1 + input2
        # shape: (batch, problem, embedding)

        transposed = added.transpose(1, 2)
        # shape: (batch, embedding, problem)

        normalized = self.norm(transposed)
        # shape: (batch, embedding, problem)

        back_trans = normalized.transpose(1, 2)
        # shape: (batch, problem, embedding)

        return back_trans

In [11]:
# Test out

model_params = {
    'embedding_dim': 256,
    'sqrt_embedding_dim': 256**(1/2),
    'encoder_layer_num': 5,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'argmax',
    'one_hot_seed_cnt': 20,  # must be >= node_cnt
}


encoder = Encoder(**model_params)


# col_emb.shape: (batch, col_cnt, embedding)
# row_emb.shape: (batch, row_cnt, embedding)
# cost_mat.shape: (batch, row_cnt, col_cnt)
batch = 64
row_cnt = 20
col_cnt = 30

row_emb = torch.randn(batch, row_cnt, model_params['embedding_dim'])
col_emb = torch.randn(batch, col_cnt, model_params['embedding_dim'])
cost_mat = torch.randn(batch, row_cnt, col_cnt)


out = encoder(row_emb, col_emb, cost_mat)
print(encoder)

print(out[0].shape, out[1].shape)
print('Number of parameters: {:.2f} MB'.format(sum(p.numel() for p in encoder.parameters() if p.requires_grad) / 1e6))

Encoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderLayer(
      (row_encoding_block): EncodingBlock(
        (Wq): Linear(in_features=256, out_features=128, bias=False)
        (Wk): Linear(in_features=256, out_features=128, bias=False)
        (Wv): Linear(in_features=256, out_features=128, bias=False)
        (mixed_score_MHA): MixedScore_MultiHeadAttention()
        (multi_head_combine): Linear(in_features=128, out_features=256, bias=True)
        (add_n_normalization_1): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
        (feed_forward): FeedForward(
          (W1): Linear(in_features=256, out_features=512, bias=True)
          (W2): Linear(in_features=512, out_features=256, bias=True)
        )
        (add_n_normalization_2): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
      )
      

In [12]:
%timeit encoder(row_emb, col_emb, cost_mat)

67.9 ms ± 599 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## Ours (all we need)

In [13]:
from torchrl.modules.models.models import MLP

In [14]:
# rich traceback
from rich.traceback import install
install()

<bound method InteractiveShell.excepthook of <ipykernel.zmqshell.ZMQInteractiveShell object at 0x7f3e00133b20>>

In [15]:
class MixedScoreMHA(nn.Module):
    def __init__(self, 
                    embed_dim,
                    num_heads,
                    hidden_dim: int = 16,
                    qkv_dim: int = 16,
                    bias=False,
                    layer1_init: float = (1/2)**(1/2),
                    layer2_init: float = (1/16)**(1/2),
                    device=None,
                    dtype=None
        ):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        assert (embed_dim % num_heads == 0), "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        # Project
        self.Wq = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wk = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wv = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.out_proj = nn.Linear(num_heads*qkv_dim, embed_dim, **factory_kwargs)

        # Init mix params
        self.mix1_weight = nn.Parameter(torch.empty(num_heads, 2, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix1_bias = nn.Parameter(torch.empty(num_heads, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix2_weight = nn.Parameter(torch.empty(num_heads, hidden_dim, 1).uniform_(-layer2_init, layer2_init))
        self.mix2_bias = nn.Parameter(torch.empty(num_heads, 1).uniform_(-layer2_init, layer2_init))

    def forward(self, q, k, v, matrix):
        # Project q, k, v and reshape to [batch, head_num, row_cnt, hidden_dim]
        q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
        q, k, v = map(lambda t: self._reshape_heads(t), (q, k, v))

        # Prepare dot product and matrix score: [batch, head_num, row_cnt, col_cnt]
        dot_product = torch.einsum('...rd,...cd->...rc', q, k) / math.sqrt(q.shape[-1])
        matrix_score = repeat(matrix, 'b r c -> b h r c', h=self.num_heads)

        # Mix the scores. Use einsum for best performance
        two_scores = torch.stack((dot_product, matrix_score), dim=-1)
        ms1 = torch.einsum('bhrct,htd->brhcd', two_scores, self.mix1_weight)
        ms2 = torch.einsum('brhcd,hdt->brhct', F.relu(ms1), self.mix2_weight)
        mixed_scores = rearrange(ms2, 'b h r c 1 -> b r h c')

        # Softmax and multiply with values
        weights = F.softmax(mixed_scores, dim=3)
        out = torch.matmul(weights, v)
        
        # Project out
        out = rearrange(out, 'b h r d -> b r (h d)')
        return self.out_proj(out)
    
    def _reshape_heads(self, x):
        # same as rearrange(v, 'b r (h d) -> b h r d', h=self.num_heads) but faster
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).transpose(1, 2)
         


class Encoder(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        encoder_layer_num = kw['encoder_layer_num']
        self.layers = nn.ModuleList([EncoderLayer(**kw) for _ in range(encoder_layer_num)])

    def forward(self, row_emb, col_emb, cost_mat):
        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)
        return row_emb, col_emb


class EncoderLayer(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**kw)
        self.col_encoding_block = EncodingBlock(**kw)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))
        return row_emb_out, col_emb_out


class EncodingBlock(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        self.mixed_score_mha = MixedScoreMHA(kw['embedding_dim'], kw['head_num'], kw['ms_hidden_dim'])
        self.add_n_normalization_1 = AddAndInstanceNormalization(kw['embedding_dim'])
        # self.feed_forward = MLP(kw['embedding_dim'], kw['embedding_dim'], 1, kw['ff_hidden_dim'], activation_class=nn.ReLU)
        self.feed_forward = nn.Sequential(
            nn.Linear(kw['embedding_dim'], kw['ff_hidden_dim']),
            nn.ReLU(),
            nn.Linear(kw['ff_hidden_dim'], kw['embedding_dim'])
        )
        self.add_n_normalization_2 = AddAndInstanceNormalization(kw['embedding_dim'])

    def forward(self, row_emb, col_emb, cost_mat):
        q, k, v = row_emb, col_emb, col_emb
        out_mha = self.mixed_score_mha(q, k, v, cost_mat)
        out1 = self.add_n_normalization_1(row_emb, out_mha)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)
        return out3 # shape: (batch, row_cnt, embedding)
        
    
class AddAndInstanceNormalization(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # [batch, problem, embedding]
        added = input1 + input2
        normalized = self.norm(added.transpose(1, 2)).transpose(1, 2)
        return normalized

In [16]:
encoder = Encoder(**model_params)
print(encoder)
print('Number of parameters: {:.2f} MB'.format(sum(p.numel() for p in encoder.parameters() if p.requires_grad) / 1e6))
out = encoder(row_emb, col_emb, cost_mat)
print(out[0].shape, out[1].shape)
#print number of parameters


Encoder(
  (layers): ModuleList(
    (0-4): 5 x EncoderLayer(
      (row_encoding_block): EncodingBlock(
        (mixed_score_mha): MixedScoreMHA(
          (Wq): Linear(in_features=256, out_features=128, bias=False)
          (Wk): Linear(in_features=256, out_features=128, bias=False)
          (Wv): Linear(in_features=256, out_features=128, bias=False)
          (out_proj): Linear(in_features=128, out_features=256, bias=True)
        )
        (add_n_normalization_1): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
        (feed_forward): Sequential(
          (0): Linear(in_features=256, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=256, bias=True)
        )
        (add_n_normalization_2): AddAndInstanceNormalization(
          (norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
        )
    

In [17]:
%timeit encoder(row_emb, col_emb, cost_mat)

57.1 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
