## train transformer

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import requests
from torch.nn.functional import layer_norm

In [3]:
if not os.path.exists('sales_textbook.txt'):
    url = 'https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling/resolve/main/sales_textbook.txt?download=true'
    with open('sales_textbook.txt', 'wb') as f:
        r = requests.get(url).content
        f.write(r)

In [4]:
with open('sales_textbook.txt', 'r') as f:
    text = f.read()

In [5]:
# hyper parameters
context_length = 16 # number of tokens in context
batch_size = 4
d_model = 64

In [6]:
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)

In [7]:
max_token_value = torch.tensor(tokens).max().item()

train_idx = int(len(tokens) * 0.9) # 90% for train 10% for test
train_tokens = torch.tensor(tokens[:train_idx], dtype=torch.long)
test_tokens = torch.tensor(tokens[train_idx:], dtype=torch.long)

In [24]:
idxs = torch.randint(high=len(train_tokens) - context_length - 1, size=(batch_size,))
x_batch = torch.stack([train_tokens[i:i+context_length] for i in idxs])
y_batch = torch.stack([train_tokens[i+1:i+context_length+1] for i in idxs])

x_batch.shape

torch.Size([4, 16])

In [9]:
import pandas as pd
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"NumPy version: {np.__version__}")

PyTorch version: 2.2.2
NumPy version: 1.26.4


In [10]:
input_embedding = nn.Embedding(max_token_value+1, d_model)
x_batch_embedding = input_embedding(x_batch)
y_batch_embedding = input_embedding(y_batch)

x_batch_embedding.shape, y_batch_embedding.shape

(torch.Size([4, 16, 64]), torch.Size([4, 16, 64]))

In [26]:
# position encoding
position_encoding = torch.zeros(context_length, d_model)
position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
position_encoding[:, 0::2] = torch.sin(position * div_term)
position_encoding[:, 1::2] = torch.cos(position * div_term)
position_encoding = position_encoding.unsqueeze(0).expand(batch_size, -1, -1)
position_encoding.shape

torch.Size([4, 16, 64])

In [12]:
x = x_batch_embedding + position_encoding
y = y_batch_embedding + position_encoding

x.shape

torch.Size([4, 16, 64])

In [13]:
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)

Q = Wq(x)
K = Wk(x)
V = Wv(x)

Q.shape, K.shape, V.shape

(torch.Size([4, 16, 64]), torch.Size([4, 16, 64]), torch.Size([4, 16, 64]))

In [14]:
num_head = 4
Q = Q.reshape(batch_size, context_length, num_head, d_model//num_head).permute(0, 2, 1, 3)
K = K.reshape(batch_size, context_length, num_head, d_model//num_head).permute(0, 2, 3, 1)
V = V.reshape(batch_size, context_length, num_head, d_model//num_head).permute(0, 2, 1, 3)

In [15]:
output = Q @ K / np.sqrt(d_model//num_head)
output.shape

torch.Size([4, 4, 16, 16])

In [16]:
# apply mask
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1).bool()
output = output.masked_fill(mask, -np.inf)

In [17]:
# apply softmax
attn = F.softmax(output, dim=-1)
# apply attention and v
A = attn @ V
A.shape

torch.Size([4, 4, 16, 16])

In [18]:
# apply concatenate
A = A.transpose(1, 2).reshape(batch_size, -1, d_model)
Wo = nn.Linear(d_model, d_model)
output = Wo(A)
output.shape

torch.Size([4, 16, 64])

In [19]:
# appli residual connection
output = x + output
# apply layer normalization
ln = nn.LayerNorm(d_model)
layer_norm_output = ln(output)

In [20]:
# apply feed forward
FF = nn.Sequential(
    nn.Linear(d_model, d_model * 4),
    nn.ReLU(),
    nn.Linear(d_model * 4, d_model)
)
output = FF(layer_norm_output)
output.shape

torch.Size([4, 16, 64])

In [21]:
# apply residual connection
output = layer_norm_output + output
# apply layer normalization
output = ln(output)

In [22]:
# apply final linear layer
final_linear = nn.Linear(d_model, max_token_value+1)
output = final_linear(output)
output.shape

torch.Size([4, 16, 100070])

In [23]:
 # apply softmax
logits = F.softmax(output, dim=-1)
predicate_index = torch.argmax(logits[0, 0]).item()
encoding.decode([predicate_index])

' FileWriter'