## GomokuHDModel

In [41]:
import torch.nn as nn

class GomokuHDModel(nn.Module):
    def __init__(self, D=10000):
        super().__init__()

        device = 'cuda' if torch.cuda.is_available else 'cpu'

        self.D = D
        self.initial_state = torchhd.random(1, D, device=device).squeeze()

        """
        C[0] -> my turn
        C[1] -> opponent turn
        """
        self.C = torchhd.random(1, D, device=device).squeeze()
        self.C = torch.stack([self.C, -self.C], dim=0)

        """
        rwd[0] -> Reward
        rwd[1] -> Punishment
        """
        self.rwd = torchhd.random(1, D, device=device).squeeze()
        self.rwd = torch.stack([self.rwd, -self.rwd], dim=0)

        self.model = self.C[0].clone()

    def forward(self, c, rwd, HDboard, encoded_pos):
        stack = torch.stack([self.model, c, rwd, HDboard])
        en_pos = torchhd.multibind(stack)
        index = torch.argmax(torch.cosine_similarity(en_pos, encoded_pos.T))
        return index, encoded_pos[:, index]
    
    def update(self, ctx):
        self.model = torchhd.normalize(torchhd.bundle(self.model, ctx))

## Parameters

In [42]:
D = int(1e4)
EPOCH = 10000
MODEL_NAME = 'gomoku_model.pth'
BOARD_NAME = 'gomoku_board.pth'

## Training

In [43]:
import os

if os.path.exists(MODEL_NAME):
    model = torch.load(MODEL_NAME)
else:
    model = GomokuHDModel(D)
if os.path.exists(BOARD_NAME):
    board = torch.load(BOARD_NAME)
else:
    board = GomokuHDBoard(D=D)

model.to('cuda')
# Training
for i in range(EPOCH):
    board.start_board(model.initial_state)
    # context for training
    white_ctx = board.HDboard.clone()
    black_ctx = board.HDboard.clone()
    # A game loop
    turn_count = 0
    while True:
        for j, c in enumerate(model.C):
            turn_count += 1

            turn = 1 if j == 0 else -1
            index, en_pos = model(c, model.rwd[0], turn * board.HDboard, board.t_encoded_pos)

            if turn == 1:
                white_ctx = white_ctx.bundle(torchhd.bind(board.HDboard, en_pos))
            if turn == -1:
                black_ctx = black_ctx.bundle(torchhd.bind(board.HDboard, en_pos))

            S = en_pos.bind(c)
            game_state = board.update_board(index, turn, S)
            if game_state != 0: break
        if game_state != 0: break
    
    # Create context with score and update model
    if game_state == 1:
        ctx = torchhd.bundle(model.rwd[0].bind(white_ctx), model.rwd[1].bind(-black_ctx))
        print(i, 'white win', turn_count)
    if game_state == -1:
        ctx = torchhd.bundle(model.rwd[1].bind(white_ctx), model.rwd[0].bind(-black_ctx))
        print(i, 'black win', turn_count)
    if game_state == 2:
        ctx = torchhd.bundle(model.rwd[2].bind(white_ctx), model.rwd[2].bind(-black_ctx))
        print(i, 'draw', turn)
    model.update(ctx)

torch.save(model, 'gomoku_model.pth')
torch.save(board, 'gomoku_board.pth')

0 black win 46
1 black win 20
2 black win 42
3 black win 58
4 white win 67
5 black win 32
6 black win 16
7 black win 68
8 black win 42
9 white win 28
10 white win 75
11 black win 28
12 white win 33
13 white win 29
14 white win 47
15 black win 62
16 white win 45
17 white win 57
18 white win 15
19 black win 32
20 black win 24
21 white win 21
22 black win 36
23 white win 27
24 white win 35
25 black win 46
26 black win 42
27 white win 77
28 white win 31
29 white win 55
30 black win 32
31 white win 21
32 black win 50
33 black win 32
34 white win 83
35 white win 55
36 black win 30
37 white win 27
38 white win 53
39 black win 34
40 white win 29
41 black win 32
42 white win 37
43 white win 34
44 white win 29
45 black win 26
46 white win 27
47 black win 22
48 black win 18
49 white win 31
50 black win 20
51 white win 55
52 black win 20
53 white win 43
54 white win 33
55 white win 50
56 black win 28
57 black win 20
58 black win 34
59 black win 32
60 black win 44
61 white win 25
62 white win 49
63

## Testing (PVB)

In [45]:
model = torch.load(MODEL_NAME)
board = torch.load(BOARD_NAME)

board.start_board(model.initial_state)
while True:
    for i, c in enumerate(model.C):
        turn = 1 if i == 0 else -1
        if turn == 1:
            index, en_pos = model(c, model.rwd[0], board.HDboard, board.t_encoded_pos)
            print(board.t_positions[index])
        elif turn == -1:
            print(board.board)
            pos = input('Enter position (x y, 0-12)')
            print('Enter', pos)
            pos = tuple([int(x) for x in pos.split()])
            index = board.t_positions.index(pos)
            en_pos = board.t_encoded_pos[:, index]
        S = en_pos.bind(c)
        game_state = board.update_board(index, turn, S)
        if game_state != 0: break
    if game_state != 0: break

  model = torch.load(MODEL_NAME)
  board = torch.load(BOARD_NAME)


(2, 0)
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0')
Enter 2 2
(10, 0)
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  0, -1,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 

ValueError: (1, 1) is not in list