In [93]:
# Import Statements:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Device Management:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available and set as device.")
else:
    print("MPS is not available on this system.")

MPS is available and set as device.


In [94]:
# Creating Training Data:
with open('chess-moves.txt', 'r') as f:
    text = f.read()

text = text.split()
len(text)

161611731

In [95]:
# Creating Vocabulary and Move Dicts:
moves = set(text)
vocab_size = len(moves)

itom = {i:m for i,m in enumerate(moves)}
mtoi = {m:i for i,m in enumerate(moves)}

itom[2]

'Reg7+'

In [96]:
# Creating Tokenizers:
encode = lambda l: [mtoi[m] for m in l]
decode = lambda l: [itom[n] for n in l]

encode(['Bd4+', 'Rd4', 'd4'])

[1871, 8233, 9584]

In [97]:
# Creating Data Tensor:
data = torch.tensor(encode(text), dtype=torch.long)
data = data.to(device)

data[:20]

tensor([10242,  6666,  9584, 12089,  3217,  5093,  6189,  7502, 10869,  6300,
         6288, 11108,  2861,  3224,  2861, 10693, 10242,  5308,  6654,  6654],
       device='mps:0')

In [98]:
# Creating Hyperparameters:
vocab_size = 12356
n_embd = 384
head_size = 16
n_layer = 4
n_head = 4
batch_size = 16
block_size = 128
dropout = 0.2

# Single Head of Attention:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()

        # K,Q,V Matrices:
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)

        # Buffer Matrix and Dropout Layer:
        self.register_buffer('tril', torch.tril(torch.ones([block_size, block_size])))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        
        k = self.key(x)
        q = self.query(x)

        # Determining Affinities with Weighted Sum:
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        # Adjusting Embedding With Value Matrix:
        v = self.value(x)
        out = wei @ v
        return out

# Parralelization of Attention Heads:
class MultiHeadedAttention(nn.Module):

    def __init__(self, head_size, n_head):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(n_head)])

        # Projection and Dropout Layers:
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

# Multi-Layer Perceptron:
class FeedForward(nn.Module):

    def __init__(self, n_embd):
        super().__init__()

        # Linear Layers:
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4),
            nn.GELU(),
            nn.Linear(n_embd * 4, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

# Self-Attention/MLP Block:
class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head

        # Self-Attention/MLP:
        self.sa = MultiHeadedAttention(head_size, n_head)
        self.ffwd = FeedForward(n_embd)

        # Layer Normalization:
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    # Residual Blocks:
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# AMP Transformer Model:
class ChessGPT(nn.Module):

    def __init__(self):
        super().__init__()

        # Token and Positional Embedding Tables:
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)

        # Block Layers:
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])

        # Layer Normalization and Unembedding:
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B,T = idx.shape

        # Embedding:
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Creating Logits after Forward Pass:
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        # Determining Loss via Cross Entropy
        if targets == None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx):

        # Generate New Move:
        idx_cond = idx[:, -block_size:]
        logits, loss = self(idx_cond)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        idx_new = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, idx_new], dim=1)
        return idx
        
# Initializing Model
m = ChessGPT()
m = m.to(device)
m = torch.compile(m)

# Creating Optimizer:
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)

list(m.parameters())[0]

Parameter containing:
tensor([[-1.1402, -1.3340, -0.8495,  ...,  1.2330, -0.9163, -1.3038],
        [-0.2831,  0.2614, -1.6091,  ...,  0.4453,  0.0858, -0.7578],
        [-1.7222, -0.4916,  0.1578,  ..., -0.3442, -2.2328, -0.2045],
        ...,
        [-0.0936, -1.1812,  1.8925,  ..., -0.6833, -1.2142,  0.4356],
        [-1.8069, -1.9777,  0.9817,  ..., -0.2323,  0.1621, -0.9676],
        [ 0.3115,  0.3212, -0.1303,  ...,  1.9315, -1.1168, -0.5739]],
       device='mps:0', requires_grad=True)

In [99]:
# Untrained Example:
game = ['e4']

idx = torch.tensor([encode(game)], dtype=torch.long).to(device)

decode(m.generate(idx)[0].tolist())

['e4', 'Rbxa2']

In [100]:
# Batching Data:
def get_batch():
    ix = torch.randint(len(data) - block_size, (batch_size,))
    xb = torch.stack([data[i:i+block_size] for i in ix])
    yb = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return xb, yb

xb, yb = get_batch()
xb.shape, yb.shape

(torch.Size([16, 128]), torch.Size([16, 128]))

In [108]:
import time
start_time = time.time()

steps = 10000

