Skip to content
Merged
2 changes: 1 addition & 1 deletion .github/workflows/python-test-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
coveralls
coveralls --service=github
137 changes: 137 additions & 0 deletions auto_trainer/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import unittest

from auto_trainer import trainer
from unittest import mock

import importlib
import sys


class TestAutoTrainer(unittest.TestCase):
def test_synced_config_no_wandb(self):
param = {'test_key': 'test_value'}

trainer._WANDB = False
self.assertFalse(trainer._WANDB)

config, run = trainer.get_synced_config(param, ['bunch', 'of', 'tags'])

self.assertDictEqual(config, param)
self.assertIsNone(run)

def test_synced_config_wandb(self):
mock_wandb = mock.MagicMock()

param = {'test_key': 'test_value'}
tags = ['bunch', 'of', 'tags']

mock_run = mock.MagicMock(config=param)
mock_wandb.init = mock.MagicMock()
mock_wandb.init.return_value = mock_run

if 'wandb' in sys.modules:
import wandb
del wandb
with mock.patch.dict('sys.modules', {'wandb': mock_wandb}):
importlib.reload(trainer)

trainer._WANDB = True
self.assertTrue(trainer._WANDB)
config, run = trainer.get_synced_config(param, tags)

mock_wandb.init.assert_called_once()

_, kwargs = mock_wandb.init.call_args
self.assertDictEqual(kwargs['config'], param)
self.assertListEqual(kwargs['tags'], tags)
self.assertDictEqual(config, param)
self.assertEqual(run, mock_run)

def test_trainer_no_wandb(self):
trainer._WANDB = False
self.assertFalse(trainer._WANDB)

algo = 'test_ago'
policy = 'test_policy'
episodes = 69

parameters = mock.MagicMock(algorithm=algo, policy=policy,
episodes=episodes)
fake_env = mock.MagicMock()

mock_learn = mock.MagicMock()
mock_save = mock.MagicMock()
mock_model = mock.MagicMock()
mock_model.learn = mock_learn
mock_model.save = mock_save

mock_model_cls = mock.MagicMock()
mock_model_cls.return_value = mock_model

with mock.patch.dict(trainer.SUPPORTED_ALGORITHMS,
{algo: mock_model_cls}, clear=True):
model, config, run = trainer.train(fake_env, parameters, None)

self.assertEqual(model, mock_model)
self.assertEqual(config, parameters)
self.assertIsNone(run)

mock_model_cls.assert_called_once()
args, _ = mock_model_cls.call_args
self.assertTupleEqual(args, (policy, fake_env))

mock_learn.assert_called_once()
mock_learn_args, _ = mock_learn.call_args
self.assertTupleEqual(mock_learn_args, (episodes, ))

mock_save.assert_called_once()
mock_save_args, _ = mock_save.call_args

# If not using wandb, should save in 2-digit reps of
# MonthDayHourMinSec; henceforth, the length of the run name should be 10
# digits long
self.assertEqual(len(mock_save_args[0]), 10)

def test_trainer_wandb(self):
algo = 'test_ago'
policy = 'test_policy'
episodes = 69

parameters = mock.MagicMock(algorithm=algo, policy=policy,
episodes=episodes)
mock_env = mock.MagicMock()
mock_run = mock.MagicMock(dir='test_dir')

mock_learn = mock.MagicMock()
mock_save = mock.MagicMock()
mock_model = mock.MagicMock()
mock_model.learn = mock_learn
mock_model.save = mock_save

mock_model_cls = mock.MagicMock()
mock_model_cls.return_value = mock_model

mock_wandb = mock.MagicMock()
if 'wandb' in sys.modules:
import wandb
del wandb
with mock.patch.dict('sys.modules', {'wandb': mock_wandb}):
importlib.reload(trainer)

trainer._WANDB = True
self.assertTrue(trainer._WANDB)

