## Imports

In [None]:
import torch
import torch.nn as nn
from model import Neuro_gambit

## ELO init

In [None]:
elo = 2000

## Device init

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # using cuda
cpu = torch.device('cpu') # using cpu

## Loading Tensors

In [None]:
X = torch.load('./data/X_tensor_'+str(elo)+'.pt').to(device)
Y = torch.load('./data/Y_tensor_'+str(elo)+'.pt').to(device)
print(X.shape)
print(Y.shape)

# seperating the Y
Y1 = Y[:, :8]
Y2 = Y[:, 8:16]
Y3 = Y[:, 16:24]
Y4 = Y[:, 24:32]
Y5 = Y[:, 32:]

Y = [Y1,Y2,Y3,Y4,Y5]

## Model class init

In [None]:
model = Neuro_gambit().to(device)

# epochs, loss, and optim
learning_rate = 0.001
n_epochs = 1000000

# loss and optimizer functions from pytorch
criterion = nn.MSELoss() # MSE function
# optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate) # stochastic gradient descent function
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate) # way better performance with AdamW than SGD

## Loading saved model

In [None]:
model.load_state_dict(torch.load('./models/neuro_gambit_'+str(elo)+'.pt')) # it takes the loaded dictionary, not the path file itself
model.eval()

## Current Loss

In [None]:
min_loss = 1
with torch.no_grad():
    y_preds = model(X) # will output a tuple of 5 tensors

    total_loss = 0
    for i in range(len(y_preds)): # calculating the loss per tensor
        y_pred = y_preds[i]
        total_loss += criterion(y_pred, Y[i])

    print('Current loss:', f'{total_loss.item()*100:.3f}%')
    min_loss = total_loss


## Training

In [None]:
for epoch in range(n_epochs):
    # forward
    y_preds = model(X) # will output a tuple of 5 tensors

    total_loss = 0
    for i in range(len(y_preds)): # calculating the loss per tensor
        y_pred = y_preds[i]
        total_loss += criterion(y_pred, Y[i])

    # backward
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    if (epoch+1) % 5 == 0:
        print(f'Epoch [{epoch+1}/{n_epochs}], Loss: {total_loss.item():.4f}', end='\r')

    if total_loss.item() < min_loss:
        min_loss = total_loss.item()
        print('model saved, loss:', min_loss)
        torch.save(model.state_dict(), './models/neuro_gambit_'+str(elo)+'.pt')

## Saving the model

In [None]:
# Save the model
torch.save(model.state_dict(), './models/neuro_gambit_'+str(elo)+'.pt')
print('Model saved')

In [None]:
# Playing a game
from model import get_best_move
import chess
import chess.svg
import matplotlib.pyplot as plt
from cairosvg import svg2png
import cv2
from IPython.display import clear_output

def draw_board(current_board, ai_col_chess):
    """Draw board

    Keyword arguments:
    current_board -- chess.Board()
    from https://colab.research.google.com/github/iAmEthanMai/chess-engine-model/blob/main/python_chess_engine.ipynb#scrollTo=yveIUxzjUr2b
    """
    board_img = chess.svg.board(current_board, flipped=ai_col_chess==chess.WHITE)
    svg2png(bytestring=board_img,write_to='./boards/board.png')
    img = cv2.imread('./boards/board.png', 1)
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.axis('off')
    plt.imshow(img)


def get_algebraic_notation(move_stack):
    board = chess.Board()
    algebraic_moves = []
    
    for move in move_stack:
        algebraic_moves.append(board.san(move))
        board.push(move)
    
    return algebraic_moves


board = chess.Board()
ai_col = 'black'
ai_col_chess = chess.BLACK if ai_col == 'black' else chess.WHITE
last = None
while not board.is_game_over():
    # render game
    clear_output()
    alg_move_stack = get_algebraic_notation(board.move_stack)
    print(" ".join(alg_move_stack))
    print("Last AI move:", alg_move_stack[-1] if alg_move_stack != [] else '', last)
    draw_board(board, ai_col_chess)
    plt.show()

    # handle moves
    if board.turn == ai_col_chess:
        move_prob = get_best_move(model, board, ai_col)
        last = move_prob
        board.push_uci(move_prob['move'])
    else:
        try:
            user_input = input() # your input in algebraic
            board.push_san(user_input)
        except chess.IllegalMoveError:
            print('Illegal move:', user_input)
print('Winner:', 'white' if board.turn == chess.BLACK else 'black')
