Знакомство с RL4CO

In [4]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output
import time
import networkx as nx
import matplotlib.pyplot as plt
from rl4co.envs import FJSPEnv
from rl4co.models.zoo.l2d import L2DModel
from rl4co.models.zoo.l2d.policy import L2DPolicy
from rl4co.models.zoo.l2d.decoder import L2DDecoder
from rl4co.models.nn.graph.hgnn import HetGNNEncoder
from rl4co.utils.trainer import RL4COTrainer

In [5]:
from src.datasets import FJSPBenchmarksDataset

In [6]:
dataset = FJSPBenchmarksDataset(file_pattern=f"../data/jsp/benchmarks/*/*.fjs")

In [7]:
test_inst, _ = dataset[1]

In [8]:
generator_params = {
  "num_jobs": 5,  # the total number of jobs
  "num_machines": 6,  # the total number of machines that can process operations
  "min_ops_per_job": 1,  # minimum number of operatios per job
  "max_ops_per_job": 2,  # maximum number of operations per job
  "min_processing_time": 1,  # the minimum time required for a machine to process an operation
  "max_processing_time": 20,  # the maximum time required for a machine to process an operation
  "min_eligible_ma_per_op": 1,  # the minimum number of machines capable to process an operation
  "max_eligible_ma_per_op": 2,  # the maximum number of machines capable to process an operation
}

In [12]:
env = FJSPEnv(generator_params=generator_params)
td = env.reset(batch_size=[1])

In [10]:
env = FJSPEnv()
td = env.reset(test_inst.td_init)

In [13]:
td

