Skip to content

Commit

Permalink
Add save/load weights for policies and refactor action distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Mar 31, 2020
1 parent b782f3a commit fdecd51
Show file tree
Hide file tree
Showing 11 changed files with 319 additions and 211 deletions.
4 changes: 4 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ Pre-Release 0.4.0a0 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- Removed CEMRL
- Model saved with previous versions cannot be loaded (because of the pre-preprocessing)

New Features:
^^^^^^^^^^^^^
- Add support for Discrete observation spaces
- Add saving/loading for policy weights, so the policy can be used without the model

Bug Fixes:
^^^^^^^^^^
Expand All @@ -26,6 +28,8 @@ Others:
^^^^^^^
- Refactor handling of observation and action spaces
- Refactored features extraction to have proper preprocessing
- Refactored action distributions


Documentation:
^^^^^^^^^^^^^^
Expand Down
14 changes: 9 additions & 5 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_squashed_gaussian(model_class):
gaussian_mean = th.rand(N_SAMPLES, N_ACTIONS)
dist = SquashedDiagGaussianDistribution(N_ACTIONS)
_, log_std = dist.proba_distribution_net(N_FEATURES)
actions, _ = dist.proba_distribution(gaussian_mean, log_std)
dist = dist.proba_distribution(gaussian_mean, log_std)
actions = dist.get_action()
assert th.max(th.abs(actions)) <= 1.0

def test_sde_distribution():
Expand All @@ -51,7 +52,8 @@ def test_sde_distribution():
_, log_std = dist.proba_distribution_net(N_FEATURES)
dist.sample_weights(log_std, batch_size=N_SAMPLES)

actions, _ = dist.proba_distribution(deterministic_actions, log_std, state)
dist = dist.proba_distribution(deterministic_actions, log_std, state)
actions = dist.get_action()

assert th.allclose(actions.mean(), dist.distribution.mean.mean(), rtol=1e-3)
assert th.allclose(actions.std(), dist.distribution.scale.mean(), rtol=1e-3)
Expand All @@ -71,11 +73,12 @@ def test_entropy(dist):
_, log_std = dist.proba_distribution_net(N_FEATURES, log_std_init=th.log(th.tensor(0.2)))

if isinstance(dist, DiagGaussianDistribution):
actions, dist = dist.proba_distribution(deterministic_actions, log_std)
dist = dist.proba_distribution(deterministic_actions, log_std)
else:
dist.sample_weights(log_std, batch_size=N_SAMPLES)
actions, dist = dist.proba_distribution(deterministic_actions, log_std, state)
dist = dist.proba_distribution(deterministic_actions, log_std, state)

actions = dist.get_action()
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=5e-3)
Expand All @@ -88,8 +91,9 @@ def test_categorical():
set_random_seed(1)
state = th.rand(N_SAMPLES, N_FEATURES)
action_logits = th.rand(N_SAMPLES, N_ACTIONS)
actions, dist = dist.proba_distribution(action_logits)
dist = dist.proba_distribution(action_logits)

actions = dist.get_action()
entropy = dist.entropy()
log_prob = dist.log_prob(actions)
assert th.allclose(entropy.mean(), -log_prob.mean(), rtol=1e-4)
2 changes: 1 addition & 1 deletion tests/test_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_continuous(model_class):
env = IdentityEnvBox(eps=0.5)

n_steps = {
A2C: 3000,
A2C: 3500,
PPO: 3000,
SAC: 700,
TD3: 500
Expand Down
60 changes: 59 additions & 1 deletion tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SAC,
]


#
@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load(model_class):
"""
Expand Down Expand Up @@ -160,3 +160,61 @@ def test_save_load_replay_buffer(model_class):

# clear file from os
os.remove(replay_path)


@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_save_load_policy(model_class):
"""
Test saving and loading policy only.
:param model_class: (BaseRLModel) A RL model
"""
env = DummyVecEnv([lambda: IdentityEnvBox(10)])

# create model
model = model_class('MlpPolicy', env, policy_kwargs=dict(net_arch=[16]), verbose=1, create_eval_env=True)
model.learn(total_timesteps=500, eval_freq=250)

env.reset()
observations = np.array([env.step(env.action_space.sample())[0] for _ in range(10)])
observations = observations.reshape(10, -1)

policy = model.policy

# Get dictionary of current parameters
params = deepcopy(policy.state_dict())

# Modify all parameters to be random values
random_params = dict((param_name, th.rand_like(param)) for param_name, param in params.items())

# Update model parameters with the new random values
policy.load_state_dict(random_params)

new_params = policy.state_dict()
# Check that all params are different now
for k in params:
assert not th.allclose(params[k], new_params[k]), "Parameters did not change as expected."

params = new_params

# get selected actions
selected_actions, _ = policy.predict(observations, deterministic=True)

# Save and load policy
policy.save("./logs/policy_weights.pkl")
# del policy
policy.load("./logs/policy_weights.pkl")

# check if params are still the same after load
new_params = policy.state_dict()

# Check that all params are the same as before save load procedure now
for key in params:
assert th.allclose(params[key], new_params[key]), "Policy parameters not the same after save and load."

# check if model still selects the same actions
new_selected_actions, _ = policy.predict(observations, deterministic=True)
assert np.allclose(selected_actions, new_selected_actions, 1e-4)

# clear file from os
os.remove("./logs/policy_weights.pkl")
126 changes: 4 additions & 122 deletions torchy_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,27 +158,6 @@ def _get_eval_env(self, eval_env: Optional[GymEnv]) -> Optional[GymEnv]:
assert eval_env.num_envs == 1
return eval_env

def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
:param action: (np.ndarray) Action to scale
:return: (np.ndarray) Scaled action
"""
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0

