Эксперименты с DQN для комбинаторных задач

In [3]:
from rl4co.envs.common.base import RL4COEnvBase

from rl4co.models.zoo.l2d.policy import L2DPolicy

from typing import IO, Any, Optional, cast

import torch
import torch.nn as nn

from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
from lightning.pytorch.core.saving import _load_from_checkpoint
from tensordict import TensorDict
from typing_extensions import Self

from rl4co.envs.common.base import RL4COEnvBase
from rl4co.models.rl.common.base import RL4COLitModule
from rl4co.utils.lightning import get_lightning_device

In [None]:
class DQN(RL4COLitModule):

    def __init__(
        self,
        env: RL4COEnvBase,
        policy: nn.Module,
        **kwargs,
    ):
        super().__init__(env, policy, **kwargs)
        self.policy = policy

        self.save_hyperparameters(logger=False)
        self.loss = torch.nn.MSELoss()

    def shared_step(
        self, batch: Any, batch_idx: int, phase: str, dataloader_idx: int = None
    ):
        td = self.env.reset(batch)
        # Perform forward pass (i.e., constructing solution and computing q values)
        out = self.policy(td, self.env, phase=phase, select_best=phase != "train")

        # Compute loss
        if phase == "train":
            out = self.calculate_loss(out)

        metrics = self.log_metrics(out, phase, dataloader_idx=dataloader_idx)
        return {"loss": out.get("loss", None), **metrics}

    def calculate_loss(
        self,
        policy_out: dict,
    ):
        
        predicted_q_values = policy_out["predicted_q_values"]
        actions = policy_out["actions"]
        rewards = policy_out["rewards"]

        fact_q_values = make_fact_q_values(predicted_q_values, rewards, actions)


        # Main loss function
        loss = self.loss(predicted_q_values, fact_q_values)
        policy_out.update(
            {
                "loss": loss,
            }
        )
        return policy_out

    def set_decode_type_multistart(self, phase: str):
        """Set decode type to `multistart` for train, val and test in policy.
        For example, if the decode type is `greedy`, it will be set to `multistart_greedy`.

        Args:
            phase: Phase to set decode type for. Must be one of `train`, `val` or `test`.
        """
        attribute = f"{phase}_decode_type"
        attr_get = getattr(self.policy, attribute)
        # If does not exist, log error
        if attr_get is None:
            return
        elif "multistart" in attr_get:
            return
        else:
            setattr(self.policy, attribute, f"multistart_{attr_get}")

    @classmethod
    def load_from_checkpoint(
        cls,
        checkpoint_path: _PATH | IO,
        map_location: _MAP_LOCATION_TYPE = None,
        hparams_file: Optional[_PATH] = None,
        strict: bool = False,
        **kwargs: Any,
    ) -> Self:
        """Load model from checkpoint/

        Note:
            This is a modified version of `load_from_checkpoint` from `pytorch_lightning.core.saving`.
            It deals with matching keys for the baseline by first running setup
        """

        if strict:
            strict = False

        # Do not use strict
        loaded = _load_from_checkpoint(
            cls,
            checkpoint_path,
            map_location,
            hparams_file,
            strict,
            **kwargs,
        )

        return cast(Self, loaded)


