Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] TOML saving and loading hyperparams #286

Merged
merged 10 commits into from
Sep 12, 2020
Merged
48 changes: 48 additions & 0 deletions docs/source/usage/tutorials/Saving and loading.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
Saving and Loading Weights and Hyperparameters with GenRL
=========================================================

We often want to checkpoint our training model in the RL setting, GenRL offers to save your hyperparameters and weights using TOML and pytorch state_dict respectively.

Following is a sample code to save checkpoints -

.. code-block:: python
import gym
import shutil

from genrl.agents import VPG
from genrl.environments.suite import VectorEnv
from genrl.core import NormalActionNoise
from genrl.trainers import OnPolicyTrainer

env = VectorEnv("CartPole-v0", 2)
algo = VPG("mlp", env, batch_size=5, replay_size=100)

trainer = OnPolicyTrainer(
algo,
env,
log_mode=["stdout"],
logdir="./logs",
save_interval=100,
epochs=100,
evaluate_episodes=2,
)
trainer.train()
trainer.evaluate()
shutil.rmtree("./logs")

Let's say you have a saved weights and hyperparameters file to load onto the model you can change your trainer as below to load it -

.. code-block:: python
trainer = OnPolicyTrainer(
algo,
env,
log_mode=["stdout"],
logdir="./logs",
save_interval=100,
epochs=100,
evaluate_episodes=2,
load_weights="./checkpoints/VPG_CartPole-v0/1-log-0.pt",
load_hyperparams="./checkpoints/VPG_CartPole-v0/1-log-0.toml",
)


1 change: 1 addition & 0 deletions docs/source/usage/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ Tutorials
Using Custom Policies
Using A2C
using_vpg
Saving and loading
10 changes: 5 additions & 5 deletions genrl/agents/deep/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
Expand All @@ -190,17 +191,16 @@ def get_hyperparams(self) -> Dict[str, Any]:
"lr_policy": self.lr_policy,
"lr_value": self.lr_value,
"rollout_size": self.rollout_size,
"weights": self.ac.state_dict(),
}
return hyperparams
return hyperparams, self.ac.state_dict()

def load_weights(self, weights) -> None:
def _load_weights(self, weights) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to add the _?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has the same name as the parameter and we are not going to let user get to this :P

"""Load weights for the agent from pretrained model

Args:
weights (:obj:`dict`): Dictionary of different neural net weights
weights (:obj:`torch.Tensor`): neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.ac.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ def get_hyperparams(self) -> Dict[str, Any]:
"""
raise NotImplementedError

def load_weights(self, weights) -> None:
def _load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model

Args:
weights (:obj:`dict`): Dictionary of different neural net weights
weights (:obj:`torch.tensor`): neural net weights
"""

raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,10 @@ def get_p_loss(self, states: torch.Tensor) -> torch.Tensor:
policy_loss = -torch.mean(q_values)
return policy_loss

def load_weights(self, weights) -> None:
def _load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model

Args:
weights (:obj:`dict`): Dictionary of different neural net weights
weights (:obj:`torch.Tensor`): neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.ac.load_state_dict(weights)
4 changes: 2 additions & 2 deletions genrl/agents/deep/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural Network weights
"""
hyperparams = {
"network": self.network,
Expand All @@ -119,9 +120,8 @@ def get_hyperparams(self) -> Dict[str, Any]:
"noise_std": self.noise_std,
"lr_policy": self.lr_policy,
"lr_value": self.lr_value,
"weights": self.ac.state_dict(),
}
return hyperparams
return hyperparams, self.ac.state_dict()

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
7 changes: 4 additions & 3 deletions genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"gamma": self.gamma,
Expand All @@ -226,15 +227,15 @@ def get_hyperparams(self) -> Dict[str, Any]:
"weights": self.model.state_dict(),
"timestep": self.timestep,
}
return hyperparams
return hyperparams, self.model.state_dict()

def load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model

Args:
weights (:obj:`Dict`): Dictionary of different neural net weights
weights (:obj:`torch.Tensor`): neural net weights
"""
self.model.load_state_dict(weights["weights"])
self.model.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
8 changes: 4 additions & 4 deletions genrl/agents/deep/ppo1/ppo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
Expand All @@ -197,18 +198,17 @@ def get_hyperparams(self) -> Dict[str, Any]:
"lr_policy": self.lr_policy,
"lr_value": self.lr_value,
"rollout_size": self.rollout_size,
"weights": self.ac.state_dict(),
}

return hyperparams
return hyperparams, self.ac.state_dict()

def load_weights(self, weights) -> None:
def _load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model

Args:
weights (:obj:`dict`): Dictionary of different neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.ac.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
Expand All @@ -232,9 +233,8 @@ def get_hyperparams(self) -> Dict[str, Any]:
"entropy_tuning": self.entropy_tuning,
"alpha": self.alpha,
"polyak": self.polyak,
"weights": self.ac.state_dict(),
}
return hyperparams
return hyperparams, self.ac.state_dict()

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
4 changes: 2 additions & 2 deletions genrl/agents/deep/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
Expand All @@ -138,10 +139,9 @@ def get_hyperparams(self) -> Dict[str, Any]:
"polyak": self.polyak,
"policy_frequency": self.policy_frequency,
"noise_std": self.noise_std,
"weights": self.ac.state_dict(),
}

return hyperparams
return hyperparams, self.ac.state_dict()

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
8 changes: 4 additions & 4 deletions genrl/agents/deep/vpg/vpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,25 +141,25 @@ def get_hyperparams(self) -> Dict[str, Any]:

Returns:
hyperparams (:obj:`dict`): Hyperparameters to be saved
weights (:obj:`torch.Tensor`): Neural network weights
"""
hyperparams = {
"network": self.network,
"batch_size": self.batch_size,
"gamma": self.gamma,
"lr_policy": self.lr_policy,
"rollout_size": self.rollout_size,
"weights": self.ac.state_dict(),
}

