In [80]:
import torch
import torch.nn as nn
from torch.nn import functional as F

tokens_filepath = "C:\\Users\\anton\\Documents\\code\\chessGPT\\data\\tokens.txt"
intermediate_filepath = "C:\\Users\\anton\\Documents\\code\\chessGPT\\data\\intermediate.txt"

torch.manual_seed(1337)


# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

print (device)

cpu


In [81]:
with open(tokens_filepath, "r") as f:
    tokens = f.read().splitlines()
    
tokens = ["SPECIAL"] + tokens

# print (tokens)

In [82]:
token_to_int = {}
int_to_token = {}

for i, token in enumerate(tokens):
    token_to_int[token] = i
    int_to_token[i] = token
    
# print (token_to_int)

In [102]:
def encode(lst):
    return [token_to_int[token] for token in lst]
def decode(lst):
    return [int_to_token[i] for i in lst]

In [84]:
vocab_size = len(tokens)
vocab_size

359

In [85]:
with open(intermediate_filepath, "r") as f:
    print(torch.tensor(encode(f.readline().split()), dtype=torch.long))
    print(torch.tensor(encode(f.readline().split()), dtype=torch.long))
    print(torch.tensor(encode(f.readline().split()), dtype=torch.long))


tensor([167, 184,  80,  83,  42, 164,  38, 144,  34,  48,  48, 269, 148, 175,
        169, 136, 266, 276,  18,  83,  84,  85, 257,  79,  86,  79, 266, 104,
         92, 124,  74,  20, 326,  74, 267, 267, 320, 117, 318, 331,  94,   6,
        127,  20, 332,  99, 333, 237, 312,  23, 305, 313, 347, 224, 207, 235,
         70,  69, 227, 216, 210,  30, 358,  69, 108, 340,  14, 353,  67, 204,
        131, 338, 331, 196,  72,  69, 114, 233,  69, 231,  72,  71,  19,  67,
        117,  66,  26,  65, 120,  71])