def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
:param scaled_action: Action to un-scale
"""
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))

def _setup_lr_schedule(self) -> None:
"""Transform to callable if needed."""
self.lr_schedule = get_schedule_fn(self.learning_rate)
Expand Down Expand Up @@ -318,57 +297,6 @@ def learn(self, total_timesteps: int,
"""
raise NotImplementedError()

@staticmethod
def _is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool:
"""
For every observation type, detects and validates the shape,
then returns whether or not the observation is vectorized.
:param observation: (np.ndarray) the input observation to validate
:param observation_space: (gym.spaces) the observation space
:return: (bool) whether the given observation is vectorized or not
"""
if isinstance(observation_space, gym.spaces.Box):
if observation.shape == observation_space.shape:
return False
elif observation.shape[1:] == observation_space.shape:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Box environment, please use {} ".format(observation_space.shape) +
"or (n_env, {}) for the observation shape."
.format(", ".join(map(str, observation_space.shape))))
elif isinstance(observation_space, gym.spaces.Discrete):
if observation.shape == (): # A numpy array of a number, has shape empty tuple '()'
return False
elif len(observation.shape) == 1:
return True
else:
raise ValueError("Error: Unexpected observation shape {} for ".format(observation.shape) +
"Discrete environment, please use (1,) or (n_env, 1) for the observation shape.")
# TODO: add support for MultiDiscrete and MultiBinary observation spaces
# elif isinstance(observation_space, gym.spaces.MultiDiscrete):
# if observation.shape == (len(observation_space.nvec),):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec):
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiDiscrete ".format(observation.shape) +
# "environment, please use ({},) or ".format(len(observation_space.nvec)) +
# "(n_env, {}) for the observation shape.".format(len(observation_space.nvec)))
# elif isinstance(observation_space, gym.spaces.MultiBinary):
# if observation.shape == (observation_space.n,):
# return False
# elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n:
# return True
# else:
# raise ValueError("Error: Unexpected observation shape {} for MultiBinary ".format(observation.shape) +
# "environment, please use ({},) or ".format(observation_space.n) +
# "(n_env, {}) for the observation shape.".format(observation_space.n))
else:
raise ValueError("Error: Cannot determine if the observation is vectorized with the space type {}."
.format(observation_space))

def predict(self, observation: np.ndarray,
state: Optional[np.ndarray] = None,
mask: Optional[np.ndarray] = None,
Expand All @@ -383,36 +311,7 @@ def predict(self, observation: np.ndarray,
:return: (Tuple[np.ndarray, Optional[np.ndarray]]) the model's action and the next state
(used in recurrent policies)
"""
# TODO: move this block to BasePolicy
# if state is None:
# state = self.initial_state
# if mask is None:
# mask = [False for _ in range(self.n_envs)]
observation = np.array(observation)
vectorized_env = self._is_vectorized_observation(observation, self.observation_space)

observation = observation.reshape((-1,) + self.observation_space.shape)
observation = th.as_tensor(observation).to(self.device)
with th.no_grad():
actions = self.policy.predict(observation, deterministic=deterministic)
# Convert to numpy
actions = actions.cpu().numpy()

# Rescale to proper domain when using squashing
if isinstance(self.action_space, gym.spaces.Box) and self.policy.squash_output:
actions = self.unscale_action(actions)

clipped_actions = actions
# Clip the actions to avoid out of bound error when using gaussian distribution
if isinstance(self.action_space, gym.spaces.Box) and not self.policy.squash_output:
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

if not vectorized_env:
if state is not None:
raise ValueError("Error: The environment must be vectorized when using recurrent policies.")
clipped_actions = clipped_actions[0]

return clipped_actions, state
return self.policy.predict(observation, state, mask, deterministic)

@classmethod
def load(cls, load_path: str, env: Optional[GymEnv] = None, **kwargs):
Expand Down Expand Up @@ -484,10 +383,7 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
raise ValueError(f"Error: the file {load_path} could not be found")

# set device to cpu if cuda is not available
if th.cuda.is_available():
device = th.device('cuda')
else:
device = th.device('cpu')
device = th.device('cuda') if th.cuda.is_available() else th.device('cpu')

# Open the zip archive and load data
try:
Expand Down Expand Up @@ -534,20 +430,6 @@ def _load_from_file(load_path: str, load_data: bool = True) -> (Tuple[Optional[D
# load the parameters with the right `map_location`
params[os.path.splitext(file_path)[0]] = th.load(file_content, map_location=device)

# for backward compatibility
if params.get('params') is not None:
params_copy = {}
for name in params:
if name == 'params':
params_copy['policy'] = params[name]
elif name == 'opt':
params_copy['policy.optimizer'] = params[name]
# Special case for SAC
elif name == 'ent_coef_optimizer':
params_copy[name] = params[name]
else:
params_copy[name + '.optimizer'] = params[name]
params = params_copy
except zipfile.BadZipFile:
# load_path wasn't a zip file
raise ValueError(f"Error: the file {load_path} wasn't a zip-file")
Expand Down Expand Up @@ -925,7 +807,7 @@ def collect_rollouts(self,
unscaled_action, _ = self.predict(obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
scaled_action = self.scale_action(unscaled_action)
scaled_action = self.policy.scale_action(unscaled_action)

if self.use_sde:
# When using SDE, the action can be out of bounds
Expand All @@ -941,7 +823,7 @@ def collect_rollouts(self,
clipped_action = np.clip(clipped_action + action_noise(), -1, 1)

# Rescale and perform action
new_obs, reward, done, infos = env.step(self.unscale_action(clipped_action))
new_obs, reward, done, infos = env.step(self.policy.unscale_action(clipped_action))

# Only stop training if return value is False, not when it is None.
if callback.on_step() is False:
Expand Down
Loading

0 comments on commit fdecd51

Please sign in to comment.