return hyperparams
return hyperparams, self.actor.state_dict()

def load_weights(self, weights) -> None:
def _load_weights(self, weights) -> None:
"""Load weights for the agent from pretrained model

Args:
weights (:obj:`dict`): Dictionary of different neural net weights
"""
self.ac.load_state_dict(weights["weights"])
self.actor.load_state_dict(weights)

def get_logging_params(self) -> Dict[str, Any]:
"""Gets relevant parameters for logging
Expand Down
46 changes: 27 additions & 19 deletions genrl/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gym
import numpy as np
import toml
import torch

from genrl.environments.vec_env import VecEnv
Expand All @@ -28,7 +29,8 @@ class Trainer(ABC):
save_interval (int): Timesteps between successive saves of the agent's important hyperparameters
save_model (str): Directory where the checkpoints of agent parameters should be saved
run_num (int): A run number allotted to the save of parameters
load_model (str): File to load saved parameter checkpoint from
load_weights (str): Weights file
load_hyperparams (str): File to load hyperparameters
render (bool): True if environment is to be rendered during training, else False
evaluate_episodes (int): Number of episodes to evaluate for
seed (int): Set seed for reproducibility
Expand All @@ -48,7 +50,8 @@ def __init__(
save_interval: int = 0,
save_model: str = "checkpoints",
run_num: int = None,
load_model: str = None,
load_weights: str = None,
load_hyperparams: str = None,
render: bool = False,
evaluate_episodes: int = 50,
seed: Optional[int] = None,
Expand All @@ -65,7 +68,8 @@ def __init__(
self.save_interval = save_interval
self.save_model = save_model
self.run_num = run_num
self.load_model = load_model
self.load_weights = load_weights
self.load_hyperparams = load_hyperparams
self.render = render
self.evaluate_episodes = evaluate_episodes

Expand Down Expand Up @@ -146,30 +150,34 @@ def save(self, timestep: int) -> None:
run_num = int(last_path[len(path) + 1 :].split("-")[0]) + 1
self.run_num = run_num

torch.save(
self.agent.get_hyperparams(),
"{}/{}-log-{}.pt".format(path, run_num, timestep),
)
filename_hyperparams = "{}/{}-log-{}.toml".format(path, run_num, timestep)
filename_weights = "{}/{}-log-{}.pt".format(path, run_num, timestep)
hyperparameters, weights = self.agent.get_hyperparams()
with open(filename_hyperparams, mode="w") as f:
toml.dump(hyperparameters, f)

torch.save(weights, filename_weights)

def load(self):
"""Function to load saved parameters of a given agent"""
path = self.load_model
try:
self.checkpoint = torch.load(path)
except FileNotFoundError:
raise Exception("Invalid File Name")
self.checkpoint_hyperparams = {}
with open(self.load_hyperparams, mode="r") as f:
self.checkpoint_hyperparams = toml.load(f, _dict=dict)

weights = {}

for key, item in self.checkpoint.items():
if "weights" not in key:
for key, item in self.checkpoint_hyperparams.items():
setattr(self, key, item)
else:
weights[key] = item

self.agent.load_weights(weights)
except FileNotFoundError:
raise Exception("Invalid hyperparameters File Name")

try:
self.checkpoint_weights = torch.load(self.load_weights)
self.agent._load_weights(self.checkpoint_weights)
except FileNotFoundError:
raise Exception("Invalid weights File Name")

print("Loaded Pretrained Model!")
print("Loaded Pretrained Model weights and hyperparameters!")

@property
def n_envs(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion genrl/trainers/offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def check_game_over_status(self, timestep: int, dones: List[bool]) -> bool:

def train(self) -> None:
"""Main training method"""
if self.load_model is not None:
if self.load_weights is not None or self.load_hyperparams is not None:
self.load()

state = self.env.reset()
Expand Down
2 changes: 1 addition & 1 deletion genrl/trainers/onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, *args, **kwargs):

def train(self) -> None:
"""Main training method"""
if self.load_model is not None:
if self.load_weights is not None or self.load_hyperparams is not None:
self.load()

for epoch in range(self.epochs):
Expand Down
6 changes: 5 additions & 1 deletion tests/test_deep/test_common/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ def test_load_params():
env = VectorEnv("CartPole-v0", 1)
algo = PPO1("mlp", env)
trainer = OnPolicyTrainer(
algo, env, epochs=0, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt"
algo,
env,
epochs=0,
load_hyperparams="test_ckpt/PPO1_CartPole-v0/0-log-0.toml",
load_weights="test_ckpt/PPO1_CartPole-v0/0-log-0.pt",
)
trainer.train()

Expand Down