### GPT for Learning how to play 2048

Inspired by [GPT from scratch](https://www.youtube.com/watch?v=kCc8FmEb1nY)

#### Explore the data

In [5]:
with open('games.txt', 'r') as f:
    text = f.read()

print(text[:1000])
print(len(text))

00000201000000000000000000000000,00000000000000000000000000010201,00010201000000000000000000000001,00010202010000000000000000000000,01030000010000010000000000000000,02030001000000000000000000010000,02030001000100000000000001000000,02030001010100000000000000010000,02030001010200000000000000020000,02030001010300000000000001000000,02040001020000000000010000000000,03040101000000000000000000000100,03040201000000000000000000010000,03040201010100000000000000000000,03040201020000000002000000000000,03040201020200010000000000000000,03040202020200010000000000000000,03040300030100000001000000000000,04040300000200000000000000020000,00000000000000000004010004030300,00000000000000000000040101000404,00000000000000010000040100000105,00000000000000000001040200000105,00000000000000000001040200010105,00000000010000000001040200000205,00000100000000000000040201010205,00010001000000000000040200020205,00000002010000000000040200000305,00000100000000000000040301000305,00000000000001000100040301000305,0000010000

#### Tokenize

Each tile is represented by 2 digit, 0-led int. Will convert tile into regular int. 

Each state is separated by comma, each game by ';\n'. Will convert these into int as well

In [43]:
mapping = {',':20, ';\n':21}
inv_mapping = {v: k for k, v in mapping.items()}

def encode(s:str)-> list[str]:
    out = []
    num_splits = len(s.split(f';\n'))
    for i, game in enumerate(s.split(f';\n')): 
        for state in game.split(',')[:-1]:
            enc_state = [int(''.join(state[i:i+2])) for i in range(0, len(state), 2)]
            enc_state.append(mapping[','])
            out += enc_state
        out.append(mapping[';\n']) if i < num_splits-1 else None
    
    return out

def decode(l:list[int]) -> str:
    s = ''
    for char in l:
        if char < 10:
            s += '0'+ str(char)
        elif char < 20: # hard coded, should fix
            s += str(char)
        else:
            s += inv_mapping[char]

    return s        


print(encode('02000000020303010607060507081011,01000000030303010607060507081011,;\n01000000030303010607060507081011,'))
print(decode(encode('02000000020303010607060507081011,01000000030303010607060507081011,;\n01000000030303010607060507081011,')))


[2, 0, 0, 0, 2, 3, 3, 1, 6, 7, 6, 5, 7, 8, 10, 11, 20, 1, 0, 0, 0, 3, 3, 3, 1, 6, 7, 6, 5, 7, 8, 10, 11, 20, 21, 1, 0, 0, 0, 3, 3, 3, 1, 6, 7, 6, 5, 7, 8, 10, 11, 20]
02000000020303010607060507081011,01000000030303010607060507081011,;
01000000030303010607060507081011,


#### Load data

Get train test split

Set up batches

In [49]:
import torch

torch.manual_seed(1748)

data = torch.tensor(encode(text), dtype=torch.short)

batch_size = 4
block_size = 34 # two boards

# get first game in last 20% of data
n = int(0.8*len(text))
while(text[n] != ';'):
    n += 1
n += 2

train_data = data[:n]
test_data = data[n:]

def get_batch(split:bool=0)-> list[torch.Tensor]:
    # split == 0: train, 1: test
    data = train_data if split == 0 else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

x, y = get_batch()

print(x)

print(y)



tensor([[ 1,  2,  9, 11, 20,  1,  0,  0,  2,  0,  1,  4,  3,  0,  8,  6,  6,  1,
          2,  9, 11, 20,  0,  0,  1,  2,  0,  1,  4,  3,  0,  8,  6,  6],
        [20,  0,  0,  2,  2,  1,  3,  4,  3,  6,  1,  2,  1,  4,  7,  8, 10, 20,
          0,  0,  1,  3,  1,  3,  4,  3,  6,  1,  2,  1,  4,  7,  8, 10],
        [ 7,  4,  0,  0,  6,  1, 20,  0,  1,  2,  2,  0,  4,  4,  5,  1,  2,  7,
          4,  1,  0,  6,  1, 20,  1,  3,  0,  0,  5,  5,  0,  0,  1,  2],
        [ 0,  0,  0,  1,  3,  2,  0,  0,  7,  7,  8,  9, 20,  0,  0,  1,  0,  0,
          0,  0,  1,  0,  0,  3,  2,  0,  8,  8,  9, 20,  0,  0,  0,  1]],
       dtype=torch.int16)
tensor([[ 2,  9, 11, 20,  1,  0,  0,  2,  0,  1,  4,  3,  0,  8,  6,  6,  1,  2,
          9, 11, 20,  0,  0,  1,  2,  0,  1,  4,  3,  0,  8,  6,  6,  2],
        [ 0,  0,  2,  2,  1,  3,  4,  3,  6,  1,  2,  1,  4,  7,  8, 10, 20,  0,
          0,  1,  3,  1,  3,  4,  3,  6,  1,  2,  1,  4,  7,  8, 10, 20],
        [ 4,  0,  0,  6,  1, 20,  0,  1,  2