In [1]:
from src.agent_logic_refactored import Agent
from src import Game
import math
import torch

In [2]:
class RandomPolicy(object):
    def __init__(self, device):
        self._device = device

    @property
    def device(self):
        return self._device

    def evaluate(self, game, action, out):
        if game.batch_size == 0:
            torch.rand((), out=out)
        else:
            torch.rand(game.batch_size, out=out)

    def update(self, state, reward, next_state):
        pass

In [3]:
def create_eps_func(eps_start, eps_end, eps_decay):
    def eps_func(index):
        return eps_end + (eps_start - eps_end) * math.exp(-1.0 * index / eps_decay)
    return eps_func

In [4]:
device = torch.device('cuda')
policy = RandomPolicy(device=device)
agent = Agent(policy)
game = Game(batch_size=int(1e6), device=device)
eps_func = create_eps_func(.9, .1, 100)

In [5]:
agent.play_game(game)

tensor([1204, 1172,  864,  ..., 1428,  648,  516], device='cuda:0',
       dtype=torch.int32)

In [6]:
%timeit agent.play_game(game)

3.38 ms ± 24 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
session = agent.init_training_session(100, game, eps_func)

In [8]:
for index, game_time, mean, std in session:
    print(index, game_time, mean, std)

0 5.540538549423218 1095.079833984375 535.0057983398438
1 10.685228109359741 1095.64697265625 535.2234497070312
2 15.804671049118042 1094.5399169921875 535.6159057617188
3 21.375129222869873 1095.3037109375 535.3302001953125
4 26.804004669189453 1094.604248046875 534.7326049804688
5 32.073079109191895 1094.182861328125 534.87646484375
6 37.37539100646973 1095.751953125 534.96826171875
7 42.86806321144104 1094.6795654296875 534.4783935546875
8 48.449432611465454 1094.732177734375 535.0352172851562
9 54.11925721168518 1094.66064453125 534.9308471679688
10 59.49124574661255 1093.892333984375 534.56982421875
11 64.61862969398499 1095.505615234375 535.4302368164062
12 70.00340580940247 1094.425048828125 534.63623046875
13 75.12910771369934 1095.3380126953125 535.1553955078125
14 80.50660872459412 1095.861572265625 535.6160888671875
15 86.17013192176819 1095.285400390625 535.2312622070312
16 91.77530527114868 1094.6209716796875 534.6950073242188
17 97.51133489608765 1095.617431640625 534.841