In [None]:
# we need to run on a py file instead of a jupyter notebook otherwise multiprocessing will not work properly
from agent import Agent, pit
from model_files.SLPolicyValueGPU import SLPolicyValueNetwork
import torch
import chess


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model1 = SLPolicyValueNetwork().to(device)
model1.load_state_dict(torch.load("checkpoint_stockfish_only.pth", map_location=torch.device("cuda"))["model"])
agent = Agent(policy_value_network=model1, c_puct=2.0, dirichlet_alpha=0.3, dirichlet_epsilon=0.25)

In [2]:
# Test evaluations

# should be about equal
board = chess.Board()
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('e2e4', np.float64(17.0)), ('d2d4', np.float64(12.0)), ('g1f3', np.float64(10.0)), ('b1c3', np.float64(9.0)), ('b1a3', np.float64(9.0)), ('c2c4', np.float64(9.0)), ('a2a4', np.float64(8.0)), ('a2a3', np.float64(6.0)), ('f2f4', np.float64(6.0)), ('c2c3', np.float64(5.0)), ('g1h3', np.float64(2.0)), ('h2h4', np.float64(2.0)), ('b2b4', np.float64(2.0)), ('h2h3', np.float64(1.0)), ('g2g4', np.float64(1.0)), ('g2g3', np.float64(0.0)), ('f2f3', np.float64(0.0)), ('e2e3', np.float64(0.0)), ('d2d3', np.float64(0.0)), ('b2b3', np.float64(0.0))]
1.0
e2e4


In [7]:
# should be positive (winning)
board = chess.Board("6k1/8/8/8/8/3Q4/8/K7 w - - 0 1")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('d3b5', np.float64(95.0)), ('d3d8', np.float64(1.0)), ('d3h7', np.float64(1.0)), ('d3c4', np.float64(1.0)), ('a1b1', np.float64(1.0)), ('d3d7', np.float64(0.0)), ('d3g6', np.float64(0.0)), ('d3d6', np.float64(0.0)), ('d3a6', np.float64(0.0)), ('d3f5', np.float64(0.0)), ('d3d5', np.float64(0.0)), ('d3e4', np.float64(0.0)), ('d3d4', np.float64(0.0)), ('d3h3', np.float64(0.0)), ('d3g3', np.float64(0.0)), ('d3f3', np.float64(0.0)), ('d3e3', np.float64(0.0)), ('d3c3', np.float64(0.0)), ('d3b3', np.float64(0.0)), ('d3a3', np.float64(0.0)), ('d3e2', np.float64(0.0)), ('d3d2', np.float64(0.0)), ('d3c2', np.float64(0.0)), ('d3f1', np.float64(0.0)), ('d3d1', np.float64(0.0)), ('d3b1', np.float64(0.0)), ('a1b2', np.float64(0.0)), ('a1a2', np.float64(0.0))]
1.0
d3b5


In [8]:
# should be negative (losing)
board = chess.Board("6k1/8/8/8/8/3Q4/8/K7 b - - 0 1")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('g8f8', np.float64(65.0)), ('g8g7', np.float64(33.0)), ('g8h8', np.float64(1.0)), ('g8f7', np.float64(0.0))]
1.0
g8f8


In [9]:
# should be positive (c3d5, somewhat winning)
board = chess.Board("rnb1kb1r/ppp1pppp/5n2/3q4/8/2N5/PPPP1PPP/R1BQKBNR w KQkq - 2 4")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('d2d4', np.float64(14.0)), ('d1h5', np.float64(12.0)), ('g1f3', np.float64(11.0)), ('h2h4', np.float64(11.0)), ('f1e2', np.float64(10.0)), ('f1d3', np.float64(8.0)), ('a2a3', np.float64(8.0)), ('d1f3', np.float64(5.0)), ('f1a6', np.float64(4.0)), ('f1c4', np.float64(3.0)), ('d1e2', np.float64(3.0)), ('c3d5', np.float64(2.0)), ('f1b5', np.float64(2.0)), ('g2g4', np.float64(2.0)), ('a2a4', np.float64(2.0)), ('g1e2', np.float64(1.0)), ('h2h3', np.float64(1.0)), ('c3b5', np.float64(0.0)), ('c3e4', np.float64(0.0)), ('c3a4', np.float64(0.0)), ('c3e2', np.float64(0.0)), ('c3b1', np.float64(0.0)), ('g1h3', np.float64(0.0)), ('e1e2', np.float64(0.0)), ('d1g4', np.float64(0.0)), ('a1b1', np.float64(0.0)), ('g2g3', np.float64(0.0)), ('f2f3', np.float64(0.0)), ('d2d3', np.float64(0.0)), ('b2b3', np.float64(0.0)), ('f2f4', np.float64(0.0)), ('b2b4', np.float64(0.0))]
1.0
d2d4


