### GPT for Learning how to play 2048

Inspired by [GPT from scratch](https://www.youtube.com/watch?v=kCc8FmEb1nY)

#### Explore the data

In [54]:
with open('2048_move_data.txt', 'r') as f:
    text = f.read()

print(text[:1000])
print(len(text))

00000100000002040100020000010000;R;01000001000002040000010200000001
00000000000001040201000000000202;U;02010104000002020001000000000000
00000001000000010003010000000101;R;01000001000000010000030100000002
00000000010000000100000201000106;R;00000000000001010000010200000206
00010000040100010001000300010001;L;01000000040200000103000002000100
00020100000101000101010304000401;U;01020203040201010000040000010000
00000101000001020100030400030101;U;01030201000003020000010400000101
01020003050000000500000002010201;R;00010203010000050000000502010201
00020000010400000103020000040400;D;00020100000400000003020002040400
00000001000000000000000100020000;U;00020002020000000000000000000000
04010709040203000000000002000002;U;05010709020203020000000000020000
00000201080000030001050101000400;L;02010001080300000105010001040000
00010200000106000000000400020000;R;00000102000001060100000400000002
01010000000100000400000301000200;D;00000000010100000400000001020203
00010201000000020100040001000002;D;0200000000000

#### Tokenize

Each tile is represented by 2 digit, 0-led int. Will convert tile into regular int. 

Each state is separated by comma, each game by ';\n'. Will convert these into int as well

In [55]:
import re
vocab = ['0'+str(i) for i in range(10)] + [str(i) for i in range(10, 18)]  + [';','\n','L','R','U','D']
mapping = {k: v for v, k in enumerate(vocab)}
inv_mapping = {v: k for k, v in mapping.items()}

def encode(s:str)-> list[str]:
    out = []
    pattern = r'\d{2}|[\n;RLUD]'
    out = re.findall(pattern, s)
    out = list(map(lambda k: mapping[k], out))
    return out

def decode(l:list[int]) -> str:
    s = ''.join(list(map(lambda v: str(inv_mapping[v]), l)))
    return s        


print(encode('01010001000000000003010202000001;R;00000102010000000003010200000201\n06000200010006000003000001000005;R;01000602000001060000000300000105'))
print(decode(encode('01010001000000000003010202000001;R;00000102010000000003010200000201\n06000200010006000003000001000005;R;01000602000001060000000300000105')))


[1, 1, 0, 1, 0, 0, 0, 0, 0, 3, 1, 2, 2, 0, 0, 1, 18, 21, 18, 0, 0, 1, 2, 1, 0, 0, 0, 0, 3, 1, 2, 0, 0, 2, 1, 19, 6, 0, 2, 0, 1, 0, 6, 0, 0, 3, 0, 0, 1, 0, 0, 5, 18, 21, 18, 1, 0, 6, 2, 0, 0, 1, 6, 0, 0, 0, 3, 0, 0, 1, 5]
01010001000000000003010202000001;R;00000102010000000003010200000201
06000200010006000003000001000005;R;01000602000001060000000300000105


#### Load data

Get train test split

Set up batches

In [56]:
import torch

torch.manual_seed(1748)

data = torch.tensor(encode(text), dtype=torch.long)
vocab_size = len(vocab)
batch_size = 64
line_size = 36
x_size = 18 # two boards * n
y_size = 18
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_iters = 5000
eval_iters = 200
eval_interval = num_iters//10
n_embed = 384
n_head = 6
n_layer = 6
dropout= 0.1
learning_rate = 3e-4

print(device)

# get first game in last 20% of data
n = int(0.8*len(data))
while n % line_size != 0:
    n += 1

train_data = data[:n]
test_data = data[n:]
print(n, len(data), len(train_data), len(test_data))

def get_batch(split:bool=0)-> list[torch.Tensor]:
    # split == 0: train, 1: test
    data = train_data if split == 0 else test_data
    ix = torch.randint(len(data)//line_size, (batch_size,))
    # print((46*line_size+line_size-1) - (46*line_size+x_size))
    x = torch.stack([data[i*line_size:i*line_size+x_size] for i in ix])
    y = torch.stack([data[i*line_size+x_size:i*line_size+line_size-1] for i in ix])
    target = torch.stack([data[i*line_size+x_size+1:i*line_size+line_size] for i in ix])
    x = x.to(device)
    y = y.to(device)
    target = target.to(device)
    return x, y, target

x, y, target = get_batch()

print(x)

print(y)

print(target)



cuda
8660844 10826028 8660844 2165184
tensor([[ 0,  0,  3,  ...,  1, 18, 23],
        [ 3,  2,  1,  ...,  0, 18, 20],
        [ 5,  0,  1,  ...,  0, 18, 20],
        ...,
        [ 1,  0,  3,  ...,  2, 18, 23],
        [ 9,  5,  0,  ...,  0, 18, 23],
        [ 0,  1,  1,  ...,  0, 18, 23]], device='cuda:0')
tensor([[18,  0,  0,  ...,  2,  2,  1],
        [18,  3,  2,  ...,  1,  0,  0],
        [18,  5,  1,  ...,  1,  0,  0],
        ...,
        [18,  1,  1,  ...,  2,  2,  2],
        [18,  1,  5,  ...,  7, 10,  1],
        [18,  0,  0,  ...,  3,  5,  4]], device='cuda:0')
tensor([[ 0,  0,  0,  ...,  2,  1, 19],
        [ 3,  2,  1,  ...,  0,  0, 19],
        [ 5,  1,  0,  ...,  0,  0, 19],
        ...,
        [ 1,  1,  0,  ...,  2,  2, 19],
        [ 1,  5,  0,  ..., 10,  1, 19],
        [ 0,  0,  0,  ...,  5,  4, 19]], device='cuda:0')


#### Bigram Model

Loss With multi-self attention: 4900: Train loss: 1.2557, Val loss: 1.1633

Residual Blocks: 4900: Train loss: 1.2156, Val loss: 1.1171
Seems to have learned state size

Scaling up model 4500: Train loss: 0.4378, Val loss: 0.4096

In [57]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1748)


class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class MaskedHead(nn.Module):

    def __init__(self, head_size, dim):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(dim, dim)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)

        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out
    