In [277]:
class L2DPolicyDQN(nn.Module):

    def __init__(
        self,
        encoder,
        decoder,
        env_name: str = "tsp",
        epsilon: float = 0.05,
        train_decode_type: str = "sampling",
        val_decode_type: str = "greedy",
        test_decode_type: str = "greedy",
        **unused_kw,
    ):
        super().__init__()

        # if len(unused_kw) > 0:
            # log.error(f"Found {len(unused_kw)} unused kwargs: {unused_kw}")

        self.env_name = env_name

        # Encoder and decoder
        # if encoder is None:
            # log.warning("`None` was provided as encoder. Using `NoEncoder`.")
            # encoder = NoEncoder()
        self.encoder = encoder
        self.decoder = decoder

        # Decoding strategies
        self.epsilon = epsilon
        self.train_decode_type = train_decode_type
        self.val_decode_type = val_decode_type
        self.test_decode_type = test_decode_type

    def forward(
        self,
        td: TensorDict,
        env: Optional[str | RL4COEnvBase] = None,
        phase: str = "train",
        calc_reward: bool = True,
        return_actions: bool = True,
        max_steps=1_000_000,
        **decoding_kwargs,
    ) -> dict:
        """Forward pass of the policy.

        Args:
            td: TensorDict containing the environment state
            env: Environment to use for decoding. If None, the environment is instantiated from `env_name`. Note that
                it is more efficient to pass an already instantiated environment each time for fine-grained control
            phase: Phase of the algorithm (train, val, test)
            calc_reward: Whether to calculate the reward
            return_actions: Whether to return the actions
            return_entropy: Whether to return the entropy
            return_hidden: Whether to return the hidden state
            return_init_embeds: Whether to return the initial embeddings
            return_sum_log_likelihood: Whether to return the sum of the log likelihood
            actions: Actions to use for evaluating the policy.
                If passed, use these actions instead of sampling from the policy to calculate log likelihood
            max_steps: Maximum number of decoding steps for sanity check to avoid infinite loops if envs are buggy (i.e. do not reach `done`)
            decoding_kwargs: Keyword arguments for the decoding strategy. See :class:`rl4co.utils.decoding.DecodingStrategy` for more information.

        Returns:
            out: Dictionary containing the reward, log likelihood, and optionally the actions and entropy
        """

        # Encoder: get encoder output and initial embeddings from initial state
        hidden, init_embeds = self.encoder(td)

        # Instantiate environment if needed
        # if isinstance(env, str) or env is None:
        #     env_name = self.env_name if env is None else env
        #     log.info(f"Instantiated environment not provided; instantiating {env_name}")
        #     env = get_env(env_name)

        # Get decode type depending on phase and whether actions are passed for evaluation
        # decode_type = decoding_kwargs.pop("decode_type", None)
        # if actions is not None:
        #     decode_type = "evaluate"
        # elif decode_type is None:
        #     decode_type = getattr(self, f"{phase}_decode_type")

        # Setup decoding strategy
        # we pop arguments that are not part of the decoding strategy
        # decode_strategy: DecodingStrategy = get_decoding_strategy(
        #     decode_type,
        #     temperature=decoding_kwargs.pop("temperature", self.temperature),
        #     tanh_clipping=decoding_kwargs.pop("tanh_clipping", self.tanh_clipping),
        #     mask_logits=decoding_kwargs.pop("mask_logits", self.mask_logits),
        #     store_all_logp=decoding_kwargs.pop("store_all_logp", return_entropy),
        #     **decoding_kwargs,
        # )

        # Pre-decoding hook: used for the initial step(s) of the decoding strategy
        # td, env, num_starts = decode_strategy.pre_decoder_hook(td, env)

        # Additionally call a decoder hook if needed before main decoding
        # td, env, hidden = self.decoder.pre_decoder_hook(td, env, hidden, num_starts)

        # Main decoding: loop until all sequences are done
        step = 0
        predicted_q_values = []
        masks = []
        done_padding = []
        selected_actions = []

        while not td["done"].all():
            q_values, mask = self.decoder(td, *hidden)
            # td = decode_strategy.step(
            #     logits,
            #     mask,
            #     td,
            #     action=actions[..., step] if actions is not None else None,
            # )

            predicted_q_values.append(q_values)
            done_padding.append(td["done"])

            epsilon = self.epsilon if phase == "train" else 0.0
            bs = q_values.size(0)
            device = q_values.device

            # 1. Жадные действия (выбираем max Q среди допустимых)
            masked_q = q_values.masked_fill(~mask, -float("inf"))  # (bs, n_actions)
            greedy_actions = masked_q.argmax(dim=-1)  # (bs,)

            # 2. Случайные допустимые действия
            random_actions = torch.zeros(bs, dtype=torch.long, device=device)
            for i in range(bs):
                valid_actions = mask[i].nonzero().squeeze(-1)  # Допустимые действия для i-го элемента
                random_actions[i] = valid_actions[torch.randint(len(valid_actions), (1,))]

            # 3. Для каждого элемента батча решаем: greedy или random
            use_random = torch.rand(bs, device=device) < epsilon  # (bs,) bool
            selected_action = torch.where(use_random, random_actions, greedy_actions)

            selected_actions.append(selected_action)
            masks.append(mask)

            td.set("action", selected_action)
            td = env.step(td)["next"]
            step += 1
            if step > max_steps:
                # log.error(
                #     f"Exceeded maximum number of steps ({max_steps}) duing decoding"
                # )
                break


        # Output dictionary construction
        if calc_reward:
            td.set("reward", env.get_reward(td, torch.stack(selected_actions).T))

        outdict = {
            "reward": td["reward"],
            "predicted_q_values": torch.concat(predicted_q_values, dim=1),
            "masks": torch.concat(masks, dim=1),
            "done_padding": torch.concat(done_padding, dim=1)
        }

        if return_actions:
            outdict["actions"] = torch.stack(selected_actions).T

        return outdict