In [10]:
# should be positive (d8d4, somewhat winning)
board = chess.Board("rnbqkb1r/ppp1pppp/5n2/8/3Q4/8/PPP2PPP/RNB1KBNR b KQkq - 2 4")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('g7g6', np.float64(11.0)), ('e7e6', np.float64(11.0)), ('b8a6', np.float64(10.0)), ('h7h6', np.float64(10.0)), ('c7c6', np.float64(10.0)), ('c7c5', np.float64(10.0)), ('a7a5', np.float64(10.0)), ('h8g8', np.float64(8.0)), ('c8d7', np.float64(5.0)), ('b7b6', np.float64(4.0)), ('f6h5', np.float64(3.0)), ('d8d7', np.float64(2.0)), ('c8e6', np.float64(2.0)), ('b8d7', np.float64(1.0)), ('h7h5', np.float64(1.0)), ('b7b5', np.float64(1.0)), ('d8d6', np.float64(0.0)), ('d8d5', np.float64(0.0)), ('d8d4', np.float64(0.0)), ('c8f5', np.float64(0.0)), ('c8g4', np.float64(0.0)), ('c8h3', np.float64(0.0)), ('b8c6', np.float64(0.0)), ('f6g8', np.float64(0.0)), ('f6d7', np.float64(0.0)), ('f6d5', np.float64(0.0)), ('f6g4', np.float64(0.0)), ('f6e4', np.float64(0.0)), ('a7a6', np.float64(0.0)), ('g7g5', np.float64(0.0)), ('e7e5', np.float64(0.0))]
1.0
g7g6


In [11]:
# should be negative (after d8d4, somewhat losing)
board = chess.Board("rnb1kb1r/ppp1pppp/5n2/8/3q4/8/PPP2PPP/RNB1KBNR w KQkq - 0 5")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('b1c3', np.float64(12.0)), ('g1f3', np.float64(10.0)), ('f1e2', np.float64(9.0)), ('h2h4', np.float64(9.0)), ('c1h6', np.float64(8.0)), ('c1e3', np.float64(8.0)), ('f1d3', np.float64(7.0)), ('c1g5', np.float64(7.0)), ('c1f4', np.float64(7.0)), ('a2a3', np.float64(7.0)), ('f1a6', np.float64(3.0)), ('f1c4', np.float64(3.0)), ('c2c3', np.float64(3.0)), ('g1h3', np.float64(1.0)), ('g1e2', np.float64(1.0)), ('f1b5', np.float64(1.0)), ('h2h3', np.float64(1.0)), ('g2g4', np.float64(1.0)), ('a2a4', np.float64(1.0)), ('e1e2', np.float64(0.0)), ('c1d2', np.float64(0.0)), ('b1a3', np.float64(0.0)), ('b1d2', np.float64(0.0)), ('g2g3', np.float64(0.0)), ('f2f3', np.float64(0.0)), ('b2b3', np.float64(0.0)), ('f2f4', np.float64(0.0)), ('c2c4', np.float64(0.0)), ('b2b4', np.float64(0.0))]
1.0
b1c3


In [12]:
# should be positive (e6e7, somewhat winning)
board = chess.Board("r6k/pp2r2p/4Rp1Q/3p4/8/1N1P2b1/PqP3PP/7K w - - 0 25")
print(agent.select_move(game_state=board, num_simulations=100, temperature=0))

[('a2a3', np.float64(17.0)), ('h2h4', np.float64(17.0)), ('h2h3', np.float64(12.0)), ('a2a4', np.float64(12.0)), ('h6g5', np.float64(10.0)), ('c2c3', np.float64(10.0)), ('c2c4', np.float64(10.0)), ('h6e3', np.float64(9.0)), ('h6f8', np.float64(1.0)), ('h6g7', np.float64(1.0)), ('h6h7', np.float64(0.0)), ('h6g6', np.float64(0.0)), ('h6f6', np.float64(0.0)), ('h6h5', np.float64(0.0)), ('h6h4', np.float64(0.0)), ('h6f4', np.float64(0.0)), ('h6h3', np.float64(0.0)), ('h6d2', np.float64(0.0)), ('h6c1', np.float64(0.0)), ('e6e7', np.float64(0.0)), ('e6f6', np.float64(0.0)), ('e6d6', np.float64(0.0)), ('e6c6', np.float64(0.0)), ('e6b6', np.float64(0.0)), ('e6a6', np.float64(0.0)), ('e6e5', np.float64(0.0)), ('e6e4', np.float64(0.0)), ('e6e3', np.float64(0.0)), ('e6e2', np.float64(0.0)), ('e6e1', np.float64(0.0)), ('b3c5', np.float64(0.0)), ('b3a5', np.float64(0.0)), ('b3d4', np.float64(0.0)), ('b3d2', np.float64(0.0)), ('b3c1', np.float64(0.0)), ('b3a1', np.float64(0.0)), ('h1g1', np.float64(