In [1]:
import chess_game_tracker

In [2]:
import time

import torch
import torch.nn as nn
from torch.nn import functional as F

tokens_filepath = "..\\data_v2\\tokens.txt"
intermediate_filepath = "..\\data_v2\\intermediate_non_bullet.txt"
model_filepath = "..\\data_v2\\model_2014_non_bullet"



# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 64 # what is the maximum context length for predictions?
eval_interval = 400
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 100
n_embd = 800
n_head = 8
n_layer = 8
dropout = 0.0
# ------------

torch.manual_seed(1337)

print (device)

cuda


In [3]:
with open(tokens_filepath, "r") as f:
    tokens = f.read().splitlines()
    
tokens = ["ZERO", "START"] + tokens

token_to_int = {}
int_to_token = {}

for i, token in enumerate(tokens):
    token_to_int[token] = i
    int_to_token[i] = token
    
def encode(lst):
    return [token_to_int[token] for token in lst]
def decode(lst):
    return [int_to_token[i] for i in lst]

vocab_size = len(tokens)
print (vocab_size)

598


In [4]:
encoded_games = []
with open(intermediate_filepath, "r") as f:
    for line in f:
        game = encode(["START"] + line.split())
        if len(game) >= block_size + 1:
            encoded_games.append(game)

In [5]:
print (len(encoded_games))
print (len(encoded_games[100]))
print (len(encoded_games[200]))
print (encoded_games[300])
print (min([len(g) for g in encoded_games]))

33206
143
73
[1, 335, 311, 148, 151, 314, 350, 374, 330, 355, 239, 378, 404, 361, 346, 379, 21, 394, 49, 63, 270, 448, 428, 23, 589, 591, 391, 397, 401, 19, 384, 59, 585, 597, 589, 589, 249, 158, 437, 457, 239, 175, 24, 461, 104, 468, 469, 464, 464, 252, 141, 82, 131, 74, 141, 294, 151, 237, 166, 222, 97, 239, 291, 318, 303, 15, 172, 8, 104, 222, 113, 216, 485, 4, 44, 7, 51, 66, 483, 16, 120, 233, 127, 526, 118, 12, 531, 531, 127, 364, 268, 25, 119, 29, 284, 296, 265, 223, 48, 206, 281, 196, 182, 206, 165, 212, 148, 273, 110, 75, 101, 33, 92, 29, 43, 33, 91, 29, 100, 33, 109, 29, 118, 33, 119, 29, 128, 74, 119, 75]
65


In [6]:
n = int(0.9*len(encoded_games)) # first 90% will be train, rest val
train_data = encoded_games[:n]
val_data = encoded_games[n:]

In [7]:
def get_block_from_game(game, block_sz, rnd, offset):
    i = rnd % (len(game) - block_sz)
    i = i - (i % 2)
    assert i >= 0
    t = game[i + offset : i + block_sz + offset]
    if len(t) != block_sz:
        print (game)
        print (t)
        print (i, offset, block_sz)
        assert False
    return torch.tensor(t, dtype=torch.long)

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data)-1, (batch_size, 2))
    x = torch.stack([get_block_from_game(encoded_games[i[0]], block_size, i[1], 0) for i in ix])
    y = torch.stack([get_block_from_game(encoded_games[i[0]], block_size, i[1], 1) for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [8]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [9]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

def create_alternating_tensor(length):
    tensor = torch.zeros(length, dtype=torch.long, device=device)
    tensor[::2] = 1
    return tensor
    
class GPTModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.whose_turn_embedding_table = nn.Embedding(2, n_embd)
        
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        trn_emb = self.whose_turn_embedding_table(create_alternating_tensor(T)) # (T,C)
        x = tok_emb + pos_emb + trn_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

In [10]:
ara = torch.arange(4, device=device)
xxx = torch.tensor([100, 100, 100, 100], dtype=torch.long, device=device)
yyy = torch.tensor([200, 200, 200, 200], dtype=torch.long, device=device)
zzz = torch.stack([xxx, yyy])
alt = create_alternating_tensor(4)
print (zzz + alt)

tensor([[101, 100, 101, 100],
        [201, 200, 201, 200]], device='cuda:0')


In [11]:
def generate(model, idx, max_new_tokens, game):
    # idx is (B, T) array of indices in the current context
    for _ in range(max_new_tokens):
        assert len(idx.size()) == 2
        (B, T) = idx.size()
        assert B == 1
        if T > block_size:
            idx_cond = idx[:, 2:]
        else:
            idx_cond = idx
        
        (B, T) = idx_cond.size()
        assert T <= block_size
        
        # get the predictions
        logits, loss = model(idx_cond)  # (B, T, C)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        legal_moves = encode(game.encoded_legal_moves())
#         print (legal_moves)
        if len(legal_moves) == 0:
            return idx

        for i in range(vocab_size):
            if i not in legal_moves:
                logits[0, i] = -100000
        
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1) # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        
#         print (idx_next,  idx_next.tolist()[0], "hi")
        assert idx_next.tolist()[0][0] in legal_moves
        
        game.make_move(decode(idx_next.tolist()[0])[0])
        
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
    return idx

In [12]:
model = GPTModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')



62.515798 M parameters


In [13]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


In [None]:
train_time = time.time()
train_print = 10

for iter in range(10000000 + 1):
    # every once in a while evaluate the loss on train and val sets
    if iter > 0 and iter % train_print == 0:
        print ("step", iter)
        print (round((time.time() - train_time) / train_print, 2), "sec per train step")
        train_time = time.time()
    if iter > 0 and iter % eval_interval == 0:
        print ("estimating loss...")
        eval_time = time.time()
        losses = estimate_loss()
        print (round((time.time() - eval_time) / eval_iters, 2), "sec per eval step")
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        train_time = time.time()
    if iter > 0 and iter % (eval_interval * 10) == 0:
        print("saving")
        torch.save(m, model_filepath + ".pt")
        torch.save(m.state_dict(), model_filepath + "_state.pt")
        train_time = time.time()
        
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 10
1.22 sec per train step
step 20
1.18 sec per train step
step 30
1.19 sec per train step
step 40
1.19 sec per train step
step 50
1.19 sec per train step
step 60
1.19 sec per train step
step 70
1.19 sec per train step
step 80
1.19 sec per train step
step 90
1.19 sec per train step
step 100
1.19 sec per train step
step 110
1.19 sec per train step
step 120
1.19 sec per train step
step 130
1.19 sec per train step
step 140
1.19 sec per train step
step 150
1.19 sec per train step
step 160
1.19 sec per train step
step 170
1.19 sec per train step
step 180
1.2 sec per train step
step 190
1.19 sec per train step
step 200
1.19 sec per train step
step 210
1.19 sec per train step
step 220
1.19 sec per train step


In [None]:
losses = estimate_loss()
print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

In [None]:
game = chess_game_tracker.Game()
context = torch.zeros((1, 1), dtype=torch.long, device=device)
context[0][0] = 1
print(decode(generate(m, idx = context, max_new_tokens=10, game=game)[0].tolist()))
print (game.get_game_pgn())

In [None]:
torch.save(m, model_filepath + ".pt")
torch.save(m.state_dict(), model_filepath + "_state.pt")