In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensordict import TensorDict

In [7]:
from rl4co.models.nn.attention import MultiHeadCrossAttention, MultiHeadAttention
from rl4co.models.zoo.matnet.encoder import MixedScoresSDPA
from rl4co.envs.scheduling.jssp import JSSPEnv
from rl4co.models.nn.ops import Normalization
from rl4co.models.zoo.common.decoder_only.policy import DecoderOnlyPolicy
from rl4co.models.nn.graph.gcn import GCNEncoder
from rl4co.models.nn.graph.graphCNN import GraphCNN
from rl4co.models.nn.ops import PositionalEncoding

In [8]:
self_attn_block = MultiHeadAttention(128, 8)
cross_attn_block = MultiHeadCrossAttention(128, 8)

In [9]:
env = JSSPEnv(6, 6)

In [10]:
td = env._reset(batch_size=10)

In [11]:
class JSSPInitEmbedding(nn.Module):

    def __init__(self, embedding_dim, linear_bias=False):
        super(JSSPInitEmbedding, self).__init__()

        self.init_job_embed = nn.Linear(3, embedding_dim, linear_bias)
        self.init_ma_embed = nn.Linear(1, embedding_dim, linear_bias)
        self.pos_encoder = None

    def forward(self, td: TensorDict):
        bs = td.batch_size
        durations = td["durations"].reshape(*bs, -1) / 1000
        lower_bounds = td["lower_bounds"].reshape(*bs, -1) / 1000
        finished_mark = td["finished_mark"].reshape(*bs, -1)
        job_feat = torch.stack((durations, lower_bounds, finished_mark), dim=-1)
        job_emb = self.init_job_embed(job_feat)

        ma_max_end = td["end_times"].gather(2, td["machines"]).max(1).values
        ma_max_end = ma_max_end.reshape(*bs, -1, 1) / 1000
        ma_emb = self.init_ma_embed(ma_max_end)

        return job_emb, ma_emb


class JSSPInitEmbedding(nn.Module):

    def __init__(self, embedding_dim, linear_bias=False):
        super(JSSPInitEmbedding, self).__init__()

        self.init_job_embed = nn.Linear(4, embedding_dim, linear_bias)
        self.pos_encoder = PositionalEncoding(embedding_dim)

    def forward(self, td: TensorDict):
        bs, jobs, ops = td["durations"].shape
        total_ops = jobs * ops
        
        durations = td["durations"].reshape(bs, -1) / 1000
        start_times = td["start_times"].reshape(bs, -1) / 1000
        lower_bounds = td["lower_bounds"].reshape(bs, -1) / 1000
        finished_mark = td["finished_mark"].reshape(bs, -1)
        job_feat = torch.stack((start_times, durations, lower_bounds, finished_mark), dim=-1)
        job_emb = self.init_job_embed(job_feat)

        seq_pos = torch.arange(ops).repeat(jobs)[None].expand(bs, -1)
        job_emb = self.pos_encoder(job_emb, seq_pos)

        return job_emb

In [12]:
# op_ma_map = F.one_hot(td["machines"].argsort(-1).reshape(*td.batch_size, -1), env.num_machines).to(torch.float32)
# ma_emb = op_ma_map.transpose(-2, -1).bmm(job_emb)

In [52]:
class EncoderBlock(nn.Module):

    def __init__(
        self, 
        embedding_dim=128, 
        num_heads=8, 
        num_scores=1,
        feed_forward_hidden=256,
        normalization="batch",
    ):
        super(EncoderBlock, self).__init__()
        ms = MixedScoresSDPA(num_heads, num_scores=num_scores)
        self.cross_attn_block = MultiHeadCrossAttention(
            embedding_dim, num_heads, sdpa_fn=ms
        )
        self.F_a = nn.ModuleDict(
            {
                "norm1": Normalization(embedding_dim, normalization),
                "ffn": nn.Sequential(
                    nn.Linear(embedding_dim, feed_forward_hidden),
                    nn.ReLU(),
                    nn.Linear(feed_forward_hidden, embedding_dim),
                ),
                "norm2": Normalization(embedding_dim, normalization),
            }
        )

    def forward(self, x, dmat=None, mask=None):
        x_out = self.cross_attn_block(x, x, cross_attn_mask=mask, dmat=dmat)
        x_emb_out = self.F_a["norm1"](x + x_out)
        x_emb_out = self.F_a["norm2"](x_emb_out + self.F_a["ffn"](x_emb_out))
        return x_emb_out


