In [1]:
from TicTacToe import TicTacToe
from ConnectFour import ConnectFour
from ResNet import ResNet
from AlphaZero import AlphaZero

import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def train():
    game = TicTacToe()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ResNet(tictactoe, 4, 64, device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

    args = {
        'C': 2,
        'num_searches': 60,
        'num_iterations': 3,
        'num_selfPlay_iterations': 500,
        'num_epochs': 4,
        'batch_size': 64,
        'temperature': 1.25,
        'dirichlet_epsilon': 0.25,
        'dirichlet_alpha': 0.3
    }

    alphazero = AlphaZero(model, optimizer, game, args)
    alphazero.learn()


def test_state(filepath):
    game = TicTacToe()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    state = tictactoe.get_initial_state()
    state = tictactoe.get_next_state(state, 2, -1)
    state = tictactoe.get_next_state(state, 4, -1)
    state = tictactoe.get_next_state(state, 6, 1)
    state = tictactoe.get_next_state(state, 8, 1)


    encoded_state = tictactoe.get_encoded_state(state)

    tensor_state = torch.tensor(encoded_state, device=device).unsqueeze(0)

    model = ResNet(tictactoe, 4, 64, device=device)
    model.load_state_dict(torch.load(filepath, map_location=device))
    model.eval()

    policy, value = model(tensor_state)
    value = value.item()
    policy = torch.softmax(policy, axis=1).squeeze(0).detach().cpu().numpy()

    print(value)

    print(state)
    print(tensor_state)

    plt.bar(range(tictactoe.action_size), policy)
    plt.show()


def play_against(filepath):
    game = TicTacToe()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ResNet(game, 4, 64, device=device)
    model.load_state_dict(torch.load(filepath, map_location=device))
    model.eval()

    player = 1
    state = game.get_initial_state()

    while True:
        print(state)

        if player == 1:
            encoded_state = game.get_encoded_state(state)
            tensor_state = torch.tensor(
                encoded_state, device=device).unsqueeze(0)
            policy, _ = model(tensor_state)
            policy = torch.softmax(policy, axis=1).squeeze(
                0).detach().cpu().numpy()
            action = np.argmax(policy)
        else:
            valid = game.get_valid_moves(state)
            print("Moves:", [i for i in range(game.action_size) if valid[i]])
            action = int(input(f"{player}: "))

            if not valid[action]:
                print("Invalid action!")
                continue

        state = game.get_next_state(state, action, player)
        val, ter = game.get_value_and_terminated(state, action)

        if ter:
            print(state)
            if val == 1:
                print(player, "won")
            else:
                print("draw")
            break

        player = game.get_opponent(player)

In [3]:
play_against(".\TicTacToe\models\[4+64+60+500]@2.pt")

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
[[0. 0. 0.]
 [0. 0. 0.]
 [1. 0. 0.]]
Moves: [0, 1, 2, 3, 4, 5, 7, 8]
-1: 1
[[ 0. -1.  0.]
 [ 0.  0.  0.]
 [ 1.  0.  0.]]
[[ 1. -1.  0.]
 [ 0.  0.  0.]
 [ 1.  0.  0.]]
Moves: [2, 3, 4, 5, 7, 8]
-1: 3
[[ 1. -1.  0.]
 [-1.  0.  0.]
 [ 1.  0.  0.]]
[[ 1. -1.  0.]
 [-1.  1.  0.]
 [ 1.  0.  0.]]
Moves: [2, 5, 7, 8]
-1: 8
[[ 1. -1.  0.]
 [-1.  1.  0.]
 [ 1.  0. -1.]]
[[ 1. -1.  1.]
 [-1.  1.  0.]
 [ 1.  0. -1.]]
1 won