# Training Loop:
for step in range(steps):
    xb, yb = get_batch()
    logits, loss = m(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 1000 == 0:
        print(f'step: {step} loss: {loss:.4f}')

end_time = time.time()
total = end_time - start_time

print(f'total loss: {loss:.4f} total time: {total:.4f} seconds')    

step: 0 loss: 3.3490
step: 1000 loss: 3.3577
step: 2000 loss: 3.2876
step: 3000 loss: 3.2888
step: 4000 loss: 3.2059
step: 5000 loss: 3.2145
step: 6000 loss: 3.2807
step: 7000 loss: 3.2221
step: 8000 loss: 3.1621
step: 9000 loss: 3.2097
total loss: 3.1522 total time: 2082.9466 seconds


In [109]:
game = ['e4',]
idx = torch.tensor([encode(game)], dtype=torch.long).to(device)

decode(m.generate(idx)[0].tolist())[-1]

'e5'

In [110]:
import chess
import chess.svg
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import io
import base64

def get_move_history_cleaned(board):
    move_history = []
    temp_board = chess.Board()
    
    for move in board.move_stack:
        san_move = temp_board.san(move)
        cleaned_move = san_move.replace('+', '').replace('#', '').replace('x', '')
        move_history.append(cleaned_move)
        temp_board.push(move)
    
    return move_history

def clean_san_move(san_move):
    cleaned = san_move.replace('+', '').replace('#', '').replace('x', '')
    return cleaned

def new_move(moves):

    previous_moves = get_move_history_cleaned(board)

    legal_moves_san = []
    for move in board.legal_moves:
        san_move = board.san(move)
        legal_moves_san.append(san_move)

    while True:
        idx = torch.tensor([encode(previous_moves)], dtype=torch.long).to(device)
        move = decode(m.generate(idx)[0].tolist())[-1]
        print(move)
        move = clean_san_move(move)
        if move in legal_moves_san:
            print('correct!')
            break
        
    return move

# Initialize chess board
board = chess.Board()
move_history = []

# Create widgets
move_input = widgets.Text(
    value='',
    description='Your move:',
    style={'description_width': 'initial'}
)

submit_button = widgets.Button(
    description='Play Move',
    button_style='success'
)

reset_button = widgets.Button(
    description='Reset Game',
    button_style='warning'
)

output_area = widgets.Output()
status_label = widgets.HTML(value="<b>White to play. Enter your move above.</b>")

def display_board():
    """Display the current board position"""
    with output_area:
        clear_output(wait=True)
        
        # Create SVG of the board
        svg = chess.svg.board(board=board, size=400)
        
        # Display the SVG
        display(HTML(f'<div style="text-align: center;">{svg}</div>'))
        
        # Show move history
        if move_history:
            history_str = " ".join([f"{i//2 + 1}.{move}" if i % 2 == 0 else move 
                                  for i, move in enumerate(move_history)])
            print(f"\nMove history: {history_str}")

def play_move(button):
    """Handle user move and generate AI response"""
    global move_history
    
    user_move = move_input.value.strip()
    
    if not user_move:
        status_label.value = "<b style='color: red;'>Please enter a move!</b>"
        return
    
    try:
        # Parse and validate user move
        move = board.parse_san(user_move)
        
        # Make the user move
        board.push(move)
        move_history.append(user_move)
        
        # Check game state after user move
        if board.is_game_over():
            display_board()
            result = board.result()
            status_label.value = f"<b style='color: blue;'>Game Over! Result: {result}</b>"
            move_input.disabled = True
            submit_button.disabled = True
            return
        
        status_label.value = "<b style='color: orange;'>AI is thinking...</b>"
        display_board()
        
        # Generate AI move
        try:
            ai_move_san = new_move(move_history)  # Your transformer function
            ai_move = board.parse_san(ai_move_san)
            
            # Make AI move
            board.push(ai_move)
            move_history.append(ai_move_san)
            
            # Check game state after AI move
            if board.is_game_over():
                display_board()
                result = board.result()
                status_label.value = f"<b style='color: blue;'>Game Over! Result: {result}</b>"
                move_input.disabled = True
                submit_button.disabled = True
            else:
                status_label.value = f"<b>AI played: {ai_move_san}. Your turn!</b>"
                display_board()
            
        except Exception as e:
            status_label.value = f"<b style='color: red;'>AI error: {str(e)}</b>"
            # Undo user move if AI fails
            board.pop()
            move_history.pop()
            display_board()
    
    except ValueError as e:
        status_label.value = f"<b style='color: red;'>Invalid move: {str(e)}</b>"
    
    # Clear input
    move_input.value = ""

def reset_game(button):
    """Reset the game to starting position"""
    global move_history
    board.reset()
    move_history = []
    move_input.disabled = False
    submit_button.disabled = False
    move_input.value = ""
    status_label.value = "<b>White to play. Enter your move above.</b>"
    display_board()

def on_enter_key(change):
    """Handle Enter key press in text input"""
    if change['type'] == 'change' and change['name'] == 'value':
        # Small delay to allow the value to be processed
        import time
        time.sleep(0.1)
        if move_input.value.strip():
            play_move(None)

# Connect event handlers
submit_button.on_click(play_move)
reset_button.on_click(reset_game)

# Create the interface
interface = widgets.VBox([
    status_label,
    widgets.HBox([move_input, submit_button, reset_button]),
    output_area
])

# Initial board display
display_board()

# Display the interface
display(interface)


VBox(children=(HTML(value='<b>White to play. Enter your move above.</b>'), HBox(children=(Text(value='', descr…