# Custom Transformer-from-Scratch Project

"Infinite Diplomacy": training a simple transformer to generate lines from the game diplomacy.

Sourced from: https://github.com/niderhoff/nlp-datasets?tab=readme-ov-file

In [1]:
# First load in the JSON dataset

import json

train = ''
test = ''

with open('train.jsonl', 'r') as trfile:
    for line in trfile:
        json_obj = json.loads(line)
        train += ' '.join(json_obj['messages'])

with open('test.jsonl', 'r') as testfile:
    for line in testfile:
        json_obj = json.loads(line)
        test += ' '.join(json_obj['messages'])



# Step 1: Tokenize and Encode/Decode

In [2]:
vocab = sorted(list(set(train).union(set(test))))

vocab_size = len(vocab)

print(''.join(vocab))
print(vocab_size)


 !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~¯°àéХаезнрт‍—‘’“”→☺♂❤ツ️，🇫🇬🇷🌊🎆🎉🏻🏼🏽🐍👀👌👍👎👏👺💀💯💰💴😀😁😂😃😄😅😆😈😉😊😍😎😔😘😛😜😝😞😟😤😦😩😫😬😭😮😰😱😲😳😴😵😺😼🙂🙃🙄🙏🤔🤗🤞🤣🤦🤨🤪🤫🤭🤷🥂🥳🥵🥺🦃🧐🧙🧨
196


In [3]:
# provided encode and decode, to convert from char <-> int_token

stoi = { c:i for i, c in enumerate(vocab) }
itos = { i:c for i, c in enumerate(vocab) }

encode = lambda some_string : [stoi[character] for character in some_string]
decode = lambda some_list_of_ints : [itos[int_token] for int_token in some_list_of_ints]

print(encode("Hello, I am Willy"))
print(''.join(decode(encode("Hello, I am Willy"))))

[41, 70, 77, 77, 80, 13, 1, 42, 1, 66, 78, 1, 56, 74, 77, 77, 90]
Hello, I am Willy


# Step 2: Convert to Embeddings

In [4]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, input_sequence, targets = None):
        logits = self.token_embedding_table(input_sequence)   # need to pass in a BATCH of input_sequences!!!!
        
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
            
        return logits, loss
        
    def generate(self, starter_sequence, max_new_tokens):
        
        my_sequence = starter_sequence[:, -max_seq_length:]

        for _ in range(max_new_tokens):
            logits, loss = self.forward(starter_sequence)
            
            logits = logits[:, -1, :]   # only look at the very last logit!!!!
            
            # note this softmaxxing and sampling!!!!
            softmax_output = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(softmax_output, num_samples=1)
            
            my_sequence = torch.cat((my_sequence, next_token), dim=1)
            
        return my_sequence
        
        
        
batch_size = 4
max_seq_length = 8

def get_batch(split):
    data = train if split == 'train' else val_data
    starting_indices = torch.randint(len(data) - max_seq_length, (batch_size,))
    
    x = torch.tensor([encode(s) for s in [data[idx:idx + max_seq_length] for idx in starting_indices]])
    y = torch.tensor([encode(s) for s in [data[idx + 1:idx + max_seq_length + 1] for idx in starting_indices]])
    return x, y
    
xb, yb = get_batch('train')

my_model = BigramLanguageModel(vocab_size)
logits, loss = my_model.forward(xb, yb)
print(loss)

tensor(6.0601, grad_fn=<NllLossBackward0>)


In [10]:
optimizer = torch.optim.AdamW(my_model.parameters(), lr=1e-3)

for epoch in range(10000):
    xb, yb = get_batch('train')
    logits, loss = my_model.forward(xb, yb)
    
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    
    if epoch % 2000 == 0:
        print(f"epoch {epoch}: loss is {loss}")

epoch 0: loss is 2.5054776668548584
epoch 2000: loss is 2.4881591796875
epoch 4000: loss is 2.900830030441284
epoch 6000: loss is 2.4526865482330322
epoch 8000: loss is 2.6962499618530273


In [11]:
seed = torch.ones(1,1).long()
print(seed)

generated = my_model.generate(seed, 100)[0]
print(''.join(decode(generated.tolist())))

tensor([[1]])
 AttaaiwwRIFittysfagWausasIaiboyo1toolPmtIhGShjgistPdw bgctetiiabIamcjaah gpttYt yfnmHaiewnastyaktWwa


# Step 3: Adding Self-Attention

In [95]:
# Hyperparameters
batch_size = 4
max_seq_length = 8

# head_dim = 16

vocab_size = len(vocab)
num_heads = 2     # MUST BE A MULTIPLE OF VOCAB_SIZE!!!
learning_rate = 1e-3
dropout = 0.2

