In [1]:
import numpy as np
from games.tictactoe import TicTacToe
from games.connectfour import ConnectFour
from models.mcts import MCTS
from models.resnet import ResNet
from models.deepzero import DeepZero
import torch
from tqdm import tqdm
from tqdm import trange
import random
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

# tictactoe game

In [None]:
game = TicTacToe()

device = torch.device("cpu")

model = ResNet(game, 4, 32, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

deepzero = DeepZero(model, optimizer, game, args)
deepzero.learn()


In [7]:
game = TicTacToe()
player = 1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

model = ResNet(game, 4, 32, device=device)
model.load_state_dict(torch.load("weights/model_2.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))
        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        valid_moves = game.get_valid_moves(state)
        neutral_state = game.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        mcts_probs = mcts_probs * valid_moves  # Mask invalid moves to zero
        action = np.argmax(mcts_probs)
        # Optional: Add a check for no valid moves, though this should not occur in a proper game state
        if valid_moves[action] == 0:
            raise ValueError("No valid moves available; game state may be invalid.")

    state = game.get_next_state(state, action, player)
    value, is_terminate = game.get_value_and_terminated(state, action)
    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
            print(player, "lose")
        break
    player = game.get_opponent(player)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
val_movies [0, 1, 2, 3, 4, 5, 6, 7, 8]
[[0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]]
[[ 0.  0.  0.]
 [-1.  1.  0.]
 [ 0.  0.  0.]]
val_movies [0, 1, 2, 5, 6, 7, 8]
[[ 0.  0.  1.]
 [-1.  1.  0.]
 [ 0.  0.  0.]]
[[ 0.  0.  1.]
 [-1.  1. -1.]
 [ 0.  0.  0.]]
val_movies [0, 1, 6, 7, 8]
1 win


In [8]:
game = ConnectFour()

device = torch.device("cpu")

model = ResNet(game, 9, 32, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 600,
    'num_iterations': 8,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

deepzero = DeepZero(model, optimizer, game, args)
deepzero.learn()


  0%|          | 0/500 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [9]:
game = ConnectFour()
player = 1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_search': 600,
    'num_iterations': 8,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
model = ResNet(game, 9, 32, device=device)
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

    state = game.get_next_state(state, action, player)

    value, is_terminate = game.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
             print(player, "lose")
        break

    player = game.get_opponent(player)

[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]]
val_movies [0, 1, 2, 3, 4, 5, 6]
[[0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1.]]
[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0. -1.  0.  0.  1.]]
val_movies [0, 1, 2, 3, 4, 5, 6]


KeyboardInterrupt: Interrupted by user

In [None]:
while True:
    print(state)
    if player == 1:
        valid_moves = tictactoe.get_valid_moves(state)
        neutral_state = tictactoe.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = tictactoe.change_perspective(state, player)
        mcts_probs = mcts.search(neutral_state)
        action = np.argmax(mcts_probs)

    state = tictactoe.get_next_state(state, action, player)

    value, is_terminate = tictactoe.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
            print(player, "lose")
        break

    player = tictactoe.get_opponent(player)

In [3]:
game.get_next_state(state, player)

TypeError: ConnectFour.get_next_state() missing 1 required positional argument: 'player'