In [2]:
import chess
import chess_game_tracker

In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F

tokens_filepath = "tokens.txt"
model_filepath = "v1_model"

# hyperparameters
block_size = 64 # what is the maximum context length for predictions?
eval_iters = 100
n_embd = 400
n_head = 4
n_layer = 4
dropout = 0.0


device = 'cpu'

In [4]:
with open(tokens_filepath, "r") as f:
    tokens = f.read().splitlines()
    
tokens = ["ZERO", "START"] + tokens

token_to_int = {}
int_to_token = {}

for i, token in enumerate(tokens):
    token_to_int[token] = i
    int_to_token[i] = token
    
def encode(lst):
    return [token_to_int[token] for token in lst]
def decode(lst):
    return [int_to_token[i] for i in lst]

vocab_size = len(tokens)
print (vocab_size)

360


In [5]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [6]:
def generate(model, idx, max_new_tokens, game):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -block_size:]
        # get the predictions
        logits, loss = model(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        legal_moves = encode(game.encoded_legal_moves())
#         print (legal_moves)
        if len(legal_moves) == 0:
            return idx

        for i in range(vocab_size):
            if i not in legal_moves:
                logits[0, i] = -100000
        
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        
#         print (idx_next,  idx_next.tolist()[0], "hi")
        assert idx_next.tolist()[0][0] in legal_moves
        
        game.make_move(decode(idx_next.tolist()[0])[0])
        
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [7]:
model = BigramLanguageModel()
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

8.01076 M parameters


In [8]:
model.load_state_dict(torch.load(model_filepath + "_state.pt"))

<All keys matched successfully>

In [9]:
game = chess_game_tracker.Game()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
context[0][0] = 1
print(decode(generate(model, idx = context, max_new_tokens=100, game=game)[0].tolist()))
print (game.get_game_pgn())

['START', 'Pawn:e2:TwoPush', 'Pawn:g7:Push', 'Pawn:d2:TwoPush', 'N1:-1:-2', 'Pawn:c2:Push', 'Pawn:e7:Push', 'BD:g5', 'BD:e7', 'BD:f6', 'BD:f6', 'BL:c4', 'Pawn:c7:Push', 'N2:-2:1', 'Pawn:d7:TwoPush', 'Pawn:e4:CaptureLeft', 'Pawn:c6:CaptureRight', 'BL:b3', 'BL:d7', 'K:2:0', 'N2:-2:-1', 'N2:2:1', 'BD:g5', 'Pawn:h2:Push', 'BL:c6', 'Pawn:a2:Push', 'Pawn:b7:TwoPush', 'N1:2:1', 'K:2:0', 'Q:f3', 'N1:1:-2', 'Q:g4', 'BD:h6', 'N1:1:2', 'Pawn:f7:TwoPush', 'Q:d1', 'Pawn:f5:CaptureLeft', 'N2:-1:2', 'Q:e8', 'N2:1:-2', 'R1:2:0', 'N2:-1:2', 'Q:d7', 'Pawn:g2:Push', 'Pawn:e6:Push', 'N2:2:1', 'K:0:-1', 'N2:-2:1', 'R1:2:0', 'K:1:1', 'Pawn:h7:TwoPush', 'N2:-2:1', 'K:1:-1', 'Q:d3', 'Q:e6', 'Pawn:g3:Push', 'Q:f7', 'Q:g3', 'Pawn:a7:TwoPush', 'R1:1:0', 'N1:2:-1', 'Q:f3', 'R2:1:0', 'R2:2:0', 'Pawn:a5:Push', 'R2:-4:0', 'Pawn:a4:CaptureRight', 'R2:0:1', 'N1:-2:1', 'Q:d1', 'N2:1:-2', 'Pawn:h3:Push', 'R2:1:0', 'Q:f3', 'Pawn:h5:CaptureLeft', 'Q:e2', 'N2:-1:2', 'K:-1:0', 'Pawn:g4:Push', 'K:0:1', 'N2:1:-2', 'K:0:-1', '

In [25]:
def EncodeMoveSequence(move_list):
    game = chess_game_tracker.Game()
    for dec_move in move_list:
        game.make_dec_move(dec_move)
    return game, encode(['START'] + game.enc_moves)

def GetWeightedMovesFromModel(mdl, move_list):
    game, enc_moves = EncodeMoveSequence(move_list)
    #print (enc_moves, decode(enc_moves))

    t = torch.tensor(enc_moves, dtype=torch.long)
    t = t[-block_size:]
    t = t[None, :]
    #print (t, t.shape)

    logits, _ = mdl(t)
    logits = logits[:, -1, :]
    logits = logits[0]

    # print (logits.shape)

    legal = game.encoded_legal_moves()
    moves_and_weights = list(enumerate(logits.tolist()))
    moves_and_weights = sorted(moves_and_weights, key=lambda x: -x[1])
    
    final_list = []

    for i, val in moves_and_weights:
        if int_to_token[i] in legal:
            final_list.append((int_to_token[i], game.decode_possible_move(int_to_token[i]), val))
        else:
            final_list.append((int_to_token[i], None, val))
    return final_list

def GetTopLegalMove(moves_and_weights):
    for _, mv, _ in moves_and_weights:
        if mv is not None:
            return mv

In [27]:
moves_and_weights = GetWeightedMovesFromModel(model, 
    [
        chess.Move.from_uci('e2e4'),
        chess.Move.from_uci('e7e5'),
    ])
# for tpl in moves_and_weights:
#     print (tpl)
print (GetTopLegalMove(moves_and_weights))

g1f3


In [29]:
moves = []

for i in range(100):
    moves_and_weights = GetWeightedMovesFromModel(model, moves)
    move = GetTopLegalMove(moves_and_weights)
    # print (move)
    moves.append(move)
    
print (chess.Board().variation_san(moves))
    

1. e4 e5 2. Nf3 Nc6 3. Bc4 Bc5 4. c3 Nf6 5. d4 exd4 6. cxd4 Bb4+ 7. Nc3 Nxe4 8. O-O Nxc3 9. bxc3 Be7 10. d5 Na5 11. Bd3 O-O 12. Qa4 b6 13. Bd2 Bf6 14. Rae1 d6 15. c4 Nb7 16. Re2 Nc5 17. Qc2 Nxd3 18. Qxd3 Bg4 19. Rfe1 Qd7 20. Bc3 Bxc3 21. Qxc3 Bxf3 22. Qxf3 Rae8 23. h3 Rxe2 24. Qxe2 g6 25. Qe4 Kg7 26. a4 a5 27. Qd4+ Kg8 28. Qd1 Re8 29. Rf1 Re4 30. Qd3 Qe7 31. Qc3 Qe8 32. Qe3 Rxe3 33. Kh2 Re6 34. f4 Qe7 35. f5 g5 36. Re1 Qd8 37. Rc1 Re3 38. Kg1 c6 39. c5 dxc5 40. Rc3 Rd3 41. Rxc5 Rc3 42. Kf2 Rc2+ 43. Ke3 Rc3+ 44. Ke4 Rc2 45. Rb5 Rc3 46. Ke5 Rc4 47. g4 h5 48. gxh5 Qd6+ 49. Kxd6 Kg7 50. h4 Kg8