In [107]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class Head(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.query = nn.Linear(vocab_size, head_dim, bias=False)
        self.key = nn.Linear(vocab_size, head_dim, bias=False)
        self.value = nn.Linear(vocab_size, head_dim, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(max_seq_length, max_seq_length)))

    def forward(self, input_sequence_batched):
        B, T, C = input_sequence_batched.shape    # (B, T, C/Channels)
        
        q = self.query(input_sequence_batched)     # (B, T, head_dim)
        k = self.key(input_sequence_batched)
        v = self.value(input_sequence_batched)
        
        # Compute attention weights
        wei = q @ k.transpose(-2, -1) * (C ** -0.5)                        # Scale by sqrt(d_k)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))          # Apply mask
        wei = F.softmax(wei, dim=-1)                                         # Normalize weights
        
        out = wei @ v
        return out
    
    
    
    
    
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads):
        super().__init__()
        head_dim = vocab_size // num_heads
        
        self.heads = nn.ModuleList([Head(head_dim) for _ in range(num_heads)])
        self.proj = nn.Linear(vocab_size, vocab_size)
    
    def forward(self, input_sequence_batched):
        batched_attended_vecs = torch.cat([h.forward(input_sequence_batched) for h in self.heads], dim=-1)           # every head is embedding_size // num_heads, so we concat all at the end
        
        logits = self.proj(batched_attended_vecs)
        
        return logits
    

class FeedForward(nn.Module):
    def __init__(self, starting_dim):
        super().__init__()
        self.mlp_layer = nn.Sequential(
            nn.Linear(starting_dim, starting_dim * 4),
            nn.ReLU(),
            nn.Linear(4 * starting_dim, starting_dim),
            nn.Dropout(dropout),
        )
    
    def forward(self, input_sequence_batched):
        logits = self.mlp_layer(input_sequence_batched)  # (B, T, C)
        return logits


    
class Block(nn.Module):
    def __init__(self, starting_dim, num_heads):
        super().__init__()
        self.sa_heads = MultiHeadAttention(num_heads)
        self.ffw = FeedForward(starting_dim)
        self.ln1 = nn.LayerNorm(vocab_size)
        self.ln2 = nn.LayerNorm(vocab_size)
    
    def forward(self, input_sequence_batched):
#         print(input_sequence_batched)
        input_sequence_batched = input_sequence_batched + self.sa_heads(self.ln1(input_sequence_batched))   # "add and norm"
        logits = input_sequence_batched + self.ffw(self.ln2(input_sequence_batched))

        return logits



        
        
        
        
class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        
        self.blocks = nn.Sequential(
            Block(vocab_size, num_heads),
            Block(vocab_size, num_heads),
            Block(vocab_size, num_heads),        # RUN AND TRAIN ON 3 BLOCKS...
            nn.LayerNorm(vocab_size),            # AND LAYERNORM AT THE VERY END!
        )
        self.lm_head = nn.Linear(vocab_size, vocab_size)
        
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)        
        self.position_embedding_table = nn.Embedding(max_seq_length, vocab_size)

    
    def forward(self, input_sequence, targets = None):
#         print(input_sequence)
#         input_sequence = input_sequence.to(torch.float32) 
        B, T = input_sequence.shape
        sequence_token_embeddings = self.token_embedding_table(input_sequence)
        positional_embedding = self.position_embedding_table(torch.arange(T))   # 0....T-1 for a (T, embedding_dim) matrix
        
        sequence_token_embeddings = sequence_token_embeddings + positional_embedding 
        sequence_token_embeddings = self.blocks(sequence_token_embeddings)
        logits = self.lm_head(sequence_token_embeddings)

        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)
        else:
            loss = None
            
        return logits, loss
        
    def generate(self, starter_sequence, max_new_tokens):
        
        my_sequence = starter_sequence[:, -max_seq_length:]

        for _ in range(max_new_tokens):
            logits, loss = self.forward(starter_sequence)
            
            logits = logits[:, -1, :]   # only look at the very last logit!!!!
            
            # note this softmaxxing and sampling!!!!
            softmax_output = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(softmax_output, num_samples=1)
            
            my_sequence = torch.cat((my_sequence, next_token), dim=1)
            
        return my_sequence
        
        

In [108]:
def get_batch(split):
    data = train if split == 'train' else val_data
    starting_indices = torch.randint(len(data) - max_seq_length, (batch_size,))
    
    x = torch.tensor([encode(s) for s in [data[idx:idx + max_seq_length] for idx in starting_indices]])
    y = torch.tensor([encode(s) for s in [data[idx + 1:idx + max_seq_length + 1] for idx in starting_indices]])
    return x, y
    
xb, yb = get_batch('train')

my_model = BigramLanguageModel(vocab_size)
logits, loss = my_model.forward(xb, yb)
print(loss)

tensor(5.4491, grad_fn=<NllLossBackward0>)


In [111]:
optimizer = torch.optim.AdamW(my_model.parameters(), lr=learning_rate)

for epoch in range(1000):
    xb, yb = get_batch('train')
    logits, loss = my_model.forward(xb, yb)
    
    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()
    
    if epoch % 500 == 0:
        print(f"epoch {epoch}: loss is {loss}")

epoch 0: loss is 2.6685869693756104
epoch 500: loss is 2.250985860824585


In [112]:
seed = torch.ones(1,1).long()
print(seed)

generated = my_model.generate(seed, 100)[0]
print(''.join(decode(generated.tolist())))

tensor([[1]])
 EtwestFtiwfffwyuiebftitttobothm fIhcaiIwfukdycIfcatttwtfgoftiwaetyodRleRsotiwlfnittmtlfmusitTbnOnwgm
