In [1]:
from board import Board
from env import Env
from qtypes import *
from striga_scheme import StrigaDeterministicMovementScheme
from train import train
from witcher_scheme import WitcherMovementSchemeSARSA

import matplotlib

import matplotlib.pyplot as plt
import matplotlib.animation as animation

from qtypes import Obstacle

plt.rcParams["image.cmap"] = "nipy_spectral"

witcher_color = .6
striga_color = .15
attacked_color = .1

ani = None

def visualize(env, turns=1000):
    ## this backend works terribly on macos
    fig, ax = plt.subplots(figsize=(8, 8))
    done = False

    def update(i):
        print(i)
        nonlocal done
        if done:
            ax.set_title("Game over: {}".format(env.game_res), fontsize=20)
            ax.set_axis_off()
            return

        env.advance()
        im = env.board.numpy_repr()
        im[env.state.striga_position.y, env.state.striga_position.x] = striga_color
        im[env.state.witcher_position.y, env.state.witcher_position.x] = witcher_color
        for p in env.attacked_positions:
            if env.board[p] != Obstacle.WALL:
                im[p.y, p.x] = im[p.y, p.x] + attacked_color

        env.attacked_positions = []
        if env.game_res is not None:
            done = True

        ax.imshow(im)
        ax.set_title("Turn {}, striga_hp: {}".format(i, 3 - env.hits), fontsize=20)
        ax.set_axis_off()

    global ani
    ani = animation.FuncAnimation(fig, update, frames=turns, interval=10)
    plt.show()
    return ani

In [2]:
width = 7
height = 7
walls = [
    Position(width - 2, 0),
    Position(width - 2, 1),
    Position(width - 1, 1)
]
castle_position = Position(width - 1, 0)
board = Board(width, height, castle_position, walls)
striga = Striga(StrigaDeterministicMovementScheme(board, 3))
witcher = Witcher(WitcherMovementSchemeSARSA(board))
initial_state = QState(Position(0, 0), Position(3, 3))

train(board, striga, witcher, initial_state)

0 0 426
0 0 840
1 0 882
2 0 936
2 0 1098
2 0 1536
3 0 1746
3 0 1842
4 0 1914
4 0 1944
4 0 2004
4 0 2016
4 0 2376
5 0 2436
5 0 2448
5 0 2526
6 0 2778
6 0 2850
6 0 2874
9 1 3109
9 1 3145
9 1 3949
10 1 3961
13 2 4038
13 2 4568
13 2 4754
13 2 4868
13 2 4892
15 2 5012
15 2 5546
18 3 5757
18 3 6195
18 3 6237
18 3 6279
18 3 6291
21 4 6458
24 5 6493
27 6 6602
28 6 6678
28 6 6948
28 6 6990
28 6 7056
29 6 7104
32 7 7657
32 7 7705
32 7 7837
35 8 7908
35 8 7976
38 9 8055
38 9 8073
38 9 8229
39 9 8241
40 9 8337
40 9 8523
41 9 8787
44 10 9128
47 11 9263
47 11 9269
48 11 9311
48 11 9371
48 11 9473
48 11 9587
49 11 9683
52 12 9740
55 13 9831
57 13 9943
57 13 9985
57 13 10027
59 13 10225
59 13 10261
59 13 10345
61 13 10429
61 13 10453
62 13 10531
65 14 10698
65 14 10814
67 14 10880
68 14 10982
71 15 11159
72 15 11229
75 16 11294
75 16 11332
76 16 11368
79 17 11485
79 17 11555
80 17 11603
80 17 11657
83 18 11774
83 18 12228
83 18 12396
83 18 12462
86 19 12559
89 20 12668
92 21 12779
93 21 12807
94 21 12

1460 420 50940
1463 421 50965
1466 422 50990
1469 423 51039
1472 424 51076
1475 425 51095
1478 426 51156
1480 426 51198
1483 427 51231
1486 428 51284
1489 429 51307
1492 430 51336
1493 430 51358
1496 431 51383
1499 432 51408
1502 433 51457
1505 434 51494
1508 435 51527
1511 436 51596
1514 437 51627
1517 438 51674
1519 438 51750
1520 438 51768
1523 439 51795
1526 440 51848
1529 441 51873
1532 442 51898
1535 443 51923
1538 444 51948
1541 445 51981
1543 445 52027
1546 446 52070
1547 446 52124
1550 447 52167
1553 448 52204
1554 448 52258
1557 449 52285
1560 450 52328
1563 451 52449
1566 452 52496
1569 453 52539
1572 454 52588
1574 454 52624
1577 455 52673
1580 456 52700
1583 457 52765
1586 458 52798
1589 459 52819
1592 460 52852
1595 461 52877
1598 462 52934
1601 463 52955
1604 464 53014
1606 464 53078
1609 465 53125
1612 466 53170
1612 466 53218
1615 467 53247
1618 468 53282
1621 469 53331
1624 470 53378
1627 471 53403
1630 472 53468
1633 473 53503
1636 474 53556
1639 475 53587
1642 476 5

In [3]:
%matplotlib notebook

env = Env(board, initial_state, striga, witcher)
ani = visualize(env)
print(ani)

<IPython.core.display.Javascript object>

<matplotlib.animation.FuncAnimation object at 0x12ba8e400>
