Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load individual elements if state dict load fails #5213

Merged
merged 22 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
49a8a23
load individual elements if state dict load fails
andrewcoh Apr 1, 2021
35980f1
clean up exception catching
andrewcoh Apr 1, 2021
38479f6
use debug statements
andrewcoh Apr 1, 2021
88eff1a
use load_state_dict strict = False
andrewcoh Apr 1, 2021
d67f11f
add unexpected keys
andrewcoh Apr 1, 2021
f6afb5f
fix typo in warning for unexpected keys
andrewcoh Apr 1, 2021
a79718e
add load different reward tests
andrewcoh Apr 1, 2021
7aa6ae4
add debug with error print out
andrewcoh Apr 5, 2021
a663ffa
add doc change
andrewcoh Apr 5, 2021
0230f2c
test convolutions can be loaded properly
andrewcoh Apr 5, 2021
4a57d60
add check that layers still have different dimensions
andrewcoh Apr 5, 2021
46891ec
add special case for non nn.Module load and comment
andrewcoh Apr 5, 2021
d07ff7c
Update ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
andrewcoh Apr 5, 2021
b8bae51
Update ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
andrewcoh Apr 5, 2021
ca8dda8
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
5a5abaa
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
e6b87d5
Update docs/Training-ML-Agents.md
andrewcoh Apr 5, 2021
2a20494
Update ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
andrewcoh Apr 5, 2021
5a3e598
update changelog
andrewcoh Apr 5, 2021
cccd32d
Merge branch 'fix-resume-imi' of https://github.com/Unity-Technologie…
andrewcoh Apr 5, 2021
ac5e0e4
update changelog comment
andrewcoh Apr 5, 2021
41c8793
fix typo
andrewcoh Apr 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/Training-ML-Agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ Python by using both the `--resume` and `--inference` flags. Note that if you
want to run inference in Unity, you should use the
[Unity Inference Engine](Getting-Started.md#running-a-pre-trained-model).

Additionally, if the network architecture changes, you may still load an existing model,
but ML-Agents will only load the parts of the model it can load and ignore all others. For instance,
if you add a new reward signal, the existing model will load but the new reward signal
will be initialized from scratch. If you have a model with a visual encoder (CNN) but
change the `hidden_units`, the CNN will be loaded but the body of the network will be
initialized from scratch.

Alternatively, you might want to start a new training run but _initialize_ it
using an already-trained model. You may want to do this, for instance, if your
environment changed and you want a new model, but the old behavior is still
Expand Down
29 changes: 28 additions & 1 deletion ml-agents/mlagents/trainers/model_saver/torch_model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,34 @@ def _load_model(
policy = cast(TorchPolicy, policy)

for name, mod in modules.items():
mod.load_state_dict(saved_state_dict[name])
try:
if isinstance(mod, torch.nn.Module):
missing_keys, unexpected_keys = mod.load_state_dict(
saved_state_dict[name], strict=False
)
if missing_keys:
logger.warning(
f"Did not find these keys {missing_keys} in checkpoint. Initializing."
)
if unexpected_keys:
logger.warning(
f"Did not expect these keys {unexpected_keys} in checkpoint. Ignoring."
)
else:
# optimizers are treated separately
andrewcoh marked this conversation as resolved.
Show resolved Hide resolved
mod.load_state_dict(saved_state_dict[name])

# KeyError is raised if the module was not present in the last run but is being
# accessed in the saved_state_dict.
# ValueError is raised by the optimizer's load_state_dict if the parameters have
# have changed. Note, the optimizer uses a completely different load_state_dict
# function because it is not an nn.Module.
# RuntimeError is raised by PyTorch if there is a size mismatch between modules
# of the same name. This will still partially assign values to those layers that
# have not changed shape.
except (KeyError, ValueError, RuntimeError) as err:
logger.warning(f"Failed to load for module {name}. Initializing")
logger.debug(f"Module loading error : {err}")

if reset_global_steps:
policy.set_step(0)
Expand Down
46 changes: 46 additions & 0 deletions ml-agents/mlagents/trainers/tests/torch/saver/test_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
NetworkSettings,
EncoderType,
PPOSettings,
SACSettings,
POCASettings,
Expand Down Expand Up @@ -70,6 +72,50 @@ def test_load_save_policy(tmp_path):
assert policy3.get_current_step() == 0


@pytest.mark.parametrize("vis_encode_type", ["resnet", "nature_cnn", "match3"])
def test_load_policy_different_hidden_units(tmp_path, vis_encode_type):
path1 = os.path.join(tmp_path, "runid1")
trainer_params = TrainerSettings()
trainer_params.network_settings = NetworkSettings(
hidden_units=12, vis_encode_type=EncoderType(vis_encode_type)
)
policy = create_policy_mock(trainer_params, use_visual=True)
conv_params = [mod for mod in policy.actor.parameters() if len(mod.shape) > 2]

model_saver = TorchModelSaver(trainer_params, path1)
model_saver.register(policy)
model_saver.initialize_or_load(policy)
policy.set_step(2000)

mock_brain_name = "MockBrain"
model_saver.save_checkpoint(mock_brain_name, 2000)

# Try load from this path
trainer_params2 = TrainerSettings()
trainer_params2.network_settings = NetworkSettings(
hidden_units=10, vis_encode_type=EncoderType(vis_encode_type)
)
model_saver2 = TorchModelSaver(trainer_params2, path1, load=True)
policy2 = create_policy_mock(trainer_params2, use_visual=True)
conv_params2 = [mod for mod in policy2.actor.parameters() if len(mod.shape) > 2]
# asserts convolutions have different parameters before load
for conv1, conv2 in zip(conv_params, conv_params2):
assert not torch.equal(conv1, conv2)
# asserts layers still have different dimensions
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
if mod1.shape[0] == 12:
assert mod2.shape[0] == 10
model_saver2.register(policy2)
model_saver2.initialize_or_load(policy2)
# asserts convolutions have same parameters after load
for conv1, conv2 in zip(conv_params, conv_params2):
assert torch.equal(conv1, conv2)
# asserts layers still have different dimensions
for mod1, mod2 in zip(policy.actor.parameters(), policy2.actor.parameters()):
if mod1.shape[0] == 12:
assert mod2.shape[0] == 10


@pytest.mark.parametrize(
"optimizer",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@

import numpy as np

from mlagents_envs.logging_util import WARNING
from mlagents.trainers.ppo.optimizer_torch import TorchPPOOptimizer
from mlagents.trainers.sac.optimizer_torch import TorchSACOptimizer
from mlagents.trainers.poca.optimizer_torch import TorchPOCAOptimizer
from mlagents.trainers.model_saver.torch_model_saver import TorchModelSaver
from mlagents.trainers.settings import (
TrainerSettings,
Expand All @@ -14,12 +16,14 @@
RNDSettings,
PPOSettings,
SACSettings,
POCASettings,
)
from mlagents.trainers.tests.torch.test_policy import create_policy_mock
from mlagents.trainers.tests.torch.test_reward_providers.utils import (
create_agent_buffer,
)


DEMO_PATH = (
os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
+ "/test.demo"
Expand All @@ -28,8 +32,12 @@

@pytest.mark.parametrize(
"optimizer",
[(TorchPPOOptimizer, PPOSettings), (TorchSACOptimizer, SACSettings)],
ids=["ppo", "sac"],
[
(TorchPPOOptimizer, PPOSettings),
(TorchSACOptimizer, SACSettings),
(TorchPOCAOptimizer, POCASettings),
],
ids=["ppo", "sac", "poca"],
)
def test_reward_provider_save(tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer
Expand Down Expand Up @@ -87,3 +95,55 @@ def test_reward_provider_save(tmp_path, optimizer):
rp_1 = optimizer.reward_signals[reward_name]
rp_2 = optimizer2.reward_signals[reward_name]
assert np.array_equal(rp_1.evaluate(data), rp_2.evaluate(data))


@pytest.mark.parametrize(
"optimizer",
[
(TorchPPOOptimizer, PPOSettings),
(TorchSACOptimizer, SACSettings),
(TorchPOCAOptimizer, POCASettings),
],
ids=["ppo", "sac", "poca"],
)
def test_load_different_reward_provider(caplog, tmp_path, optimizer):
OptimizerClass, HyperparametersClass = optimizer

trainer_settings = TrainerSettings()
trainer_settings.hyperparameters = HyperparametersClass()
trainer_settings.reward_signals = {
RewardSignalType.CURIOSITY: CuriositySettings(),
RewardSignalType.RND: RNDSettings(),
}

policy = create_policy_mock(trainer_settings, use_discrete=False)
optimizer = OptimizerClass(policy, trainer_settings)

# save at path 1
path1 = os.path.join(tmp_path, "runid1")
model_saver = TorchModelSaver(trainer_settings, path1)
model_saver.register(policy)
model_saver.register(optimizer)
model_saver.initialize_or_load()
assert len(optimizer.critic.value_heads.stream_names) == 2
policy.set_step(2000)
model_saver.save_checkpoint("MockBrain", 2000)

trainer_settings2 = TrainerSettings()
trainer_settings2.hyperparameters = HyperparametersClass()
trainer_settings2.reward_signals = {
RewardSignalType.GAIL: GAILSettings(demo_path=DEMO_PATH)
}

# create a new optimizer and policy
policy2 = create_policy_mock(trainer_settings2, use_discrete=False)
optimizer2 = OptimizerClass(policy2, trainer_settings2)

# load weights
model_saver2 = TorchModelSaver(trainer_settings2, path1, load=True)
model_saver2.register(policy2)
model_saver2.register(optimizer2)
assert len(optimizer2.critic.value_heads.stream_names) == 1
model_saver2.initialize_or_load() # This is to load the optimizers
messages = [rec.message for rec in caplog.records if rec.levelno == WARNING]
assert len(messages) > 0