In [1]:
import torch
import torch.nn as nn
from torch import Tensor 
from torch.nn.modules import Module
from einops import rearrange
from rl4co.utils.ops import batchify, unbatchify
from rl4co.models.zoo.common.autoregressive.decoder import PrecomputedCache
from rl4co.envs import RL4COEnvBase
from tensordict import TensorDict
from rl4co.envs import RL4COEnvBase
from rl4co.models.nn.env_embeddings import env_init_embedding
from rl4co.models.nn.attention import MultiHeadCrossAttention
from rl4co.models.zoo.matnet.encoder import MixedScoresSDPA
from rl4co.envs.scheduling.jssp import JSSPEnv
from rl4co.models.nn.ops import Normalization
from rl4co.models import DecoderOnlyPolicy, AutoregressivePolicy, L2DReinforce, AttentionModel, AutoregressiveDecoder
from rl4co.models.nn.graph.gcn import GCNEncoder
from rl4co.utils.ops import adj_to_pyg_edge_index

In [2]:
try:

    def get_free_gpu():
        import subprocess
        from io import StringIO
        import pandas as pd

        gpu_stats = subprocess.check_output(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
        gpu_df = pd.read_csv(StringIO(gpu_stats.decode("utf-8")),
                            names=['memory.used', 'memory.free'],
                            skiprows=1)
        print('GPU usage:\n{}'.format(gpu_df))
        gpu_df['memory.free'] = gpu_df['memory.free'].map(lambda x: x.rstrip(' [MiB]'))
        idx = gpu_df['memory.free'].idxmax()
        print('Returning GPU{} with {} free MiB'.format(idx, gpu_df.iloc[idx]['memory.free']))
        return idx

    free_gpu_id = get_free_gpu()
    torch.cuda.set_device(free_gpu_id)
    print(torch.cuda.is_available())
    print(torch.cuda.device_count())
    print(torch.cuda.current_device())

except:
    
    print("Could not set a default GPU")

Could not set a default GPU


In [3]:
device = torch.device(f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu")

In [4]:
env = JSSPEnv(6, 6, device=device)

In [5]:
td = env._reset(batch_size=10)

In [6]:
# class EncoderBlock(nn.Module):

#     def __init__(
#         self, 
#         embedding_dim=128, 
#         num_heads=8, 
#         num_scores=1,
#         feed_forward_hidden=256,
#         normalization="batch",
#     ):
#         super(EncoderBlock, self).__init__()
#         ms = MixedScoresSDPA(num_heads, num_scores=num_scores)
#         self.cross_attn_block = MultiHeadCrossAttention(
#             embedding_dim, num_heads, sdpa_fn=ms
#         )
#         self.F_a = nn.ModuleDict(
#             {
#                 "norm1": Normalization(embedding_dim, normalization),
#                 "ffn": nn.Sequential(
#                     nn.Linear(embedding_dim, feed_forward_hidden),
#                     nn.ReLU(),
#                     nn.Linear(feed_forward_hidden, embedding_dim),
#                 ),
#                 "norm2": Normalization(embedding_dim, normalization),
#             }
#         )

#     def forward(self, x, dmat=None, mask=None):
#         x_out = self.cross_attn_block(x, x, cross_attn_mask=mask, dmat=dmat)
#         x_emb_out = self.F_a["norm1"](x + x_out)
#         x_emb_out = self.F_a["norm2"](x_emb_out + self.F_a["ffn"](x_emb_out))
#         return x_emb_out


# class AttnEncoder(nn.Module):

#     def __init__(self, embedding_dim=128, num_heads=8, num_layers=3, linear_bias=False):
#         super(Encoder, self).__init__()
#         self.embedding_dim = embedding_dim
#         self.init_emb = JSSPInitEmbedding(embedding_dim)
#         self.layers = nn.ModuleList([
#             EncoderBlock(embedding_dim, num_heads, num_scores=1) for _ in range(num_layers)
#         ])
    
#     def forward(self, td):
#         bs, num_jobs, num_ops = td["durations"].shape
        
#         init_emb = self.init_emb(td)
#         op_emb = init_emb.clone()
#         # dmat = torch.stack(tuple(td.select("ops_on_same_ma_adj", "adjacency").values()), dim=-1)
#         dmat = td["adjacency"]
        
#         for layer in self.layers:
#             op_emb = layer(op_emb, dmat=dmat)
        
#         job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
#         job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
#         return job_emb, job_init_emb

class EncoderBlock(nn.Module):

    def __init__(
        self, 
        embedding_dim=128, 
        num_heads=8, 
        num_scores=1,
        feed_forward_hidden=256,
        normalization="batch",
    ):
        super(EncoderBlock, self).__init__()
        ms = MixedScoresSDPA(num_heads, num_scores=num_scores)
        self.cross_attn_block = MultiHeadCrossAttention(
            embedding_dim, num_heads, sdpa_fn=ms
        )
        self.F_a = nn.ModuleDict(
            {
                "norm1": Normalization(embedding_dim, normalization),
                "ffn": nn.Sequential(
                    nn.Linear(embedding_dim, feed_forward_hidden),
                    nn.ReLU(),
                    nn.Linear(feed_forward_hidden, embedding_dim),
                ),
                "norm2": Normalization(embedding_dim, normalization),
            }
        )

    def forward(self, x, dmat=None, mask=None):
        x_out = self.cross_attn_block(x, x, cross_attn_mask=mask, dmat=dmat)
        x_emb_out = self.F_a["norm1"](x + x_out)
        x_emb_out = self.F_a["norm2"](x_emb_out + self.F_a["ffn"](x_emb_out))
        return x_emb_out


class AttnEncoder(nn.Module):

    def __init__(
            self, 
            env_name,
            embedding_dim=128, 
            num_heads=8, 
            num_layers=3, 
            init_embedding: nn.Module = None,
        ):
        super().__init__()
        if isinstance(env_name, RL4COEnvBase):
            env_name = env_name.name
        self.env_name = env_name

        self.init_embedding = (
            env_init_embedding(self.env_name, {"embedding_dim": embedding_dim})
            if init_embedding is None
            else init_embedding
        )
        self.embedding_dim = embedding_dim
        self.layers = nn.ModuleList([
            EncoderBlock(embedding_dim, num_heads, num_scores=2) for _ in range(num_layers)
        ])
    
    def forward(self, td):
        bs, num_jobs, num_ops = td["durations"].shape
        
        init_emb = self.init_embedding(td)
        op_emb = init_emb.clone()
        dmat = torch.stack(
            tuple(td.select("ops_on_same_ma_adj", "ops_of_same_job").values()), 
            dim=-1
        )
        
        for layer in self.layers:
            op_emb = layer(op_emb, dmat=dmat)
        
        # job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        # job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
        return op_emb, None


class AttnDecoder(AutoregressiveDecoder):
    def __init__(
            self, 
            env_name, 
            embedding_dim: int, 
            num_heads: int, 
            use_graph_context: bool = True, 
            linear_bias: bool = False, 
            context_embedding: Module = None, 
            dynamic_embedding: Module = None, 
            **logit_attn_kwargs
        ):
        super().__init__(
            env_name, 
            embedding_dim, 
            num_heads, 
            use_graph_context, 
            linear_bias, 
            context_embedding, 
            dynamic_embedding, 
            **logit_attn_kwargs
        )

    def _get_log_p(
        self,
        cached: PrecomputedCache,
        td: TensorDict,
        softmax_temp: float = None,
        num_starts: int = 0,
    ):
        next_op = td["next_op"][..., None].expand(-1, -1, self.embedding_dim)
        # Get precomputed (cached) embeddings
        node_embeds_cache, graph_context_cache = (
            cached.node_embeddings.gather(1, next_op),
            cached.graph_context,
        )
        glimpse_k_stat, glimpse_v_stat, logit_k_stat = (
            cached.glimpse_key.gather(1, next_op),
            cached.glimpse_val.gather(1, next_op),
            cached.logit_key.gather(1, next_op),
        )  # [B, N, H]
        has_dyn_emb_multi_start = self.is_dynamic_embedding and num_starts > 1

        # Handle efficient multi-start decoding
        if has_dyn_emb_multi_start:
            # if num_starts > 0 and we have some dynamic embeddings, we need to reshape them to [B*S, ...]
            # since keys and values are not shared across starts (i.e. the episodes modify these embeddings at each step)
            glimpse_k_stat = batchify(glimpse_k_stat, num_starts)
            glimpse_v_stat = batchify(glimpse_v_stat, num_starts)
            logit_k_stat = batchify(logit_k_stat, num_starts)
            node_embeds_cache = batchify(node_embeds_cache, num_starts)
            graph_context_cache = (
                batchify(graph_context_cache, num_starts)
                if isinstance(graph_context_cache, Tensor)
                else graph_context_cache
            )
        elif num_starts > 1:
            td = unbatchify(td, num_starts)
            if isinstance(graph_context_cache, Tensor):
                # add a dimension for num_starts (will automatically be broadcasted during addition)
                graph_context_cache = graph_context_cache.unsqueeze(1)

        step_context = self.context_embedding(cached.node_embeddings, td)
        glimpse_q = step_context + graph_context_cache
        glimpse_q = (
            glimpse_q.unsqueeze(1) if glimpse_q.ndim == 2 else glimpse_q
        )  # add seq_len dim if not present

        # Compute dynamic embeddings and add to static embeddings
        glimpse_k_dyn, glimpse_v_dyn, logit_k_dyn = self.dynamic_embedding(td)
        glimpse_k = glimpse_k_stat + glimpse_k_dyn
        glimpse_v = glimpse_v_stat + glimpse_v_dyn
        logit_k = logit_k_stat + logit_k_dyn

        # Get the mask
        mask = ~td["action_mask"]

        # Compute logits
        log_p = self.logit_attention(
            glimpse_q, glimpse_k, glimpse_v, logit_k, mask, softmax_temp
        )

        # Now we need to reshape the logits and log_p to [B*S,N,...] is num_starts > 1 without dynamic embeddings
        # note that rearranging order is important here
        if num_starts > 1 and not has_dyn_emb_multi_start:
            log_p = rearrange(log_p, "b s l -> (s b) l", s=num_starts)
            mask = rearrange(mask, "b s l -> (s b) l", s=num_starts)
        return log_p, mask

      
class L2DEncoder(GCNEncoder):

        
    def __init__(self, embedding_dim, num_layers):
    
        def edge_idx_fn(td, _):
            return adj_to_pyg_edge_index(td["adjacency"])
        
        super().__init__("jssp", embedding_dim, num_layers, edge_idx_fn=edge_idx_fn)

    def forward(self, td):
        bs, num_jobs, num_ops = td["durations"].shape
        op_emb, init_emb = super().forward(td)
        job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
        return job_emb, job_init_emb     

In [7]:
emb_dim = 128
layers = 3

In [8]:
l2d_enc = L2DEncoder(emb_dim, layers)
attn_enc = AttnEncoder("jssp", emb_dim, num_layers=layers)

In [9]:
attn_dec = AttnDecoder(env_name="jssp", embedding_dim=emb_dim, num_heads=8)

In [10]:
l2d_emb, _ = l2d_enc(td)
attn_emb, _ = attn_enc(td)
print(l2d_emb.shape)
print(attn_emb.shape)

torch.Size([10, 6, 128])
torch.Size([10, 36, 128])


In [13]:
_ = attn_dec(td, attn_emb, env=env)

In [14]:
l2d_policy = DecoderOnlyPolicy(
    env_name=env, 
    embedding_dim=emb_dim, 
    feature_extractor=l2d_enc
)

In [16]:
attn_policy = AutoregressivePolicy(
    env_name="jssp",
    encoder=attn_enc, 
    decoder=attn_dec
)

In [17]:
l2d_model = L2DReinforce(
    env, 
    l2d_policy,
    batch_size = 100,
    val_batch_size = None,
    test_batch_size = None,
    train_data_size = 10000,
    val_data_size = 1000,
    test_data_size = 1000
)

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [18]:
attn_model = AttentionModel(
    env, 
    attn_policy,
    batch_size = 100,
    val_batch_size = None,
    test_batch_size = None,
    train_data_size = 10000,
    val_data_size = 1000,
    test_data_size = 1000
)

In [19]:
from rl4co.utils.trainer import RL4COTrainer

trainer = RL4COTrainer(
    max_epochs=1
)

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:551: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:67: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable T

In [20]:
trainer.fit(attn_model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead

  | Name     | Type                 | Params
--------------------------------------------------
0 | env      | JSSPEnv              | 0     
1 | policy   | AutoregressivePolicy | 499 K 
2 | baseline | WarmupBaseline       | 499 K 
--------------------------------------------------
998 K     Trainable params
0         Non-trainable params
998 K     Total params
3.995     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
trainer.fit(l2d_model)