tensor([147, 164, 207, 183,  57,  83,  84,  20,  74,  74, 331, 103, 168, 144,
        174,  86, 151,  85, 152,  11,  80,  81,  93,  20,  97,  87, 107,  90,
         84,  79, 348,  15, 341,  48,  22,  43,  15, 268,  78, 326,  80, 331,
        319,  34, 109, 259, 239, 243, 266, 311,  79, 250,  80, 266,  80, 264,
        333, 256, 350, 242, 345,  66,  75,  65,  81,  71,  81, 263,  53, 266,
         70, 101, 338,  99, 355,  70,  77,  66,  76,  65,  79, 258,  78, 257,
         82,  71,  75,  7

In [86]:
encoded_games = []
with open(intermediate_filepath, "r") as f:
    for line in f:
        game = encode(line.split())
        if len(game) >= block_size + 1:
            encoded_games.append(game)

In [87]:
print (len(encoded_games))
print (len(encoded_games[100]))
print (len(encoded_games[200]))
print (encoded_games[300])
print (min([len(g) for g in encoded_games]))

43978
54
55
[84, 237, 148, 183, 167, 163, 80, 143, 264, 160, 46, 140, 84, 136, 169, 20, 39, 48, 272, 79, 43, 122, 84, 269, 244, 83, 22, 74, 272, 140, 275, 87, 187, 23, 272, 136, 79, 180, 226, 176, 228, 331, 80, 153, 326, 5, 348, 10, 72, 157, 258, 325, 272, 71, 214, 223, 345, 66, 72, 354, 207, 193, 191, 339, 216, 72, 70, 283, 70, 303, 258, 315, 319, 324, 319, 17, 66, 318, 325, 285]
33


In [88]:
n = int(0.9*len(encoded_games)) # first 90% will be train, rest val
train_data = encoded_games[:n]
val_data = encoded_games[n:]

In [89]:
def get_block_from_game(game, block_sz, rnd, offset):
    i = rnd % (len(game) - block_sz)
    t = game[i + offset : i + block_sz + offset]
    if len(t) != block_sz:
        print (game)
        print (t)
        print (i, offset, block_sz)
        assert False
    return torch.tensor(t, dtype=torch.long)

# data loading
def get_batch(split, block_sz, batch_sz):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-1, (batch_sz, 2))
    x = torch.stack([get_block_from_game(encoded_games[i[0]], block_sz, i[1], 0) for i in ix])
    y = torch.stack([get_block_from_game(encoded_games[i[0]], block_sz, i[1], 1) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [90]:
xb, yb = get_batch('train', block_sz=block_size, batch_sz=batch_size)

print('inputs:')
print(xb.shape)
print('targets:')
print(yb.shape)

print('----')


inputs:
torch.Size([16, 32])
targets:
torch.Size([16, 32])
----


In [116]:
import chess
import chess_encode
import chess_decode


In [138]:
class Game:
    def __init__(self):
        self.board = chess.Board()
        self.enc = chess_encode.EncoderBoard()
        self.dec = chess_decode.DecoderBoard()
        self.moves = []

    def encoded_legal_moves(self):
        return [self.enc.EncodeMove(mv, make_move=False) for mv in self.board.legal_moves]
    
    def make_move(self, enc_move):
        move = self.dec.DecodeMove(enc_move)
        self.enc.EncodeMove(move, make_move=True)
        self.board.push(move)
        self.moves.append(move)
        return move
    
    def get_game_pgn(self):
        return chess.Board().variation_san(self.moves)

In [108]:
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        # idx and targets are both (B,T) tensor of integers
        logits = self.token_embedding_table(idx) # (B,T,C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


In [92]:
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)


torch.Size([512, 359])
tensor(6.3635, grad_fn=<NllLossBackward0>)


In [None]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)


In [145]:
for steps in range(50001): # increase number of steps for good results... 
    # sample a batch of data
    xb, yb = get_batch('train', block_sz=block_size, batch_sz=batch_size)

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % 1000 == 0:
        print(steps, loss.item())

0 4.798133850097656
1000 4.724464416503906
2000 4.956279754638672
3000 4.824542999267578
4000 4.892591953277588
5000 4.910529136657715
6000 4.863946437835693
7000 4.724982738494873
8000 4.740623950958252
9000 4.950060844421387
10000 4.908806800842285
11000 4.703491687774658
12000 4.864297866821289
13000 4.888840198516846
14000 4.853829383850098
15000 4.6967692375183105
16000 4.667713165283203
17000 4.667600154876709
18000 4.944421768188477
19000 4.769054889678955
20000 4.633237361907959
21000 4.810161590576172
22000 4.897270202636719
23000 4.900995254516602
24000 4.930427551269531
25000 4.825490951538086
26000 4.874955177307129
27000 4.862794876098633
28000 4.8221917152404785
29000 4.822365760803223
30000 4.739196300506592
31000 4.750077724456787
32000 4.767266750335693
33000 4.927359580993652
34000 4.75062894821167
35000 4.845069885253906
36000 4.9533820152282715
37000 4.647576332092285
38000 4.791723251342773
39000 4.736194133758545
40000 4.6152663230896
41000 4.74415922164917
42000 

In [142]:
def generate(model, idx, max_new_tokens, game):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        # get the predictions
        logits, loss = model(idx)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        legal_moves = encode(game.encoded_legal_moves())
        if len(legal_moves) == 0:
            return idx

        for i in range(vocab_size):
            if i not in legal_moves:
                logits[0, i] = -1000

        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        
        assert idx_next[0] in legal_moves

        game.make_move(decode(idx_next.tolist()[0])[0])

        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [146]:
game = Game()
print(decode(generate(m, idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100, game=game)[0].tolist()))
print (game.get_game_pgn())

['SPECIAL', 'Pawn:g2:TwoPush', 'N1:1:-2', 'N2:-1:2', 'N1:1:-2', 'N2:-1:2', 'Pawn:h7:Push', 'R2:-1:0', 'R1:1:0', 'N2:-2:-1', 'Pawn:f7:TwoPush', 'Pawn:h2:Push', 'Pawn:g7:TwoPush', 'Pawn:b2:TwoPush', 'N1:2:-1', 'Pawn:e2:CaptureRight', 'Pawn:e7:TwoPush', 'BL:g2', 'BD:c5', 'Pawn:h3:Push', 'Pawn:c7:Push', 'N1:1:2', 'N2:-2:-1', 'N1:1:2', 'Pawn:d7:Push', 'Pawn:d2:Push', 'N2:-1:-2', 'N2:2:1', 'N2:-1:2', 'BD:g5', 'N2:2:-1', 'K:0:1', 'R2:-2:0', 'R1:2:0', 'BL:d7', 'Q:d2', 'R1:1:0', 'R2:-2:0', 'R2:0:-1', 'R1:1:0', 'K:1:0', 'R2:2:0', 'R2:0:-1', 'N2:2:1', 'K:0:-1', 'N2:-2:-1', 'K:-1:1', 'BD:h6', 'R2:1:0', 'K:1:-1', 'R2:0:-1', 'Q:g5', 'Q:a5', 'Pawn:a2:Push', 'BD:d4', 'R1:1:0', 'Pawn:b7:TwoPush', 'Pawn:a3:Push', 'R1:1:0', 'R2:1:0', 'N2:1:-2', 'BD:g7', 'BD:b2', 'R1:0:1', 'Q:b6', 'N2:-1:2', 'BD:e5', 'N2:-2:-1', 'R1:0:-1', 'R1:-1:0', 'R1:-1:0', 'K:-1:0', 'R1:1:0', 'BD:f6', 'BD:f6', 'Pawn:h4:Push', 'R1:-2:0', 'N2:2:1', 'N2:-2:1', 'Q:f6', 'R1:2:0', 'Q:d6', 'R1:2:0', 'K:-1:0', 'Pawn:a7:Push', 'R2:0:1', 'R1:-