# MatNet Model

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append(2*"../")

from einops import rearrange, repeat
import math
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchrl.modules.models.models import MLP
from torchrl.envs import EnvBase
from tensordict.tensordict import TensorDict

from rl4co.envs import ATSPEnv 
from rl4co.utils.ops import batchify, unbatchify
from rl4co.models.nn.attention import LogitAttention
from rl4co.models.nn.utils import decode_probs, get_log_likelihood
from rl4co.models.nn.env_context import env_context
from rl4co.models.nn.env_embedding import env_dynamic_embedding, env_init_embedding
from rl4co.data.dataset import TensorDictCollate

  from .autonotebook import tqdm as notebook_tqdm


## Differences between AM and MatNet

1. MatNet uses a dual graph attention layer for processing the  set of source and destination nodes A and B separately
2. Mixed-score attention: this should make the network learn the "best" recipe
3. Initial node representation: zero-vectors for A nodes and one-hot vectors for B nodes

In [2]:
env = ATSPEnv(num_loc=10)
env.reset()

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.bool, is_shared=False),
        cost_matrix: Tensor(shape=torch.Size([10, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        current_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        first_node: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        i: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [3]:
# Test out

# col_emb.shape: (batch, col_cnt, embedding)
# row_emb.shape: (batch, row_cnt, embedding)
# cost_mat.shape: (batch, row_cnt, col_cnt)
batch = 64
row_cnt = 20
col_cnt = 30


model_params = {
    'embedding_dim': 256,
    'sqrt_embedding_dim': 256**(1/2),
    'encoder_layer_num': 5,
    'qkv_dim': 16,
    'sqrt_qkv_dim': 16**(1/2),
    'head_num': 16,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'ms_hidden_dim': 16,
    'ms_layer1_init': (1/2)**(1/2),
    'ms_layer2_init': (1/16)**(1/2),
    'eval_type': 'argmax',
    'one_hot_seed_cnt': 20,  # must be >= node_cnt
}


row_emb = torch.randn(batch, row_cnt, model_params['embedding_dim'])
col_emb = torch.randn(batch, col_cnt, model_params['embedding_dim'])
cost_mat = torch.randn(batch, row_cnt, col_cnt)

## Ours (all we need)

In [4]:
class MixedScoreMHA(nn.Module):
    def __init__(self, 
                    embed_dim,
                    num_heads,
                    hidden_dim: int = 16,
                    qkv_dim: int = 16,
                    bias=False,
                    layer1_init: float = (1/2)**(1/2),
                    layer2_init: float = (1/16)**(1/2),
                    device=None,
                    dtype=None
        ):
        super().__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        assert (embed_dim % num_heads == 0), "embed_dim must be divisible by num_heads"
        self.num_heads = num_heads
        self.embed_dim = embed_dim

        # Project
        self.Wq = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wk = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.Wv = nn.Linear(embed_dim, num_heads*qkv_dim, bias=bias, **factory_kwargs)
        self.out_proj = nn.Linear(num_heads*qkv_dim, embed_dim, **factory_kwargs)

        # Init mix params
        self.mix1_weight = nn.Parameter(torch.empty(num_heads, 2, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix1_bias = nn.Parameter(torch.empty(num_heads, hidden_dim).uniform_(-layer1_init, layer1_init))
        self.mix2_weight = nn.Parameter(torch.empty(num_heads, hidden_dim, 1).uniform_(-layer2_init, layer2_init))
        self.mix2_bias = nn.Parameter(torch.empty(num_heads, 1).uniform_(-layer2_init, layer2_init))

    def forward(self, q, k, v, matrix):
        # Project q, k, v and reshape to [batch, head_num, row_cnt, hidden_dim]
        q, k, v = self.Wq(q), self.Wk(k), self.Wv(v)
        q, k, v = map(lambda t: self._reshape_heads(t), (q, k, v))

        # Prepare dot product and matrix score: [batch, head_num, row_cnt, col_cnt]
        dot_product = torch.einsum('...rd,...cd->...rc', q, k) / math.sqrt(q.shape[-1])
        matrix_score = repeat(matrix, 'b r c -> b h r c', h=self.num_heads)

        # Mix the scores. Use einsum for best performance
        two_scores = torch.stack((dot_product, matrix_score), dim=-1)
        ms1 = torch.einsum('bhrct,htd->brhcd', two_scores, self.mix1_weight)
        ms2 = torch.einsum('brhcd,hdt->brhct', F.relu(ms1), self.mix2_weight)
        mixed_scores = rearrange(ms2, 'b h r c 1 -> b r h c')

        # Softmax and multiply with values
        weights = F.softmax(mixed_scores, dim=3)
        out = torch.matmul(weights, v)
        
        # Project out
        out = rearrange(out, 'b h r d -> b r (h d)')
        return self.out_proj(out)
    
    def _reshape_heads(self, x):
        # same as rearrange(v, 'b r (h d) -> b h r d', h=self.num_heads) but faster
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).transpose(1, 2)
         
    
class AddAndInstanceNormalization(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

    def forward(self, input1, input2):
        # [batch, problem, embedding]
        added = input1 + input2
        normalized = self.norm(added.transpose(1, 2)).transpose(1, 2)
        return normalized
    
class EncodingBlock(nn.Module):
    def __init__(self,
                    embed_dim,
                    num_heads,
                    ms_hidden_dim=16,
                    ff_hidden_dim=512,
                    **mha_kwargs
                 ):
        super().__init__()
        self.mixed_score_mha = MixedScoreMHA(embed_dim, num_heads, ms_hidden_dim, **mha_kwargs)
        self.add_n_normalization_1 = AddAndInstanceNormalization(embed_dim)
        self.feed_forward = MLP(embed_dim, embed_dim, 1, ff_hidden_dim, activation_class=nn.ReLU)
        self.add_n_normalization_2 = AddAndInstanceNormalization(embed_dim)

    def forward(self, row_emb, col_emb, cost_mat):
        q, k, v = row_emb, col_emb, col_emb
        out_mha = self.mixed_score_mha(q, k, v, cost_mat)
        out1 = self.add_n_normalization_1(row_emb, out_mha)
        out2 = self.feed_forward(out1)
        out3 = self.add_n_normalization_2(out1, out2)
        return out3 # shape: (batch, row_cnt, embedding)
    

class EncoderLayer(nn.Module):
    def __init__(self, **kw):
        super().__init__()
        self.row_encoding_block = EncodingBlock(**kw)
        self.col_encoding_block = EncodingBlock(**kw)

    def forward(self, row_emb, col_emb, cost_mat):
        # row_emb.shape: (batch, row_cnt, embedding)
        # col_emb.shape: (batch, col_cnt, embedding)
        # cost_mat.shape: (batch, row_cnt, col_cnt)
        row_emb_out = self.row_encoding_block(row_emb, col_emb, cost_mat)
        col_emb_out = self.col_encoding_block(col_emb, row_emb, cost_mat.transpose(1, 2))
        return row_emb_out, col_emb_out
        

class MatNetEncoder(nn.Module):
    def __init__(self, num_layers, **kw):
        super().__init__()
        self.layers = nn.ModuleList([EncoderLayer(**kw) for _ in range(num_layers)])

    def forward(self, row_emb, col_emb, cost_mat):
        for layer in self.layers:
            row_emb, col_emb = layer(row_emb, col_emb, cost_mat)
        return row_emb, col_emb

In [5]:
# encoder = MatNetEncoder(**model_params)

encoder = MatNetEncoder(num_layers=5, embed_dim=256, num_heads=16)
print('Number of parameters: {:.2f} MB'.format(sum(p.numel() for p in encoder.parameters() if p.requires_grad) / 1e6))
out = encoder(row_emb, col_emb, cost_mat)
print(out[0].shape, out[1].shape)
#print number of parameters


Number of parameters: 5.27 MB
torch.Size([64, 20, 256]) torch.Size([64, 30, 256])


## Create Decoder

In [6]:
def select_start_nodes(batch_size, num_nodes, device="cpu"):
    """Node selection strategy for POMO
    Selects different start nodes for each batch element
    """
    selected = torch.arange(num_nodes, device=device).repeat(batch_size)  # TODO: check
    return selected

In [7]:
@dataclass
class PrecomputedCache:
    row_embeddings: torch.Tensor
    # graph_context: torch.Tensor # TODO: check if used in MatNet
    glimpse_key: torch.Tensor
    glimpse_val: torch.Tensor
    logit_key: torch.Tensor


class Decoder(nn.Module):
    def __init__(self, env, embedding_dim, num_heads, num_pomo=20, **logit_attn_kwargs):
        super(Decoder, self).__init__()

        self.env = env
        self.embedding_dim = embedding_dim

        assert embedding_dim % num_heads == 0

        self.context = env_context(self.env.name, {"embedding_dim": embedding_dim})
        self.dynamic_embedding = env_dynamic_embedding(
            self.env.name, {"embedding_dim": embedding_dim}
        )

        # For each node we compute (glimpse key, glimpse value ) so 2 * embedding_dim
        # In original implementation, this is separated but this is the same
        # Note that compared to original AM, we do not project the logit key
        self.project_node_embeddings = nn.Linear(
            embedding_dim, 2 * embedding_dim, bias=False
        )
    
        self.project_fixed_context = nn.Linear(embedding_dim, embedding_dim, bias=False)

        # MHA
        self.logit_attention = LogitAttention(
            embedding_dim, num_heads, **logit_attn_kwargs
        )

        # POMO
        self.num_pomo = max(num_pomo, 1) # POMO = 1 is just normal REINFORCE

    def forward(self, td, embeddings, decode_type="sampling"):
        # Collect outputs
        outputs = []
        actions = []
 

        if self.num_pomo > 1:
            # POMO: first action is decided via select_start_nodes
            action = select_start_nodes(batch_size=td.shape[0], num_nodes=self.num_pomo, device=td.device)

            # # Expand td to batch_size * num_pomo
            td = batchify(td, self.num_pomo)

            td.set("action", action)
            td = self.env.step(td)["next"]
            log_p = torch.zeros_like(td['action_mask'], device=td.device) # first log_p is 0, so p = log_p.exp() = 1

            outputs.append(log_p.squeeze(1))
            actions.append(action)
        
        # Compute keys, values for the glimpse and keys for the logits once as they can be reused in every step
        cached_embeds = self._precompute(embeddings)      
        
        # Here we suppose all the batch is done at the same time
        while not td["done"].all():  
            # Compute the logits for the next node
            log_p, mask = self._get_log_p(cached_embeds, td)

            # Select the indices of the next nodes in the sequences, result (batch_size) long
            action = decode_probs(
                log_p.exp().squeeze(1), mask.squeeze(1), decode_type=decode_type
            )

            # Step the environment
            td.set("action", action)
            td = self.env.step(td)["next"]

            # Collect output of step
            outputs.append(log_p.squeeze(1))
            actions.append(action)

        outputs, actions = torch.stack(outputs, 1), torch.stack(actions, 1)
        td.set("reward", self.env.get_reward(td, actions))
        return outputs, actions, td
    
    def _precompute(self, embeddings):       
        # The projection of the node embeddings for the attention is calculated once up front
        row_embed, col_embed = embeddings
    
        (
            glimpse_key_fixed,
            glimpse_val_fixed,
        ) = self.project_node_embeddings(col_embed).chunk(2, dim=-1)


        # Organize in a dataclass for easy access
        cached_embeds = PrecomputedCache(
            row_embeddings=batchify(row_embed, self.num_pomo),
            glimpse_key=batchify(glimpse_key_fixed, self.num_pomo),
            glimpse_val=batchify(glimpse_val_fixed, self.num_pomo),
            logit_key=batchify(col_embed, self.num_pomo),
        )

        return cached_embeds

    def _get_log_p(self, cached, td):
        # Compute the query based on the context (computes automatically the first and last node context)
        step_context = self.context(cached.row_embeddings, td)
        glimpse_q = step_context.unsqueeze(1)  # TODO check in POMO, no graph context (trick for overfit) # [batch, 1, embed_dim] # TODO: check if this is the same as POMO

        # Compute keys and values for the nodes
        glimpse_key_dynamic, glimpse_val_dynamic, logit_key_dynamic = self.dynamic_embedding(td)
        glimpse_k = cached.glimpse_key + glimpse_key_dynamic
        glimpse_v = cached.glimpse_val + glimpse_val_dynamic
        logit_k = cached.logit_key + logit_key_dynamic

        # Get the mask
        mask = ~td["action_mask"]

        # Compute logits
        log_p = self.logit_attention(glimpse_q, glimpse_k, glimpse_v, logit_k, mask)

        return log_p, mask

In [8]:
class MatNetInitEmbedding(nn.Module):
    def __init__(self, embedding_dim, one_hot_seed_cnt):
        super(MatNetInitEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.one_hot_seed_cnt = one_hot_seed_cnt

    def forward(self, td):
        # Generate initial embeddings: [batch, node, node]
        cost_mat = td["cost_matrix"]
        batch_size = cost_mat.size(0)
        node_cnt = cost_mat.size(1)
        row_emb = torch.zeros((batch_size, node_cnt, self.embedding_dim), device=cost_mat.device)
        col_emb = torch.zeros((batch_size, node_cnt, self.embedding_dim), device=cost_mat.device)
        # randomize col_emb: we refactor with topk
        rand = torch.rand(batch_size, self.one_hot_seed_cnt, device=cost_mat.device)
        _, rand_idx = rand.topk(node_cnt, dim=1, largest=False)
        b_idx, n_idx = torch.meshgrid(torch.arange(batch_size, device=cost_mat.device),
                                    torch.arange(node_cnt, device=cost_mat.device))
        col_emb[b_idx, n_idx, rand_idx] = 1
        return row_emb, col_emb

In [9]:
embedding = MatNetInitEmbedding(128, 50)

td = TensorDict({"cost_matrix": torch.randn(100, 50, 50)}, batch_size=100)
a, b = embedding(td)
a.shape, b.shape

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


(torch.Size([100, 50, 128]), torch.Size([100, 50, 128]))

In [10]:
class MatNetPolicy(nn.Module):
    def __init__(
        self,
        env: EnvBase,
        encoder: nn.Module = None,
        decoder: nn.Module = None,
        embedding_dim: int = 256,
        num_pomo: int = 10,
        one_hot_seed_cnt: int = 20,
        num_encode_layers: int = 5,
        num_heads: int = 16,
        mask_inner: bool = True,
        train_decode_type: str = "sampling",
        val_decode_type: str = "greedy",
        test_decode_type: str = "greedy",
        **unused_kwargs
    ):
        super(MatNetPolicy, self).__init__()

        if len(unused_kwargs) > 0:
            print("Unused kwargs found in MatNetPolicy init: ", unused_kwargs)

        self.env = env

        self.init_embedding = MatNetInitEmbedding(embedding_dim, one_hot_seed_cnt)

        self.encoder = (
            MatNetEncoder(
                num_heads=num_heads,
                embed_dim=embedding_dim,
                num_layers=num_encode_layers,
            )
            if encoder is None
            else encoder
        )

        self.decoder = (
            Decoder(
                env,
                embedding_dim,
                num_heads,
                num_pomo=num_pomo,
                mask_inner=mask_inner,
            )
            if decoder is None
            else decoder
        )
        self.num_pomo = num_pomo
        self.train_decode_type = train_decode_type
        self.val_decode_type = val_decode_type
        self.test_decode_type = test_decode_type

    def forward(
        self,
        td: TensorDict,
        phase: str = "train",
        return_actions: bool = False,
        **decoder_kwargs,
    ) -> TensorDict:
        """Given observation, precompute embeddings and rollout"""

        # Set decoding type for policy, can be also greedy
        row_emb, col_emb = self.init_embedding(td)
        encoded_inputs = self.encoder(row_emb, col_emb, td["cost_matrix"])
        
        # Get decode type depending on phase
        if decoder_kwargs.get("decode_type", None) is None:
            decoder_kwargs["decode_type"] = getattr(self, f"{phase}_decode_type")

        # Main rollout
        log_p, actions, td = self.decoder(td, encoded_inputs, **decoder_kwargs)

        # Log likelyhood is calculated within the model since returning it per action does not work well with
        ll = get_log_likelihood(log_p, actions, td.get("mask", None))
        out = {
            "reward": td["reward"],
            "log_likelihood": ll,
            "actions": actions if return_actions else None,
        }

        return out


## Test policy only

In [11]:
env = ATSPEnv(num_loc=20)
env.name = "tsp" # TODO: make this automatic when creating env

dataset = env.dataset(batch_size=[10000])

dataloader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=False,  # no need to shuffle, we're resampling every epoch
    num_workers=0,
    collate_fn=TensorDictCollate(),
)

policy = MatNetPolicy(
    env,
    num_pomo=20,
).to("cuda")

# model = torch.compile(model)

td = next(iter(dataloader)).to("cuda")
td = env.reset(td)

out = policy(td, decode_type="sampling")

print(out)

{'reward': tensor([-3.8983, -3.2838, -3.1124,  ..., -3.7781, -3.8790, -2.8345],
       device='cuda:0'), 'log_likelihood': tensor([-29.3295, -38.0420, -36.7787,  ..., -31.7003, -41.2502, -29.0416],
       device='cuda:0', grad_fn=<SumBackward1>), 'actions': None}


## Create full MatNet: `env` + `policy` + `baseline`

In [12]:
from typing import Any, Optional, Tuple, Union

import lightning as L

from rl4co.utils.lightning import get_lightning_device
from rl4co.models.rl.reinforce import WarmupBaseline, RolloutBaseline, ExponentialBaseline, SharedBaseline
from rl4co.models.zoo.pomo.utils import get_best_actions
from rl4co.data.dataset import TensorDictDataset


class MatNet(nn.Module):
    def __init__(self, env, policy=None, baseline=None):
        super().__init__()
        self.env = env
        self.policy = MatNetPolicy(env) if policy is None else policy
        self.baseline = SharedBaseline() if baseline is None else baseline
        # self.baseline = WarmupBaseline(RolloutBaseline()) if baseline is None else baseline
        self.num_pomo = self.policy.num_pomo

    def forward(
        self,
        td: TensorDict,
        phase: str = "train",
        decode_type: str = "sampling",
        return_actions: bool = False,
    ):
        """Evaluate model, get costs and log probabilities and compare with baseline"""

        # Evaluate model, get costs and log probabilities
        out = self.policy(td, decode_type=decode_type, return_actions=return_actions)

        # Max POMO reward [batch, num_pomo]
        reward = unbatchify(out["reward"], self.num_pomo,)
        max_reward, max_idxs = reward.max(dim=1)
        out.update(
            {
                "max_reward": max_reward,
                "best_actions": get_best_actions(out["actions"], max_idxs)
                if return_actions
                else None,
            }
        )

        if phase == "train":
            costs = unbatchify(-out["reward"], self.policy.num_pomo)
            ll = unbatchify(out["log_likelihood"], self.policy.num_pomo)
            bl_val, bl_loss = self.baseline.eval(td, costs)

            # Calculate REINFORCE loss
            advantage = costs - bl_val
            reinforce_loss = (advantage * ll).mean()
            loss = reinforce_loss + bl_loss
            out.update(
                {
                    "loss": loss,
                    "reinforce_loss": reinforce_loss,
                    "bl_loss": bl_loss,
                    "bl_val": bl_val,
                }
            )

        return out

    def setup(self, lit_module):
        # Make baseline taking model itself and train_dataloader from model as input
        self.baseline.setup(
            self.policy,
            lit_module.val_dataloader(),
            self.env,
            device=get_lightning_device(lit_module),
        )

    def on_train_epoch_end(self, lit_module):
        self.baseline.epoch_callback(
            self.policy,
            lit_module.val_dataloader(),
            lit_module.current_epoch,
            self.env,
            device=get_lightning_device(lit_module),
        )

In [13]:

model = MatNet(
    env,
    policy,
    # baseline=baseline,
).to("cuda")


td = next(iter(dataloader)).to("cuda")
td = env.reset(td)

out = model(td, decode_type="sampling")

print(out)

{'reward': tensor([-3.8317, -3.5553, -3.3876,  ..., -3.4134, -3.6070, -2.9557],
       device='cuda:0'), 'log_likelihood': tensor([-29.3433, -35.0586, -33.1406,  ..., -30.6922, -31.7506, -33.8129],
       device='cuda:0', grad_fn=<SumBackward1>), 'actions': None, 'max_reward': tensor([-3.4220, -3.0902, -2.7575, -2.9322, -3.0552, -3.5105, -2.8032, -3.1806,
        -2.5953, -3.4013, -3.7320, -3.1029, -3.4469, -3.9353, -2.2647, -2.9603,
        -2.4990, -3.4372, -3.0389, -3.0226, -3.2946, -2.6400, -3.5070, -2.8914,
        -3.6952, -3.4368, -2.9230, -4.0867, -4.2507, -3.0500, -3.0880, -2.5604,
        -3.6136, -2.9886, -2.5695, -2.7840, -2.4710, -2.5222, -2.5736, -3.1487,
        -2.8049, -2.9437, -2.8800, -3.3523, -3.3092, -2.8107, -1.9092, -2.6659,
        -2.5970, -2.6209, -3.3530, -2.2931, -2.2357, -3.2969, -3.2027, -3.3188,
        -2.4187, -2.3804, -2.7434, -3.0590, -2.8018, -2.7455, -3.1963, -2.6717],
       device='cuda:0'), 'best_actions': None, 'loss': tensor(0.0075, device='cud

## Lightning Module

In [14]:
class RL4COLitModule(L.LightningModule):
    def __init__(self, env, model, cfg):
        """
        Base LightningModule for Neural Combinatorial Optimization
        If model_cfg is passed, it will take precedence over cfg.model
        Likewise for env_cfg
        
        NOTE: simplified not to use Hydra instantiate here
        """

        super().__init__()
        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters(cfg)
        self.cfg = cfg
        self.env = env
        self.model = model
        self.instantiate_metrics()

    def instantiate_metrics(self):
        """Dictionary of metrics to be logged at each phase"""
        self.train_metrics = self.cfg.metrics.get("train", ["loss", "reward"])
        self.val_metrics = self.cfg.metrics.get("val", ["reward"])
        self.test_metrics = self.cfg.metrics.get("test", ["reward"])
        self.log_on_step = self.cfg.metrics.get("log_on_step", True)


    def setup(self, stage="fit"):
        self.train_dataset = self.env.dataset(self.cfg.data.train_size, "train")
        self.val_dataset = self.env.dataset(self.cfg.data.val_size, "val")
        test_size = self.cfg.data.get("test_size", self.cfg.data.val_size)
        self.test_dataset = self.env.dataset(test_size, "test")
        if hasattr(self.model, "setup"):
            self.model.setup(self)

    def configure_optimizers(self):
        parameters = (
            self.parameters()
        )  # this will train task specific parameters such as Retrieval head for AAN
        optimizer = torch.optim.Adam(
            parameters, lr=self.cfg.optim.lr, weight_decay=self.cfg.optim.weight_decay
        )
        return [optimizer] # NOTE: for simplicity we do not include the scheduler here
    
    def shared_step(self, batch: Any, batch_idx: int, phase: str):
        td = self.env.reset(batch)
        out = self.model(td, phase)
        # Log metrics
        metrics = getattr(self, f"{phase}_metrics")
        metrics = {f"{phase}/{k}": v.mean() for k, v in out.items() if k in metrics}
        self.log_dict(
            metrics,
            on_step=self.log_on_step,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )

        return {"loss": out.get("loss", None)}

    def training_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase="train")

    def validation_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase="val")

    def test_step(self, batch: Any, batch_idx: int):
        return self.shared_step(batch, batch_idx, phase="test")

    def train_dataloader(self):
        return self._dataloader(self.train_dataset)

    def val_dataloader(self):
        return self._dataloader(self.val_dataset)
    
    def test_dataloader(self):
        return self._dataloader(self.test_dataset)

    def on_train_epoch_end(self):
        if hasattr(self.model, "on_train_epoch_end"):
            self.model.on_train_epoch_end(self)
        self.train_dataset = self.env.dataset(self.cfg.data.train_size, "train")

    def _dataloader(self, dataset):
        return DataLoader(
            dataset,
            batch_size=self.cfg.data.batch_size,
            shuffle=False,  # no need to shuffle, we're resampling every epoch
            num_workers=self.cfg.data.get("num_workers", 0),
            collate_fn=TensorDictCollate(),
        )


## Config

In [15]:
from omegaconf import OmegaConf, DictConfig

config = DictConfig(
    {
        "data": {
            "train_size": 100000, # with 1 epochs, this is 1k samples
            "val_size": 10000, 
            "batch_size": 64, #64,
        },
        "optim": {
            "lr": 4e-4,
            "weight_decay": 1e-6,
        },
        "metrics": {
            "train": ["loss", "reward"],
            "val": ["reward"],
            "test": ["reward"],
            "log_on_step": True,
        },
        
    }
)


lit_module = RL4COLitModule(env, model, config)

In [16]:
# Trick to make calculations faster
torch.set_float32_matmul_precision("medium")

# Trainer
trainer = L.Trainer(
    max_epochs=10, # 10
    accelerator="gpu",
    logger=None, # can replace with WandbLogger, TensorBoardLogger, etc.
    precision="16-mixed", # Lightning will handle casting to float16
    log_every_n_steps=1,   
    gradient_clip_val=1.0, # clip gradients to avoid exploding gradients!
)

# Fit the model
trainer.fit(lit_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params
----------------------------------
0 | env   | ATSPEnv | 0     
1 | model | MatNet  | 5.7 M 
----------------------------------
5.7 M     Trainable params
0         Non-trainable params
5.7 M     Total params
22.670    Total estimated model params size (MB)
2023-05-15 02:36:19.373024: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-15 02:36:19.389838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical opera

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                           

  rank_zero_warn(


Epoch 0:  34%|███▍      | 534/1563 [00:28<00:54, 18.73it/s, v_num=6, train/reward_step=-3.48, train/loss_step=0.0287]  

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
