diff --git a/.github/workflows/python-test-coverage.yml b/.github/workflows/python-test-coverage.yml index edfb4aa..da6b241 100644 --- a/.github/workflows/python-test-coverage.yml +++ b/.github/workflows/python-test-coverage.yml @@ -35,4 +35,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: | - coveralls + coveralls --service=github diff --git a/auto_trainer/test_trainer.py b/auto_trainer/test_trainer.py new file mode 100644 index 0000000..7eaa635 --- /dev/null +++ b/auto_trainer/test_trainer.py @@ -0,0 +1,137 @@ +import unittest + +from auto_trainer import trainer +from unittest import mock + +import importlib +import sys + + +class TestAutoTrainer(unittest.TestCase): + def test_synced_config_no_wandb(self): + param = {'test_key': 'test_value'} + + trainer._WANDB = False + self.assertFalse(trainer._WANDB) + + config, run = trainer.get_synced_config(param, ['bunch', 'of', 'tags']) + + self.assertDictEqual(config, param) + self.assertIsNone(run) + + def test_synced_config_wandb(self): + mock_wandb = mock.MagicMock() + + param = {'test_key': 'test_value'} + tags = ['bunch', 'of', 'tags'] + + mock_run = mock.MagicMock(config=param) + mock_wandb.init = mock.MagicMock() + mock_wandb.init.return_value = mock_run + + if 'wandb' in sys.modules: + import wandb + del wandb + with mock.patch.dict('sys.modules', {'wandb': mock_wandb}): + importlib.reload(trainer) + + trainer._WANDB = True + self.assertTrue(trainer._WANDB) + config, run = trainer.get_synced_config(param, tags) + + mock_wandb.init.assert_called_once() + + _, kwargs = mock_wandb.init.call_args + self.assertDictEqual(kwargs['config'], param) + self.assertListEqual(kwargs['tags'], tags) + self.assertDictEqual(config, param) + self.assertEqual(run, mock_run) + + def test_trainer_no_wandb(self): + trainer._WANDB = False + self.assertFalse(trainer._WANDB) + + algo = 'test_ago' + policy = 'test_policy' + episodes = 69 + + parameters = mock.MagicMock(algorithm=algo, policy=policy, + episodes=episodes) + fake_env = mock.MagicMock() + + mock_learn = mock.MagicMock() + mock_save = mock.MagicMock() + mock_model = mock.MagicMock() + mock_model.learn = mock_learn + mock_model.save = mock_save + + mock_model_cls = mock.MagicMock() + mock_model_cls.return_value = mock_model + + with mock.patch.dict(trainer.SUPPORTED_ALGORITHMS, + {algo: mock_model_cls}, clear=True): + model, config, run = trainer.train(fake_env, parameters, None) + + self.assertEqual(model, mock_model) + self.assertEqual(config, parameters) + self.assertIsNone(run) + + mock_model_cls.assert_called_once() + args, _ = mock_model_cls.call_args + self.assertTupleEqual(args, (policy, fake_env)) + + mock_learn.assert_called_once() + mock_learn_args, _ = mock_learn.call_args + self.assertTupleEqual(mock_learn_args, (episodes, )) + + mock_save.assert_called_once() + mock_save_args, _ = mock_save.call_args + + # If not using wandb, should save in 2-digit reps of + # MonthDayHourMinSec; henceforth, the length of the run name should be 10 + # digits long + self.assertEqual(len(mock_save_args[0]), 10) + + def test_trainer_wandb(self): + algo = 'test_ago' + policy = 'test_policy' + episodes = 69 + + parameters = mock.MagicMock(algorithm=algo, policy=policy, + episodes=episodes) + mock_env = mock.MagicMock() + mock_run = mock.MagicMock(dir='test_dir') + + mock_learn = mock.MagicMock() + mock_save = mock.MagicMock() + mock_model = mock.MagicMock() + mock_model.learn = mock_learn + mock_model.save = mock_save + + mock_model_cls = mock.MagicMock() + mock_model_cls.return_value = mock_model + + mock_wandb = mock.MagicMock() + if 'wandb' in sys.modules: + import wandb + del wandb + with mock.patch.dict('sys.modules', {'wandb': mock_wandb}): + importlib.reload(trainer) + + trainer._WANDB = True + self.assertTrue(trainer._WANDB) + + with mock.patch.dict(trainer.SUPPORTED_ALGORITHMS, + {algo: mock_model_cls}, clear=True): + model, config, run = trainer.train(mock_env, parameters, None, + run=mock_run) + + _, kwargs = mock_model_cls.call_args + self.assertEqual(kwargs['tensorboard_log'], mock_run.dir) + + args, _ = mock_save.call_args + self.assertEqual(args[0], '{}/model'.format(mock_run.dir)) + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/auto_trainer/trainer.py b/auto_trainer/trainer.py index 6d97b65..490cefb 100644 --- a/auto_trainer/trainer.py +++ b/auto_trainer/trainer.py @@ -1,7 +1,11 @@ +from typing import List, Text + from datetime import datetime -import stable_baselines + +import gym import logging import os +import stable_baselines PROJECT_NAME = 'solo-rl-experiments' @@ -21,7 +25,22 @@ _WANDB = False -def get_synced_config(parameters, tags): +def get_synced_config(parameters, tags: List[Text]): + """Sync the config with wandb (if necessary) and return the new config + + Args: + parameters (Any W&B supported type): The current hyperparameters to use. If + W&B is enabled and is actively making a sweep, these hyperparameters will + get updated to W&B's sweep. This can be any type supported by W&B, + including Dicts and argparse.Namespace objects. + tags (List[Text]): Tags that describe the run. Note that this is basically + useless if W&B is disabled. + + Returns: + The (hyperparameter config, W&B run object). Obviously, if W&B is disabled, + then the run object will be None and the hyperparameter config will be what + was passed in. + """ if not _WANDB: return parameters, None @@ -37,7 +56,27 @@ def get_synced_config(parameters, tags): return config, run -def train(env, parameters, tags, full_logging=True, log_freq=100, run=None): +def train(env: gym.Env, parameters, tags: List[Text], + full_logging: bool = False, log_freq: int = 100, run = None): + """Train a model. + + Args: + env (gym.Env): Gym environment to train on. + parameters (Any W&B supported type): The hyperparameters to train with. + Refer to W&B for all of the support types. + tags (List[Text]): List of tags that describe this run. Doesn't do anything + if `run` is not None. + full_logging (bool, optional): Whether or not to log *everything*. Can fill + up space quick. Defaults to True. + log_freq (int, optional): How many steps to write the logs. Defaults to 100. + run ([wandb.Run], optional): A current W&B run. Use this if you want to + reuse a current run, i.e. train a model, do things to it, and continue + training it. If this is None, a new run will be created via + `get_synced_config`. Defaults to None. + + Returns: + [type]: [description] + """ if run: config = parameters else: @@ -50,7 +89,7 @@ def train(env, parameters, tags, full_logging=True, log_freq=100, run=None): verbose=1) if _WANDB: wandb.tensorboard.monkeypatch._notify_tensorboard_logdir( - os.path.join(run.dir, '{}_1'.format(_DEFAULT_RUN_NAME))) + os.path.join(run.dir, '{}_1'.format(_DEFAULT_RUN_NAME))) model.learn(config.episodes, tb_log_name=_DEFAULT_RUN_NAME, log_interval=log_freq) diff --git a/experiments/notebooks/solo8v2vanilla-ppo2.ipynb b/experiments/notebooks/solo8v2vanilla-ppo2.ipynb index 4e5dd61..8bc5063 100644 --- a/experiments/notebooks/solo8v2vanilla-ppo2.ipynb +++ b/experiments/notebooks/solo8v2vanilla-ppo2.ipynb @@ -435,7 +435,7 @@ "extension": ".py", "format_name": "percent", "format_version": "1.3", - "jupytext_version": "1.7.1" + "jupytext_version": "1.9.1" } }, "kernelspec": { diff --git a/experiments/train/autotrainer-sb-solo-demo.py b/experiments/train/autotrainer-sb-solo-demo.py index 8f0d5d4..ea98a76 100644 --- a/experiments/train/autotrainer-sb-solo-demo.py +++ b/experiments/train/autotrainer-sb-solo-demo.py @@ -5,7 +5,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.7.1 +# jupytext_version: 1.9.1 # kernelspec: # display_name: Python 3 # language: python diff --git a/experiments/train/solo8v2vanilla-ppo2.py b/experiments/train/solo8v2vanilla-ppo2.py index 446e114..d96dedf 100644 --- a/experiments/train/solo8v2vanilla-ppo2.py +++ b/experiments/train/solo8v2vanilla-ppo2.py @@ -5,7 +5,7 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.7.1 +# jupytext_version: 1.9.1 # kernelspec: # display_name: Python 3 # language: python diff --git a/setup.py b/setup.py index 560d39a..492704b 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ 'stable-baselines', 'numpy<1.19.0,>=1.16.0', 'jupytext', + 'gym' ], extras_require={ 'cpu': ['tensorflow>=1.15.0,<2'],