diff --git a/pokemongo_bot/cell_workers/__init__.py b/pokemongo_bot/cell_workers/__init__.py index 8a700ef5ba..d9198790c7 100644 --- a/pokemongo_bot/cell_workers/__init__.py +++ b/pokemongo_bot/cell_workers/__init__.py @@ -12,3 +12,4 @@ from handle_soft_ban import HandleSoftBan from follow_path import FollowPath from follow_spiral import FollowSpiral +from base_task import BaseTask diff --git a/pokemongo_bot/cell_workers/base_task.py b/pokemongo_bot/cell_workers/base_task.py new file mode 100644 index 0000000000..4d73a68443 --- /dev/null +++ b/pokemongo_bot/cell_workers/base_task.py @@ -0,0 +1,14 @@ +class BaseTask(object): + def __init__(self, bot, config): + self.bot = bot + self.config = config + self._validate_work_exists() + self.initialize() + + def _validate_work_exists(self): + method = getattr(self, 'work', None) + if not method or not callable(method): + raise NotImplementedError('Missing "work" method') + + def initialize(self): + pass diff --git a/pokemongo_bot/test/base_task_test.py b/pokemongo_bot/test/base_task_test.py new file mode 100644 index 0000000000..ed87c44e86 --- /dev/null +++ b/pokemongo_bot/test/base_task_test.py @@ -0,0 +1,43 @@ +import unittest +import json +from pokemongo_bot.cell_workers import BaseTask + +class FakeTask(BaseTask): + def initialize(self): + self.foo = 'foo' + + def work(self): + pass + +class FakeTaskWithoutInitialize(BaseTask): + def work(self): + pass + +class FakeTaskWithoutWork(BaseTask): + pass + +class BaseTaskTest(unittest.TestCase): + def setUp(self): + self.bot = {} + self.config = {} + + def test_initialize_called(self): + task = FakeTask(self.bot, self.config) + self.assertIs(task.bot, self.bot) + self.assertIs(task.config, self.config) + self.assertEquals(task.foo, 'foo') + + def test_does_not_throw_without_initialize(self): + FakeTaskWithoutInitialize(self.bot, self.config) + + def test_throws_without_work(self): + self.assertRaisesRegexp( + NotImplementedError, + 'Missing "work" method', + FakeTaskWithoutWork, + self.bot, + self.config + ) + +if __name__ == '__main__': + unittest.main()