TensorDict(
    fields={
        action_mask: Tensor(shape=torch.Size([1, 31]), device=cpu, dtype=torch.bool, is_shared=False),
        busy_until: Tensor(shape=torch.Size([1, 6]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        end_op_per_job: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.int64, is_shared=False),
        finish_times: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.float32, is_shared=False),
        is_ready: Tensor(shape=torch.Size([1, 10]), device=cpu, dtype=torch.bool, is_shared=False),
        job_done: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.bool, is_shared=False),
        job_in_process: Tensor(shape=torch.Size([1, 5]), device=cpu, dtype=torch.bool, is_shared=False),
        job_ops_adj: Tensor(shape=torch.Size([1, 5, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        lbs: Tensor(shape=torch.Size([1, 10]), de

In [13]:
encoder = HetGNNEncoder(embed_dim=32, num_layers=2)
(op_emb, ma_emb), init = encoder(td)
print(ma_emb.shape)
print(op_emb.shape)

torch.Size([1, 5, 32])
torch.Size([1, 10, 32])


In [39]:
decoder = L2DDecoder(env_name=env.name, embed_dim=32)
logits, mask = decoder(td, (op_emb, ma_emb), num_starts=0)

# (1 + num_jobs * num_machines)
print(logits.shape)

torch.Size([1, 2001])


In [40]:
def make_step(td):
    logits, mask = decoder(td, (op_emb, ma_emb), num_starts=0)
    action = logits.masked_fill(~mask, -torch.inf).argmax(1)
    td["action"] = action
    td = env.step(td)["next"]
    return td

In [None]:
env.render(td, 0)
# Update plot within a for loop
while not td["done"].all():
    # Clear the previous output for the next iteration
    # clear_output(wait=True)

    td = make_step(td)
    # env.render(td, 0)
    # Display updated plot
    # display(plt.gcf())
    
    # Pause for a moment to see the changes
    # time.sleep(.4)

In [44]:
td["time"]

tensor([624.])

In [49]:
if torch.cuda.is_available():
    accelerator = "gpu"
    batch_size = 256
    train_data_size = 2_000
    embed_dim = 128
    num_encoder_layers = 4
else:
    accelerator = "cpu"
    batch_size = 32
    train_data_size = 1_000
    embed_dim = 64
    num_encoder_layers = 2

In [50]:
# Policy: neural network, in this case with encoder-decoder architecture
policy = L2DPolicy(embed_dim=embed_dim, num_encoder_layers=num_encoder_layers, env_name="fjsp")

# Model: default is AM with REINFORCE and greedy rollout baseline
model = L2DModel(env,
                 policy=policy, 
                 baseline="rollout",
                 batch_size=batch_size,
                 train_data_size=train_data_size,
                 val_data_size=1_000,
                 optimizer_kwargs={"lr": 1e-4})

trainer = RL4COTrainer(
    max_epochs=30,
    accelerator=accelerator,
    devices=1,
    logger=None,
)

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [51]:
trainer.fit(model)

val_file not set. Generating dataset instead
test_file not set. Generating dataset instead
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params | Mode 
----------------------------------------------------
0 | env      | FJSPEnv        | 0      | train
1 | policy   | L2DPolicy      | 585 K  | train
2 | baseline | WarmupBaseline | 585 K  | train
----------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.682     Total estimated model params size (MB)
124       Modules in train mode
120       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=30` reached.


In [None]:
td = env.reset(test_inst.td_init)

In [65]:
(op_emb, ma_emb), init = model.policy.encoder(td)
print(ma_emb.shape)
print(op_emb.shape)

torch.Size([1, 20, 128])
torch.Size([1, 500, 128])


In [66]:
def make_step_model(td):
    logits, mask = model.policy.decoder(td, (op_emb, ma_emb), num_starts=0)
    action = logits.masked_fill(~mask, -torch.inf).argmax(1)
    td["action"] = action
    td = env.step(td)["next"]
    return td

In [67]:
# env.render(td, 0)
# Update plot within a for loop
while not td["done"].all():
    # Clear the previous output for the next iteration
    # clear_output(wait=True)

    td = make_step_model(td)
    # env.render(td, 0)
    # # Display updated plot
    # display(plt.gcf())
    
    # # Pause for a moment to see the changes
    # time.sleep(.4)

In [68]:
td["time"]

tensor([619.])

In [70]:
policy2 = L2DPolicy(embed_dim=embed_dim, num_encoder_layers=num_encoder_layers, env_name="fjsp")


model2 = L2DModel(env,
                 policy=policy2, 
                 baseline="rollout",
                 batch_size=batch_size,
                 train_data_size=train_data_size,
                 val_data_size=1_000,
                 optimizer_kwargs={"lr": 1e-4})

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.


In [None]:
model2.load_from_checkpoint("./checkpoints/l2d_checkpoints.ckpt")

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.policy.encoder.init_embedding.init_ops_embed.weight', 'baseline.baseline.policy.encoder.init_embedding.pos_encoder.pe', 'baseline.baseline.policy.encoder.init_embedding.init_ma_embed.weight', 'baseline.base

L2DModel(
  (env): FJSPEnv()
  (policy): L2DPolicy(
    (encoder): HetGNNEncoder(
      (init_embedding): FJSPInitEmbedding(
        (init_ops_embed): Linear(in_features=5, out_features=128, bias=False)
        (pos_encoder): PositionalEncoding(
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (init_ma_embed): Linear(in_features=1, out_features=128, bias=False)
        (edge_embed): Linear(in_features=1, out_features=128, bias=False)
      )
      (layers): ModuleList(
        (0-3): 4 x HetGNNBlock(
          (hgnn1): HetGNNLayer(
            (activation): ReLU()
          )
          (hgnn2): HetGNNLayer(
            (activation): ReLU()
          )
          (ffn1): TransformerFFN(
            (ops): ModuleDict(
              (norm1): Normalization(
                (normalizer): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (ffn): Sequential(
                (0): Linear(in_features=128, out_features

In [72]:
model3 = L2DModel.load_from_checkpoint("./checkpoints/l2d_checkpoints.ckpt")

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['policy'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/core/saving.py:195: Found keys that are not in the model state dict but in the checkpoint: ['baseline.baseline.policy.encoder.init_embedding.init_ops_embed.weight', 'baseline.baseline.policy.encoder.init_embedding.pos_encoder.pe', 'baseline.baseline.policy.encoder.init_embedding.init_ma_embed.weight', 'baseline.base

In [81]:
model3 = model3.to("cpu")

In [82]:
def make_step_model3(td):
    logits, mask = model3.policy.decoder(td, (op_emb, ma_emb), num_starts=0)
    action = logits.masked_fill(~mask, -torch.inf).argmax(1)
    td["action"] = action
    td = env.step(td)["next"]
    return td

In [83]:
# env.render(td, 0)
# Update plot within a for loop
while not td["done"].all():
    # Clear the previous output for the next iteration
    # clear_output(wait=True)

    td = make_step_model3(td)
    # env.render(td, 0)
    # # Display updated plot
    # display(plt.gcf())
    
    # # Pause for a moment to see the changes
    # time.sleep(.4)

In [84]:
td["time"]

tensor([619.])

In [86]:
from src.solvers.jsp.fjsp.rl4co.l2d import FJSPL2DSolver

In [88]:
solver = FJSPL2DSolver(train_data_size=train_data_size, batch_size=batch_size, max_epochs=15, accelerator=accelerator, lr=0.0001, val_data_size=1_000, 
                       embed_dim=embed_dim, num_encoder_layers=num_encoder_layers)

In [89]:
solver.fit(generator_params)

Unused keyword arguments: num_jobs, num_machines, min_ops_per_job, max_ops_per_job, min_processing_time, max_processing_time, min_eligible_ma_per_op, max_eligible_ma_per_op. Please check the base class documentation at https://rl4co.readthedocs.io/en/latest/_content/api/envs/base.html. In case you would like to pass data generation arguments, please pass a `generator` method instead or for example: `generator_kwargs=dict(num_loc=50)` to the constructor.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'env' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['env'])`.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'policy' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/home/daniil/programming/diplom/mtrlgnn/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (8) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=15` reached.


In [90]:
solver.save_model()

'/home/daniil/programming/diplom/mtrlgnn/notebooks/lightning_logs/version_6/checkpoints/epoch=14-step=120.ckpt'

In [91]:
solver2 = FJSPL2DSolver.load_from_checkpoint(solver.save_model())

RuntimeError: Error(s) in loading state_dict for FJSPL2DSolver:
	Unexpected key(s) in state_dict: "baseline.baseline.policy.encoder.init_embedding.init_ops_embed.weight", "baseline.baseline.policy.encoder.init_embedding.pos_encoder.pe", "baseline.baseline.policy.encoder.init_embedding.init_ma_embed.weight", "baseline.baseline.policy.encoder.init_embedding.edge_embed.weight", "baseline.baseline.policy.encoder.layers.0.hgnn1.self_attn", "baseline.baseline.policy.encoder.layers.0.hgnn1.cross_attn", "baseline.baseline.policy.encoder.layers.0.hgnn1.edge_attn", "baseline.baseline.policy.encoder.layers.0.hgnn2.self_attn", "baseline.baseline.policy.encoder.layers.0.hgnn2.cross_attn", "baseline.baseline.policy.encoder.layers.0.hgnn2.edge_attn", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.0.ffn1.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.0.ffn2.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.1.hgnn1.self_attn", "baseline.baseline.policy.encoder.layers.1.hgnn1.cross_attn", "baseline.baseline.policy.encoder.layers.1.hgnn1.edge_attn", "baseline.baseline.policy.encoder.layers.1.hgnn2.self_attn", "baseline.baseline.policy.encoder.layers.1.hgnn2.cross_attn", "baseline.baseline.policy.encoder.layers.1.hgnn2.edge_attn", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.1.ffn1.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.1.ffn2.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.2.hgnn1.self_attn", "baseline.baseline.policy.encoder.layers.2.hgnn1.cross_attn", "baseline.baseline.policy.encoder.layers.2.hgnn1.edge_attn", "baseline.baseline.policy.encoder.layers.2.hgnn2.self_attn", "baseline.baseline.policy.encoder.layers.2.hgnn2.cross_attn", "baseline.baseline.policy.encoder.layers.2.hgnn2.edge_attn", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.2.ffn1.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.2.ffn2.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.3.hgnn1.self_attn", "baseline.baseline.policy.encoder.layers.3.hgnn1.cross_attn", "baseline.baseline.policy.encoder.layers.3.hgnn1.edge_attn", "baseline.baseline.policy.encoder.layers.3.hgnn2.self_attn", "baseline.baseline.policy.encoder.layers.3.hgnn2.cross_attn", "baseline.baseline.policy.encoder.layers.3.hgnn2.edge_attn", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.3.ffn1.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm1.normalizer.weight", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm1.normalizer.bias", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm1.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm1.normalizer.running_var", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm1.normalizer.num_batches_tracked", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.ffn.0.weight", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.ffn.0.bias", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.ffn.2.weight", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.ffn.2.bias", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm2.normalizer.weight", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm2.normalizer.bias", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm2.normalizer.running_mean", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm2.normalizer.running_var", "baseline.baseline.policy.encoder.layers.3.ffn2.ops.norm2.normalizer.num_batches_tracked", "baseline.baseline.policy.decoder.actor.dummy", "baseline.baseline.policy.decoder.actor.mlp.lins.0.weight", "baseline.baseline.policy.decoder.actor.mlp.lins.0.bias", "baseline.baseline.policy.decoder.actor.mlp.lins.1.weight", "baseline.baseline.policy.decoder.actor.mlp.lins.1.bias", "baseline.baseline.policy.decoder.actor.mlp.lins.2.weight", "baseline.baseline.policy.decoder.actor.mlp.lins.2.bias", "policy.encoder.layers.1.hgnn1.self_attn", "policy.encoder.layers.1.hgnn1.cross_attn", "policy.encoder.layers.1.hgnn1.edge_attn", "policy.encoder.layers.1.hgnn2.self_attn", "policy.encoder.layers.1.hgnn2.cross_attn", "policy.encoder.layers.1.hgnn2.edge_attn", "policy.encoder.layers.1.ffn1.ops.norm1.normalizer.weight", "policy.encoder.layers.1.ffn1.ops.norm1.normalizer.bias", "policy.encoder.layers.1.ffn1.ops.norm1.normalizer.running_mean", "policy.encoder.layers.1.ffn1.ops.norm1.normalizer.running_var", "policy.encoder.layers.1.ffn1.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.1.ffn1.ops.ffn.0.weight", "policy.encoder.layers.1.ffn1.ops.ffn.0.bias", "policy.encoder.layers.1.ffn1.ops.ffn.2.weight", "policy.encoder.layers.1.ffn1.ops.ffn.2.bias", "policy.encoder.layers.1.ffn1.ops.norm2.normalizer.weight", "policy.encoder.layers.1.ffn1.ops.norm2.normalizer.bias", "policy.encoder.layers.1.ffn1.ops.norm2.normalizer.running_mean", "policy.encoder.layers.1.ffn1.ops.norm2.normalizer.running_var", "policy.encoder.layers.1.ffn1.ops.norm2.normalizer.num_batches_tracked", "policy.encoder.layers.1.ffn2.ops.norm1.normalizer.weight", "policy.encoder.layers.1.ffn2.ops.norm1.normalizer.bias", "policy.encoder.layers.1.ffn2.ops.norm1.normalizer.running_mean", "policy.encoder.layers.1.ffn2.ops.norm1.normalizer.running_var", "policy.encoder.layers.1.ffn2.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.1.ffn2.ops.ffn.0.weight", "policy.encoder.layers.1.ffn2.ops.ffn.0.bias", "policy.encoder.layers.1.ffn2.ops.ffn.2.weight", "policy.encoder.layers.1.ffn2.ops.ffn.2.bias", "policy.encoder.layers.1.ffn2.ops.norm2.normalizer.weight", "policy.encoder.layers.1.ffn2.ops.norm2.normalizer.bias", "policy.encoder.layers.1.ffn2.ops.norm2.normalizer.running_mean", "policy.encoder.layers.1.ffn2.ops.norm2.normalizer.running_var", "policy.encoder.layers.1.ffn2.ops.norm2.normalizer.num_batches_tracked", "policy.encoder.layers.2.hgnn1.self_attn", "policy.encoder.layers.2.hgnn1.cross_attn", "policy.encoder.layers.2.hgnn1.edge_attn", "policy.encoder.layers.2.hgnn2.self_attn", "policy.encoder.layers.2.hgnn2.cross_attn", "policy.encoder.layers.2.hgnn2.edge_attn", "policy.encoder.layers.2.ffn1.ops.norm1.normalizer.weight", "policy.encoder.layers.2.ffn1.ops.norm1.normalizer.bias", "policy.encoder.layers.2.ffn1.ops.norm1.normalizer.running_mean", "policy.encoder.layers.2.ffn1.ops.norm1.normalizer.running_var", "policy.encoder.layers.2.ffn1.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.2.ffn1.ops.ffn.0.weight", "policy.encoder.layers.2.ffn1.ops.ffn.0.bias", "policy.encoder.layers.2.ffn1.ops.ffn.2.weight", "policy.encoder.layers.2.ffn1.ops.ffn.2.bias", "policy.encoder.layers.2.ffn1.ops.norm2.normalizer.weight", "policy.encoder.layers.2.ffn1.ops.norm2.normalizer.bias", "policy.encoder.layers.2.ffn1.ops.norm2.normalizer.running_mean", "policy.encoder.layers.2.ffn1.ops.norm2.normalizer.running_var", "policy.encoder.layers.2.ffn1.ops.norm2.normalizer.num_batches_tracked", "policy.encoder.layers.2.ffn2.ops.norm1.normalizer.weight", "policy.encoder.layers.2.ffn2.ops.norm1.normalizer.bias", "policy.encoder.layers.2.ffn2.ops.norm1.normalizer.running_mean", "policy.encoder.layers.2.ffn2.ops.norm1.normalizer.running_var", "policy.encoder.layers.2.ffn2.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.2.ffn2.ops.ffn.0.weight", "policy.encoder.layers.2.ffn2.ops.ffn.0.bias", "policy.encoder.layers.2.ffn2.ops.ffn.2.weight", "policy.encoder.layers.2.ffn2.ops.ffn.2.bias", "policy.encoder.layers.2.ffn2.ops.norm2.normalizer.weight", "policy.encoder.layers.2.ffn2.ops.norm2.normalizer.bias", "policy.encoder.layers.2.ffn2.ops.norm2.normalizer.running_mean", "policy.encoder.layers.2.ffn2.ops.norm2.normalizer.running_var", "policy.encoder.layers.2.ffn2.ops.norm2.normalizer.num_batches_tracked", "policy.encoder.layers.3.hgnn1.self_attn", "policy.encoder.layers.3.hgnn1.cross_attn", "policy.encoder.layers.3.hgnn1.edge_attn", "policy.encoder.layers.3.hgnn2.self_attn", "policy.encoder.layers.3.hgnn2.cross_attn", "policy.encoder.layers.3.hgnn2.edge_attn", "policy.encoder.layers.3.ffn1.ops.norm1.normalizer.weight", "policy.encoder.layers.3.ffn1.ops.norm1.normalizer.bias", "policy.encoder.layers.3.ffn1.ops.norm1.normalizer.running_mean", "policy.encoder.layers.3.ffn1.ops.norm1.normalizer.running_var", "policy.encoder.layers.3.ffn1.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.3.ffn1.ops.ffn.0.weight", "policy.encoder.layers.3.ffn1.ops.ffn.0.bias", "policy.encoder.layers.3.ffn1.ops.ffn.2.weight", "policy.encoder.layers.3.ffn1.ops.ffn.2.bias", "policy.encoder.layers.3.ffn1.ops.norm2.normalizer.weight", "policy.encoder.layers.3.ffn1.ops.norm2.normalizer.bias", "policy.encoder.layers.3.ffn1.ops.norm2.normalizer.running_mean", "policy.encoder.layers.3.ffn1.ops.norm2.normalizer.running_var", "policy.encoder.layers.3.ffn1.ops.norm2.normalizer.num_batches_tracked", "policy.encoder.layers.3.ffn2.ops.norm1.normalizer.weight", "policy.encoder.layers.3.ffn2.ops.norm1.normalizer.bias", "policy.encoder.layers.3.ffn2.ops.norm1.normalizer.running_mean", "policy.encoder.layers.3.ffn2.ops.norm1.normalizer.running_var", "policy.encoder.layers.3.ffn2.ops.norm1.normalizer.num_batches_tracked", "policy.encoder.layers.3.ffn2.ops.ffn.0.weight", "policy.encoder.layers.3.ffn2.ops.ffn.0.bias", "policy.encoder.layers.3.ffn2.ops.ffn.2.weight", "policy.encoder.layers.3.ffn2.ops.ffn.2.bias", "policy.encoder.layers.3.ffn2.ops.norm2.normalizer.weight", "policy.encoder.layers.3.ffn2.ops.norm2.normalizer.bias", "policy.encoder.layers.3.ffn2.ops.norm2.normalizer.running_mean", "policy.encoder.layers.3.ffn2.ops.norm2.normalizer.running_var", "policy.encoder.layers.3.ffn2.ops.norm2.normalizer.num_batches_tracked". 
	size mismatch for policy.encoder.init_embedding.init_ops_embed.weight: copying a param with shape torch.Size([128, 5]) from checkpoint, the shape in current model is torch.Size([32, 5]).
	size mismatch for policy.encoder.init_embedding.pos_encoder.pe: copying a param with shape torch.Size([1, 1000, 128]) from checkpoint, the shape in current model is torch.Size([1, 1000, 32]).
	size mismatch for policy.encoder.init_embedding.init_ma_embed.weight: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.init_embedding.edge_embed.weight: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn1.self_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn1.cross_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn1.edge_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn2.self_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn2.cross_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.hgnn2.edge_attn: copying a param with shape torch.Size([128, 1]) from checkpoint, the shape in current model is torch.Size([32, 1]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm1.normalizer.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm1.normalizer.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm1.normalizer.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm1.normalizer.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.ffn.0.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([64, 32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.ffn.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.ffn.2.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([32, 64]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.ffn.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm2.normalizer.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm2.normalizer.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm2.normalizer.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn1.ops.norm2.normalizer.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm1.normalizer.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm1.normalizer.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm1.normalizer.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm1.normalizer.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.ffn.0.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([64, 32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.ffn.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.ffn.2.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([32, 64]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.ffn.2.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm2.normalizer.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm2.normalizer.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm2.normalizer.running_mean: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.encoder.layers.0.ffn2.ops.norm2.normalizer.running_var: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.decoder.actor.dummy: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for policy.decoder.actor.mlp.lins.0.weight: copying a param with shape torch.Size([128, 256]) from checkpoint, the shape in current model is torch.Size([32, 64]).
	size mismatch for policy.decoder.actor.mlp.lins.0.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.decoder.actor.mlp.lins.1.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([32, 32]).
	size mismatch for policy.decoder.actor.mlp.lins.1.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([32]).
	size mismatch for policy.decoder.actor.mlp.lins.2.weight: copying a param with shape torch.Size([1, 128]) from checkpoint, the shape in current model is torch.Size([1, 32]).

In [92]:
checkpoint = torch.load(solver.save_model(), map_location='cpu')
print(checkpoint.keys())  # Проверьте наличие 'hyper_parameters'
print(checkpoint['hyper_parameters'])  # Какие параметры там сохранены

UnpicklingError: Weights only load failed. This file can still be loaded, to do so you have two options, [1mdo those steps only if you trust the source of the checkpoint[0m. 
	(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.
	(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.
	WeightsUnpickler error: Unsupported global: GLOBAL rl4co.envs.scheduling.fjsp.env.FJSPEnv was not an allowed global by default. Please use `torch.serialization.add_safe_globals([FJSPEnv])` or the `torch.serialization.safe_globals([FJSPEnv])` context manager to allowlist this global if you trust this class/function.

Check the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.

In [94]:
torch.save(solver.policy, "l2d.pt")