with mock.patch.dict(trainer.SUPPORTED_ALGORITHMS,
{algo: mock_model_cls}, clear=True):
model, config, run = trainer.train(mock_env, parameters, None,
run=mock_run)

_, kwargs = mock_model_cls.call_args
self.assertEqual(kwargs['tensorboard_log'], mock_run.dir)

args, _ = mock_save.call_args
self.assertEqual(args[0], '{}/model'.format(mock_run.dir))


if __name__ == '__main__':
pass
47 changes: 43 additions & 4 deletions auto_trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import List, Text

from datetime import datetime
import stable_baselines

import gym
import logging
import os
import stable_baselines


PROJECT_NAME = 'solo-rl-experiments'
Expand All @@ -21,7 +25,22 @@
_WANDB = False


def get_synced_config(parameters, tags):
def get_synced_config(parameters, tags: List[Text]):
"""Sync the config with wandb (if necessary) and return the new config

Args:
parameters (Any W&B supported type): The current hyperparameters to use. If
W&B is enabled and is actively making a sweep, these hyperparameters will
get updated to W&B's sweep. This can be any type supported by W&B,
including Dicts and argparse.Namespace objects.
tags (List[Text]): Tags that describe the run. Note that this is basically
useless if W&B is disabled.

Returns:
The (hyperparameter config, W&B run object). Obviously, if W&B is disabled,
then the run object will be None and the hyperparameter config will be what
was passed in.
"""
if not _WANDB:
return parameters, None

Expand All @@ -37,7 +56,27 @@ def get_synced_config(parameters, tags):
return config, run


def train(env, parameters, tags, full_logging=True, log_freq=100, run=None):
def train(env: gym.Env, parameters, tags: List[Text],
full_logging: bool = False, log_freq: int = 100, run = None):
"""Train a model.

Args:
env (gym.Env): Gym environment to train on.
parameters (Any W&B supported type): The hyperparameters to train with.
Refer to W&B for all of the support types.
tags (List[Text]): List of tags that describe this run. Doesn't do anything
if `run` is not None.
full_logging (bool, optional): Whether or not to log *everything*. Can fill
up space quick. Defaults to True.
log_freq (int, optional): How many steps to write the logs. Defaults to 100.
run ([wandb.Run], optional): A current W&B run. Use this if you want to
reuse a current run, i.e. train a model, do things to it, and continue
training it. If this is None, a new run will be created via
`get_synced_config`. Defaults to None.

Returns:
[type]: [description]
"""
if run:
config = parameters
else:
Expand All @@ -50,7 +89,7 @@ def train(env, parameters, tags, full_logging=True, log_freq=100, run=None):
verbose=1)

if _WANDB: wandb.tensorboard.monkeypatch._notify_tensorboard_logdir(
os.path.join(run.dir, '{}_1'.format(_DEFAULT_RUN_NAME)))
os.path.join(run.dir, '{}_1'.format(_DEFAULT_RUN_NAME)))

model.learn(config.episodes, tb_log_name=_DEFAULT_RUN_NAME,
log_interval=log_freq)
Expand Down
2 changes: 1 addition & 1 deletion experiments/notebooks/solo8v2vanilla-ppo2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@
"extension": ".py",
"format_name": "percent",
"format_version": "1.3",
"jupytext_version": "1.7.1"
"jupytext_version": "1.9.1"
}
},
"kernelspec": {
Expand Down
2 changes: 1 addition & 1 deletion experiments/train/autotrainer-sb-solo-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.7.1
# jupytext_version: 1.9.1
# kernelspec:
# display_name: Python 3
# language: python
Expand Down
2 changes: 1 addition & 1 deletion experiments/train/solo8v2vanilla-ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.7.1
# jupytext_version: 1.9.1
# kernelspec:
# display_name: Python 3
# language: python
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'stable-baselines',
'numpy<1.19.0,>=1.16.0',
'jupytext',
'gym'
],
extras_require={
'cpu': ['tensorflow>=1.15.0,<2'],
Expand Down