class Encoder(nn.Module):

    def __init__(self, embedding_dim=128, num_heads=8, num_scores=1, num_l=3, linear_bias=False):
        super(Encoder, self).__init__()
        self.embedding_dim = embedding_dim
        self.init_emb = JSSPInitEmbedding(embedding_dim)
        self.layers = nn.ModuleList([
            EncoderBlock(embedding_dim, num_heads, num_scores) for _ in range(num_l)
        ])
    
    def forward(self, td):
        bs, num_jobs, num_ops = td["durations"].shape
        
        init_emb = self.init_emb(td)
        op_emb = init_emb.clone()
        dmat = torch.stack(tuple(td.select("ops_on_same_ma_adj", "adjacency").values()), dim=-1)
        
        for layer in self.layers:
            op_emb = layer(op_emb, dmat=dmat)
        
        job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
        return job_emb, job_init_emb

      
class Encoder(GCNEncoder):

        
    def __init__(self, embedding_dim, num_layers):
    
        def edge_idx_fn(td, *args, **kwargs):
            return torch.permute(td["adjacency"].nonzero(), (1,0))
        
        init_embedding = JSSPInitEmbedding(emb_dim)
        
        super().__init__("jssp", embedding_dim, num_layers, init_embedding=init_embedding, edge_idx_fn=edge_idx_fn)

    def forward(self, td):
        bs, num_jobs, num_ops = td["durations"].shape
        op_emb, init_emb = super().forward(td)
        job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
        return job_emb, job_init_emb     


class Encoder(GraphCNN):
    def __init__(self, embedding_dim, num_layers):
        super().__init__("jssp", embedding_dim, num_layers)
        self.embedding_dim = embedding_dim

    def forward(self, td):
        bs, num_jobs, num_ops = td["durations"].shape
        op_emb, init_emb = super().forward(td)
        job_init_emb = init_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        job_emb = op_emb.gather(1, td["next_op"][...,None].expand(bs, num_jobs, self.embedding_dim))
        
        return job_emb, job_init_emb    


In [53]:
emb_dim = 128
layers = 3

In [54]:
enc = Encoder(emb_dim, layers)

In [55]:
# enc = Encoder(num_scores=2)

In [56]:
emb, _ = enc(td)

In [57]:
emb.shape

torch.Size([10, 6, 128])

In [58]:
jssp_policy = DecoderOnlyPolicy(
    env_name=env, embedding_dim=emb_dim, feature_extractor=enc
)

In [59]:
from rl4co.models.rl import PPO, REINFORCE
class L2DReinforce(REINFORCE):
    def __init__(
        self,
        env,
        policy,
        baseline = "rollout",
        policy_kwargs={},
        baseline_kwargs={},
        **kwargs,
    ):
        if policy is None:
            policy = DecoderOnlyPolicy(env.name, **policy_kwargs)

        super().__init__(env, policy, baseline, baseline_kwargs, **kwargs)

In [60]:
model = L2DReinforce(
    env, 
    jssp_policy,
    batch_size = 50,
    val_batch_size = None,
    test_batch_size = None,
    train_data_size = 10000,
    val_data_size = 1000,
    test_data_size = 1000
)

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [61]:
from rl4co.utils.trainer import RL4COTrainer

trainer = RL4COTrainer(
    max_epochs=1
)

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:551: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [62]:
trainer.fit(model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead

  | Name     | Type              | Params
-----------------------------------------------
0 | env      | JSSPEnv           | 0     
1 | policy   | DecoderOnlyPolicy | 83.6 K
2 | baseline | WarmupBaseline    | 83.6 K
-----------------------------------------------
167 K     Trainable params
0         Non-trainable params
167 K     Total params
0.669     Total estimated model params size (MB)


Sanity Checking: |                                                                | 0/? [00:00<?, ?it/s]

/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/luttmann/opt/miniconda3/envs/rl4co/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Training: |                                                                       | 0/? [00:00<?, ?it/s]

Validation: |                                                                     | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
