diff --git a/CHANGELOG.md b/CHANGELOG.md index 6195660283..4e5e5bb3e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the end of batch size mismatch ([#389](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/389)) - Fixed `batch_size` parameter for DataModules remaining ([#344](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/344)) - Fixed CIFAR `num_samples` ([#432](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/432)) +- Fixed DQN `run_n_episodes` using the wrong environment variable ([#525](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/525)) ## [0.2.5] - 2020-10-12 diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index ff51ecdb6f..2c98d5efa2 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -171,7 +171,7 @@ def run_n_episodes(self, env, n_epsiodes: int = 1, epsilon: float = 1.0) -> List while not done: self.agent.epsilon = epsilon action = self.agent(episode_state, self.device) - next_state, reward, done, _ = self.env.step(action[0]) + next_state, reward, done, _ = env.step(action[0]) episode_state = next_state episode_reward += reward