In [278]:
class L2DModelDQN(DQN):

    def __init__(
        self,
        env: RL4COEnvBase,
        policy: L2DPolicyDQN = None,
        policy_kwargs={},
        **kwargs,
    ):
        assert env.name in [
            "fjsp",
            "jssp",
        ], "L2DModel currently only works for Job-Shop Scheduling Problems"
        if policy is None:
            policy = L2DPolicyDQN(env_name=env.name, **policy_kwargs)

        super().__init__(env, policy, **kwargs)

In [279]:
from rl4co.envs import FJSPEnv

from rl4co.models.nn.graph.hgnn import HetGNNEncoder

from rl4co.models.zoo.l2d.decoder import FJSPActor

In [280]:
encoder = HetGNNEncoder(embed_dim=32, num_layers=2)
decoder = FJSPActor(
                    embed_dim=32,
                    hidden_dim=64,
                    hidden_layers=2,
                )

In [281]:
generator_params = {
  "num_jobs": 5,  # the total number of jobs
  "num_machines": 5,  # the total number of machines that can process operations
  "min_ops_per_job": 1,  # minimum number of operatios per job
  "max_ops_per_job": 3,  # maximum number of operations per job
  "min_processing_time": 1,  # the minimum time required for a machine to process an operation
  "max_processing_time": 20,  # the maximum time required for a machine to process an operation
  "min_eligible_ma_per_op": 1,  # the minimum number of machines capable to process an operation
  "max_eligible_ma_per_op": 2,  # the maximum number of machines capable to process an operation
}

In [282]:
env = FJSPEnv(generator_params=generator_params)
td = env.reset(batch_size=[5])

In [283]:
(op_emb, ma_emb), init = encoder(td)
op_emb.shape, ma_emb.shape

(torch.Size([5, 15, 32]), torch.Size([5, 5, 32]))

In [284]:
q_values, mask = decoder(td, op_emb, ma_emb)
q_values.shape, mask.shape

(torch.Size([5, 26]), torch.Size([5, 26]))

In [285]:
q_values[0]

tensor([ 0.0225,  0.0602,  0.0643,  0.0592,  0.0512, -0.0026,  0.0361,  0.0428,
         0.0304,  0.0115, -0.0265,  0.0770,  0.0753,  0.0375,  0.0300,  0.0170,
         0.0697,  0.0621,  0.0387,  0.0256, -0.0186,  0.0972,  0.1043,  0.0989,
         0.0795, -0.0563], grad_fn=<SelectBackward0>)

In [286]:
pol = L2DPolicyDQN(encoder=encoder, decoder=decoder, env_name="fjsp")

In [287]:
pol(td, env)

{'reward': tensor([-41., -47., -44., -41., -37.]),
 'predicted_q_values': tensor([[ 0.0225,  0.0602,  0.0643,  ...,  0.0989,  0.0795, -0.0563],
         [ 0.0225,  0.0024, -0.0277,  ...,  0.0151,  0.0247, -0.0510],
         [ 0.0225,  0.0836,  0.1845,  ..., -0.0801, -0.0380, -0.1437],
         [ 0.0225, -0.0531,  0.0271,  ..., -0.0960, -0.0930, -0.0650],
         [ 0.0225,  0.1208,  0.1369,  ..., -0.0022, -0.1089,  0.0940]],
        grad_fn=<CatBackward0>),
 'masks': tensor([[False, False,  True,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False,  True, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]),
 'done_padding': tensor([[False, False, False, False, False, False, False, False, False, False,
          False,  True],
         [False, False, False, False, False, False, False, False, False, False,
          False, False],
      