In [8]:
!pip install nbimporter



In [None]:
import chess
import random
import numpy as np
import model_training
import nbimporter

# Dictionary to represent chess pieces in a 12-dimensional one-hot format
chess_dict = {
    'p': [1,0,0,0,0,0,0,0,0,0,0,0], 'P': [0,0,0,0,0,0,1,0,0,0,0,0],
    'n': [0,1,0,0,0,0,0,0,0,0,0,0], 'N': [0,0,0,0,0,0,0,1,0,0,0,0],
    'b': [0,0,1,0,0,0,0,0,0,0,0,0], 'B': [0,0,0,0,0,0,0,0,1,0,0,0],
    'r': [0,0,0,1,0,0,0,0,0,0,0,0], 'R': [0,0,0,0,0,0,0,0,0,1,0,0],
    'q': [0,0,0,0,1,0,0,0,0,0,0,0], 'Q': [0,0,0,0,0,0,0,0,0,0,1,0],
    'k': [0,0,0,0,0,1,0,0,0,0,0,0], 'K': [0,0,0,0,0,0,0,0,0,0,0,1],
    '.': [0,0,0,0,0,0,0,0,0,0,0,0],
}

def make_matrix(board):
    pgn = board.epd().split(" ", 1)[0]
    rows = pgn.split("/")
    matrix = []

    for row in rows:
        expanded_row = []
        for char in row:
            if char.isdigit():
                expanded_row.extend(['.'] * int(char))
            else:
                expanded_row.append(char)
        matrix.append(expanded_row)
    
    return matrix

def translate(matrix, chess_dict):
    translated = []
    for row in matrix:
        translated_row = [chess_dict[piece] for piece in row]
        translated.append(translated_row)
    return translated

def calculate_move(depth, board, epochs, model, minimum=0, maximum=1):
    legal_moves = list(board.legal_moves)
    scores = [0] * len(legal_moves)

    for epoch in range(epochs):
        for i, move in enumerate(legal_moves):
            try:
                temp_board = board.copy()
                temp_board.push(move)
                
                for _ in range(depth):
                    inner_legal_moves = list(temp_board.legal_moves)
                    if not inner_legal_moves:
                        scores[i] *= 1000  # Penalize dead ends
                        break
                    random_move = random.choice(inner_legal_moves)
                    temp_board.push(random_move)
                
                matrix = make_matrix(temp_board)
                translated = np.array(translate(matrix, chess_dict))
                prediction = model.predict(translated.reshape(1, 8, 8, 12))
                scores[i] += prediction * (maximum - minimum) + minimum

            except Exception as e:
                print(f"Error with move {move}: {e}")
                continue

        print(f"Epoch {epoch + 1}/{epochs}")
    
    best_move_index = np.argmax(scores)
    return legal_moves[best_move_index]

def play_game(board, model, minimum=0, maximum=1):
    while True:
        print(board)
        print("Legal moves:", list(board.legal_moves))
        
        user_move = input("Which move do you want to play? ")
        try:
            board.push_san(user_move)
        except:
            print("Invalid move format or illegal move. Try again.")
            continue

        print("Board after your move:")
        print(board)

        # Show model prediction for your move
        matrix = make_matrix(board)
        translated = np.array(translate(matrix, chess_dict))
        prediction = model.predict(translated.reshape(1, 8, 8, 12))
        print("Model prediction (after your move):", prediction)

        # Model makes its move
        ai_move = calculate_move(10, board, 7, model, minimum, maximum)
        print("AI plays:", ai_move)
        board.push(ai_move)

# Example usage:

model = model_training.load_model_from_files()
board = chess.Board()
play_game(board, model)

r n b q k b n r
p p p p p p p p
. . . . . . . .
. . . . . . . .
. . . . . . . .
. . . . . . . .
P P P P P P P P
R N B Q K B N R
Legal moves: [Move.from_uci('g1h3'), Move.from_uci('g1f3'), Move.from_uci('b1c3'), Move.from_uci('b1a3'), Move.from_uci('h2h3'), Move.from_uci('g2g3'), Move.from_uci('f2f3'), Move.from_uci('e2e3'), Move.from_uci('d2d3'), Move.from_uci('c2c3'), Move.from_uci('b2b3'), Move.from_uci('a2a3'), Move.from_uci('h2h4'), Move.from_uci('g2g4'), Move.from_uci('f2f4'), Move.from_uci('e2e4'), Move.from_uci('d2d4'), Move.from_uci('c2c4'), Move.from_uci('b2b4'), Move.from_uci('a2a4')]
