Testing saving and loading a model

In [None]:
import torch
import torch.nn as nn  # noqa: F401
import safety_gymnasium

from rl_vcf.rl.algos.sac.core import MLPActorCritic
from rl_vcf.rl.algos.ppo.core import MLPActorCritic as PPO_MLPActorCritic
from rl_vcf.rl.utils import make_env_safety, get_actor_structure

In [2]:
seed = 0
device = "cuda"
num_envs = 4
envs = safety_gymnasium.vector.SafetySyncVectorEnv(
    [
        make_env_safety(
            "SafetyPointReachAvoidReset1-v0",
            i,
            seed + i,
            False,
            5,
            False,
        )
        for i in range(num_envs)
    ]
)

In [3]:
agent = MLPActorCritic(
    envs.single_observation_space,
    envs.single_action_space,
    hidden_sizes=[256, 256],
    # eval("nn." + cfg.network.activation + "()"),
    activation=torch.nn.ReLU(),
).to(device)
print(agent)
print(agent.pi.get_activation())

MLPActorCritic(
  (pi): MLPSquashedGaussianActor(
    (net): Sequential(
      (0): Linear(in_features=44, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
    )
    (mu_layer): Linear(in_features=256, out_features=2, bias=True)
    (log_std_layer): Linear(in_features=256, out_features=2, bias=True)
  )
  (q1): MLPCritic(
    (net): Sequential(
      (0): Linear(in_features=46, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
  (q2): MLPCritic(
    (net): Sequential(
      (0): Linear(in_features=46, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
)
ReLU


In [4]:
print(agent.pi.net[0].weight)
print(agent.q1.net[0].weight)

Parameter containing:
tensor([[ 0.1478, -0.0357, -0.0117,  ...,  0.0288, -0.0008,  0.0562],
        [ 0.0965,  0.0658,  0.0555,  ...,  0.0754,  0.0075,  0.0837],
        [ 0.0432,  0.0340,  0.0793,  ...,  0.0730, -0.1242, -0.0525],
        ...,
        [ 0.0005,  0.0656,  0.0367,  ...,  0.0148,  0.0011, -0.0078],
        [-0.0021,  0.1091, -0.0414,  ...,  0.1039,  0.1211,  0.1118],
        [-0.0050, -0.0661, -0.1216,  ..., -0.0382,  0.0684, -0.0654]],
       device='cuda:0', requires_grad=True)
Parameter containing:
tensor([[ 0.0620,  0.0172, -0.1164,  ..., -0.0475,  0.0781, -0.0094],
        [-0.0394,  0.0615,  0.1285,  ..., -0.0195,  0.0522,  0.0555],
        [-0.0680,  0.0966, -0.0717,  ..., -0.0365,  0.0752,  0.1152],
        ...,
        [-0.0188, -0.0068,  0.0145,  ..., -0.1233, -0.1202, -0.0935],
        [-0.0079, -0.1318, -0.0112,  ...,  0.1308, -0.0311,  0.1395],
        [-0.0119, -0.0886,  0.1142,  ...,  0.0630,  0.0803,  0.0563]],
       device='cuda:0', requires_grad=True)


In [5]:
torch.save(agent.pi.state_dict(), "test_policy.pt")
torch.save(agent.state_dict(), "test_agent.pt")

In [6]:
loaded_state_dict = torch.load("test_policy.pt", weights_only=True, map_location=device)
loaded_hidden_sizes, loaded_activation = get_actor_structure(
    loaded_state_dict, envs.single_observation_space, envs.single_action_space
)
print(loaded_hidden_sizes)
print(loaded_activation)

[256, 256]
ReLU


In [7]:
loaded_agent_state_dict = torch.load(
    "test_agent.pt", weights_only=True, map_location=device
)
print(loaded_agent_state_dict.keys())

odict_keys(['pi.act_scale', 'pi.act_bias', 'pi.activation', 'pi.net.0.weight', 'pi.net.0.bias', 'pi.net.2.weight', 'pi.net.2.bias', 'pi.mu_layer.weight', 'pi.mu_layer.bias', 'pi.log_std_layer.weight', 'pi.log_std_layer.bias', 'q1.net.0.weight', 'q1.net.0.bias', 'q1.net.2.weight', 'q1.net.2.bias', 'q1.net.4.weight', 'q1.net.4.bias', 'q2.net.0.weight', 'q2.net.0.bias', 'q2.net.2.weight', 'q2.net.2.bias', 'q2.net.4.weight', 'q2.net.4.bias'])


In [7]:
new_agent = MLPActorCritic(
    envs.single_observation_space,
    envs.single_action_space,
    loaded_hidden_sizes,
    eval("nn." + loaded_activation + "()"),
)
print(new_agent)
print(new_agent.pi.get_activation())

MLPActorCritic(
  (pi): MLPSquashedGaussianActor(
    (net): Sequential(
      (0): Linear(in_features=44, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
    )
    (mu_layer): Linear(in_features=256, out_features=2, bias=True)
    (log_std_layer): Linear(in_features=256, out_features=2, bias=True)
  )
  (q1): MLPCritic(
    (net): Sequential(
      (0): Linear(in_features=46, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
  (q2): MLPCritic(
    (net): Sequential(
      (0): Linear(in_features=46, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
)
ReLU


In [8]:
print(new_agent.pi.net[0].weight)
print(new_agent.q1.net[0].weight)
new_agent.pi.load_state_dict(loaded_state_dict, strict=True)
new_agent.to(device)
print(new_agent.pi.net[0].weight)
print(new_agent.q1.net[0].weight)

Parameter containing:
tensor([[ 0.0991,  0.0378,  0.1220,  ...,  0.0729, -0.0199, -0.0672],
        [-0.0644, -0.0175,  0.0420,  ...,  0.0534,  0.0914,  0.1075],
        [ 0.1151,  0.1133,  0.1427,  ..., -0.0643,  0.0560, -0.0704],
        ...,
        [ 0.1182,  0.0561, -0.0307,  ...,  0.1106, -0.0657,  0.0514],
        [-0.1204,  0.1128, -0.1491,  ..., -0.1187,  0.0498,  0.1347],
        [-0.0143, -0.0027,  0.0584,  ...,  0.1497,  0.1215,  0.0417]],
       requires_grad=True)
Parameter containing:
tensor([[-0.0281, -0.0707,  0.1460,  ..., -0.0548,  0.0095, -0.0572],
        [ 0.1247,  0.0280, -0.0600,  ..., -0.1309,  0.1410, -0.0218],
        [ 0.0449, -0.1462, -0.0489,  ..., -0.0056, -0.0667,  0.0621],
        ...,
        [-0.1035,  0.0369, -0.1432,  ...,  0.1383, -0.1434, -0.1373],
        [-0.0095,  0.0188, -0.0195,  ...,  0.1038, -0.0553, -0.0301],
        [-0.0209, -0.0409,  0.0363,  ..., -0.0510, -0.1251, -0.0905]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.

In [9]:
ppo_agent = PPO_MLPActorCritic(
    envs.single_observation_space,
    envs.single_action_space,
    loaded_hidden_sizes,
    eval("nn." + loaded_activation + "()"),
    True,
)
print(ppo_agent)
print(ppo_agent.pi.get_activation())

MLPActorCritic(
  (pi): MLPSquashedGaussianActor(
    (net): Sequential(
      (0): Linear(in_features=44, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
    )
    (mu_layer): Linear(in_features=256, out_features=2, bias=True)
    (log_std_layer): Linear(in_features=256, out_features=2, bias=True)
  )
  (v): MLPCritic(
    (net): Sequential(
      (0): Linear(in_features=44, out_features=256, bias=True)
      (1): ReLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ReLU()
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Identity()
    )
  )
)
ReLU


In [10]:
print(ppo_agent.pi.net[0].weight)
print(ppo_agent.v.net[0].weight)
ppo_agent.pi.load_state_dict(loaded_state_dict, strict=True)
ppo_agent.to(device)
print(ppo_agent.pi.net[0].weight)
print(ppo_agent.v.net[0].weight)

Parameter containing:
tensor([[ 0.0014, -0.0523,  0.0203,  ...,  0.1303,  0.0200,  0.0484],
        [-0.0374, -0.1132, -0.1378,  ..., -0.1458,  0.0101,  0.0622],
        [ 0.1842, -0.0429, -0.0459,  ..., -0.0407, -0.0374, -0.0652],
        ...,
        [-0.2086,  0.2150,  0.1201,  ..., -0.0284, -0.0738, -0.0856],
        [-0.1056,  0.0692, -0.0201,  ...,  0.0167, -0.0083,  0.0625],
        [ 0.0227, -0.0055, -0.0102,  ..., -0.0587,  0.0105, -0.1043]],
       requires_grad=True)
Parameter containing:
tensor([[-0.1743, -0.1366, -0.0315,  ..., -0.1350, -0.0731, -0.1107],
        [ 0.0010, -0.1279,  0.1020,  ...,  0.1600, -0.0319,  0.1045],
        [ 0.0131,  0.0114,  0.0543,  ..., -0.1943, -0.0229,  0.0381],
        ...,
        [-0.0665,  0.0321, -0.0273,  ..., -0.0963,  0.1352, -0.0017],
        [-0.1477, -0.1125, -0.0514,  ...,  0.0006,  0.0386,  0.0470],
        [-0.0880, -0.0126, -0.0906,  ..., -0.0248,  0.0592, -0.1519]],
       requires_grad=True)
Parameter containing:
tensor([[ 0.