class Head(nn.Module):

    def __init__(self, head_size, dim):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(dim, dim)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, K=None, V=None):
        if K is None:
            k = self.key(x)
        else:
            k = self.key(K)
        if V is None:
            v = self.value(x)
        else:
            v = self.value(V)

        B, T, C = x.shape
        q = self.query(x)
        

        wei = q @ k.transpose(-2,-1) * C**-0.5
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        
        out = wei @ v
        return out


class MaskedMultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size, dim):
        super().__init__()
        self.heads = nn.ModuleList([MaskedHead(head_size, dim) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads, head_size, dim):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, dim) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, K=None, V=None):
        out = torch.cat([h(x, K, V) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out
    
class FeedForward(nn.Module):

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)
    
class DecoderBlock(nn.Module):

    def __init__(self, n_embed, n_head, dim):
        super().__init__()
        head_size = n_embed//n_head
        self.msa = MaskedMultiHeadAttention(n_head, head_size, dim)
        self.sa = MultiHeadAttention(n_head, head_size, dim)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        self.ln3 = nn.LayerNorm(n_embed)

    def forward(self, x, K, V):
        x = x + self.msa(self.ln1(x))
        x = x + self.sa(self.ln2(x), K, V)
        x = x + self.ffwd(self.ln3(x))
        return x
    
class EncoderBlock(nn.Module):

    def __init__(self, n_embed, n_head, dim):
        super().__init__()
        head_size = n_embed//n_head
        self.sa = MultiHeadAttention(n_head, head_size, dim)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x
    
class EncoderTransformer(nn.Module):
    def __init__(self, vocab_size=vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(x_size, n_embed)
        self.blocks = nn.Sequential(*[EncoderBlock(n_embed, n_head, x_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)


    def forward(self, idx):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)

        return x

class DecoderTransformer(nn.Module):

    def __init__(self, vocab_size=vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(y_size, n_embed)
        self.blocks = nn.ModuleList([DecoderBlock(n_embed, n_head, y_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        self.encoder = EncoderTransformer(vocab_size=vocab_size)

    def forward(self, idx, idy=None, targets=None ):
        # Bx, Tx = idx.shape
        By, Ty = idy.shape

        x = self.encoder(idx)

        tok_emb = self.token_embedding_table(idy) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(Ty, device=device)) # (T, C)
        y = tok_emb + pos_emb
        for block in self.blocks:
            y = block(y, x, x)
        y = self.ln_f(y)
        logits = self.lm_head(y) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, idy, max_new_tokens):
        # idx is (B,T)
        for  _ in range(max_new_tokens):
            idy_cond = idy[:, -y_size:]
            logits, loss = self(idx, idy_cond)
            logits = logits[:,-1,:] # last time step, (B,C)
            probs = F.softmax(logits, dim=-1)
            idy_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idy = torch.cat((idy, idy_next), dim=1) #(B, T+1)
        return idy

print(x.device, y.device, target.device)
    
model = DecoderTransformer()
m = model.to(device)

logits, loss = m(x,y, target)
print(logits.shape)
print(loss)

test_input = torch.tensor([encode('01010001000000000003010202000001;R'),], dtype=torch.long, device=device)
print(decode(m.generate(test_input, torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=y_size)[0].tolist()))

cuda:0 cuda:0 cuda:0
torch.Size([1088, 24])
tensor(3.3329, device='cuda:0', grad_fn=<NllLossBackward0>)
00111116160114D07D06140808D1616D14


In [58]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in [0, 1]:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X,Y, target = get_batch(split)
            logits, loss = model(X,Y, target)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


optimizer = torch.optim.Adam(m.parameters(), lr=learning_rate)



for iter in range(num_iters):

    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f'{iter}: Train loss: {losses[0]:.4f}, Val loss: {losses[1]:.4f}')
    xb, yb, target = get_batch(0)

    logits, loss = m(xb, yb, target)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

test_input = torch.tensor([encode('01010001000000000003010202000001;R'),], dtype=torch.long, device=device)
print(decode(m.generate(test_input, torch.zeros((1,1), dtype=torch.long, device=device), max_new_tokens=y_size)[0].tolist()))



0: Train loss: 3.3274, Val loss: 3.3532
500: Train loss: 0.3128, Val loss: 0.3783
1000: Train loss: 0.1917, Val loss: 0.1698
1500: Train loss: 0.1435, Val loss: 0.0859
2000: Train loss: 0.1340, Val loss: 0.0748
2500: Train loss: 0.1293, Val loss: 0.0656
3000: Train loss: 0.1275, Val loss: 0.0629
3500: Train loss: 0.1247, Val loss: 0.0599
4000: Train loss: 0.1264, Val loss: 0.0614
4500: Train loss: 0.1236, Val loss: 0.0571
0000000102000100000003010200000201
01


In [1]:
def count_digits(n):
    count = 1
    while n // 10 >= 1:
        n /= 10 
        count += 1
    return count

def print_states(s):
    for i, game in enumerate(s.split(f'\n')): 
        for state in game.split(';'):
            if len(state) == 1:
                print(state)
            else:
                enc_state = [2**int(''.join(state[i:i+2])) for i in range(0, len(state), 2)]
                for i, tile in enumerate(enc_state):
                    print(' ' * max(0, 5 - count_digits(tile)) + (str(tile) if tile != 1 else '0'), end=('\n' if (i+1) % 4 == 0 else ''))

            print('\n')

test_input = '02120005000000000012050502120001;D'
res = decode(m.generate(torch.tensor([encode(test_input),], dtype=torch.long, device=device), torch.tensor([encode(';'),], dtype=torch.long, device=device), max_new_tokens=y_size-1)[0].tolist())
print(res)
print([len(state) for state in res.split(',')])
print_states(test_input)
print_states(res)

NameError: name 'decode' is not defined

In [None]:
import numpy as np

# Test Model
def apply_move(s, move):
        # plays move
        # 0:r,1:l,2:u,3:d
        s = [2**int(''.join(s[i:i+2])) for i in range(0, len(s), 2)]
        for i in range(16):
            if s[i] == 1:
                s[i] = 0

        if move == 0:
            for k in range(0, 3):
                for i in range(3 - k, 16 - k, 4):
                    for j in range(1, 4 - k):
                        if s[i] == 0 and s[i - j] != 0:
                            s[i] = s[i - j]
                            s[i - j] = 0
                        if s[i] != 0 and s[i - j] == s[i]:
                            s[i] *= 2
                            s[i - j] = 0
                            break
                        if s[i] != 0 and s[i - j] != 0 and s[i - j] != s[i]:
                            break
        if move == 1:
            for k in range(0, 3):
                for i in range(0 + k, 13 + k, 4):
                    for j in range(1, 4 - k):
                        if s[i] == 0 and s[i + j] != 0:
                            s[i] = s[i + j]
                            s[i + j] = 0
                        if s[i] != 0 and s[i + j] == s[i]:
                            s[i] *= 2
                            s[i + j] = 0
                            break
                        if s[i] != 0 and s[i + j] != 0 and s[i + j] != s[i]:
                            break
        if move == 2:
            for k in range(0, 3):
                for i in range(0 + 4 * k, 4 + 4 * k):
                    for j in range(1, 4 - k):
                        if s[i] == 0 and s[i + j * 4] != 0:
                            s[i] = s[i + j * 4]
                            s[i + j * 4] = 0
                        if s[i] != 0 and s[i + j * 4] == s[i]:
                            s[i] *= 2
                            s[i + j * 4] = 0
                            break
                        if s[i] != 0 and s[i + j * 4] != 0 and s[i + j * 4] != \
                                s[i]:
                            break
        if move == 3:
            for k in range(0, 3):
                for i in range(12 - 4 * k, 16 - 4 * k):
                    for j in range(1, 4 - k):
                        if s[i] == 0 and s[i - j * 4] != 0:
                            s[i] = s[i - j * 4]
                            s[i - j * 4] = 0
                        if s[i] != 0 and s[i - j * 4] == s[i]:
                            s[i] *= 2
                            s[i - j * 4] = 0
                            break
                        if s[i] != 0 and s[i - j * 4] != 0 and s[i - j * 4] != \
                                s[i]:
                            break

        for i in range(16):
            if s[i] == 0:
                s[i] = 1
        s = [np.log2(s[i]).astype(int) for i in range(0, 16)]
        out = ''
        for tile in s:
            if tile < 10:
                out += "0"+str(tile)
            else:
                out += str(tile)
        return out
    

def check_accuracy(s, move, out):
    # Check if generated state is feasible from input state, measured as a percentage of correct tiles
    num_extra_2s = 0 # counts extra 2s or 4s which could have been added as the 'random' tile after a move
    num_errors = 0

    s = apply_move(s, move)
    for i in range(0, len(s), 2):
        if s[i:i+2] == out[i:i+2]:
            continue
        if s[i:i+2] == 0 and (out[i:i+2] == 2 or out[i:i+2] == 4): # empty tile filled
            num_extra_2s += 1
            continue
        else:
            num_errors += 1

    if num_extra_2s == 0:
        num_errors += 1
    elif num_extra_2s > 1:
        num_errors += num_extra_2s-1

    return num_errors/16


In [None]:
test_input = '02120005000000000012050502120001;D'
res = decode(m.generate(torch.tensor([encode(test_input),], dtype=torch.long, device=device), torch.tensor([encode(';'),], dtype=torch.long, device=device), max_new_tokens=y_size)[0].tolist())
s, move = test_input.split(';')

print(check_accuracy)