In [1]:
import src
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## Init Agent

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = nn.Sequential(
    nn.Linear(16, 16),
    nn.Hardsigmoid(),
    nn.Linear(16, 8),
    nn.Hardsigmoid(),
    nn.Linear(16, 4),
    nn.Hardsigmoid(),
    nn.Linear(4, 2),
    nn.Hardsigmoid(),
    nn.Linear(2, 1),
    nn.Hardsigmoid(),
)

In [4]:
value_func = src.value_func.NNFunc(model, device)

In [5]:
agent = src.Agent(value_func)

## Play A Game

In [7]:
batch_size = 1000
game = src.Game(batch_size, device=device)

In [8]:
agent.play_game(game).float().mean()

tensor(2549.6602, device='cuda:0')

In [9]:
%%timeit
game.new_game()
agent.play_game(game)

9.21 s ± 840 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Train Agent

In [11]:
num_games = 500
eps_start = .9
eps_end = .1
eps_decay = 100
loss_func = F.smooth_l1_loss
optimizer = optim.Adam(model.parameters())
session = agent.init_training_session(num_games, game, eps_start, eps_end, eps_decay, loss_func=loss_func, optimizer=optimizer)

In [12]:
for index, game_time, mean, std in session:
    print(index, game_time, mean, std)

0 4.9837141036987305 1170.0360107421875 579.3958740234375
1 9.496177911758423 1181.89208984375 553.3045043945312
2 13.768308639526367 1152.612060546875 549.4242553710938
3 18.45579481124878 1174.1080322265625 570.6180419921875
4 23.573424100875854 1187.5841064453125 556.6300659179688
5 27.645490646362305 1221.1201171875 584.0358276367188
6 31.982757329940796 1222.2440185546875 580.6599731445312
7 36.46521353721619 1230.10400390625 603.3790283203125
8 41.733492851257324 1227.612060546875 599.7770385742188
9 47.24485468864441 1229.412109375 613.2137451171875
10 52.48392152786255 1202.5 568.0673828125
11 57.622153997421265 1237.2840576171875 578.6387329101562
12 63.82865047454834 1225.6680908203125 609.2220458984375
13 69.45188689231873 1219.0400390625 608.67529296875
14 74.17082238197327 1242.7720947265625 596.6329345703125
15 79.47590780258179 1225.8360595703125 575.0524291992188
16 84.45513272285461 1234.800048828125 598.2525634765625
17 90.01956081390381 1260.248046875 594.27569580078

In [18]:
torch.rand((), device=device)

tensor(0.2128, device='cuda:0')