In [1]:
from AlphaZeroParallel import AlphaZero
from ConnectFour import ConnectFour
from ResNet import ResNet

import os
import torch
import numpy as np

In [2]:
def train(filepath=None):
    game = ConnectFour()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = ResNet(game, 9, 128, device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

    if filepath:
        model.load_state_dict(torch.load(
            f"{game}/models/{filepath}.pt", map_location=device))
        optimizer.load_state_dict(torch.load(
            f"{game}/optimizers/{filepath}.pt", map_location=device))

    model.train()

    args = {
        'C': 2,
        'num_searches': 500,
        'num_iterations': 10,
        'num_selfPlay_iterations': 100,
        'num_parallel_games': 20,
        'num_epochs': 10,
        'batch_size': 32,
        'temperature': 1.25,
        'dirichlet_epsilon': 0.25,
        'dirichlet_alpha': 0.3
    }

    alphazero = AlphaZero(model, optimizer, game, args)
    alphazero.learn()
    

def show_state(state):
    for i in range(6):
        for j in range(7):
            if state[i][j] == 1:
                print('X', end=' ')
            elif state[i][j] == -1:
                print('O', end=' ')
            else:
                print('-', end=' ')
        print()
    print()


def play_against(filepath):
    game = ConnectFour()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = ResNet(game, 9, 128, device=device)

    model.load_state_dict(torch.load(
        f"{game}/models/{filepath}.pt", map_location=device))

    model.eval()

    player = 1
    state = game.get_initial_state()

    while True:
        os.system('cls')
        show_state(state)

        if player == 1:
            encoded_state = game.get_encoded_state(state)
            tensor_state = torch.tensor(
                encoded_state, device=device).unsqueeze(0)
            policy, _ = model(tensor_state)
            policy = torch.softmax(policy, axis=1).squeeze(
                0).detach().cpu().numpy()
            action = np.argmax(policy)
        else:
            valid = game.get_valid_moves(state)
            print("Moves:", [i for i in range(game.action_size) if valid[i]])
            action = int(input(f"{player}: "))

            if not valid[action]:
                print("Invalid action!")
                continue

        state = game.get_next_state(state, action, player)
        val, ter = game.get_value_and_terminated(state, action)

        if ter:
            show_state(state)
            if val == 1:
                if player == -1:
                    print("human won")
                if player == 1:
                    print("computer won")
            else:
                print("draw")
            break

        player = game.get_opponent(player)

In [3]:
train("model")

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - X - - - 

Moves: [0, 1, 2, 3, 4, 5, 6]
-1: 2
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - O X - - - 

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - O X X - - 

Moves: [0, 1, 2, 3, 4, 5, 6]
-1: 5
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - O X X O - 

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - X - - - 
- - O X X O - 

Moves: [0, 1, 2, 3, 4, 5, 6]
-1: 3
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - O - - - 
- - - X - - - 
- - O X X O - 

- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - O - - - 
- - - X - X - 
- - O X X O - 

Moves: [0, 1, 2, 3, 4, 5, 6]
-1: 2
- - - - - - - 
- - - - - - - 
- - - - - - - 
- - - O - - - 
- - O X - X - 
- - O X X O - 

- - - - - - - 
- - - - - - - 
- - - - - -