In [1]:
import typing
import math
import random

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor

from model.transformer import Transformer

In [2]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "N/A")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

Torch version: 2.5.1
CUDA available: True
CUDA version: 12.4
Current device: 0
Device name: NVIDIA GeForce RTX 3060


In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
else:
    device = torch.device('cpu')
    print('Using CPU')

Using CUDA


In [4]:
torch.manual_seed(1337)

<torch._C.Generator at 0x79c96cdd1d90>

In [5]:
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
len(text)

1115393

In [7]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [8]:
chars = sorted(list(set(text)))
print(str().join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [9]:
vocab_size = len(chars)
vocab_size

65

In [10]:
stoi = {ch: i for i, ch in enumerate(chars)}
print(stoi)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}


In [11]:
itos = {i: ch for ch, i in stoi.items()}
print(itos)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}


In [12]:
def encode(s: str) -> typing.List[int]:
    return [stoi[c] for c in s]

def decode(ints: typing.List[int]) -> str:
    return str().join(itos[i] for i in ints)

In [13]:
print(encode('hello world'))
print(decode(encode('hello world')))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [14]:
encoded_text = encode(text)
len(encoded_text)

1115393

In [15]:
data = torch.tensor(encoded_text, dtype=torch.long, device=device)
data.shape

torch.Size([1115393])

In [16]:
data[:100]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59], device='cuda:0')

In [17]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f'Training size: {len(train_data)}, Validation size: {len(val_data)}')

Training size: 1003853, Validation size: 111540


In [18]:
block_size = 8 # Also called "context length"

In [19]:
xb = train_data[:block_size]
yb = train_data[1:block_size+1]
print('--- As characters ---')
for t in range(block_size):
    context = xb[:t+1]
    target = yb[t]
    print(f'When the input is {decode(context.tolist())} the next character is {itos[target.item()]}')
print('--- Encoded ---')
for t in range(block_size):
    context = xb[:t+1]
    target = yb[t]
    print(f'When the input is {context} the next character is {target}')

--- As characters ---
When the input is F the next character is i
When the input is Fi the next character is r
When the input is Fir the next character is s
When the input is Firs the next character is t
When the input is First the next character is  
When the input is First  the next character is C
When the input is First C the next character is i
When the input is First Ci the next character is t
--- Encoded ---
When the input is tensor([18], device='cuda:0') the next character is 47
When the input is tensor([18, 47], device='cuda:0') the next character is 56
When the input is tensor([18, 47, 56], device='cuda:0') the next character is 57
When the input is tensor([18, 47, 56, 57], device='cuda:0') the next character is 58
When the input is tensor([18, 47, 56, 57, 58], device='cuda:0') the next character is 1
When the input is tensor([18, 47, 56, 57, 58,  1], device='cuda:0') the next character is 15
When the input is tensor([18, 47, 56, 57, 58,  1, 15], device='cuda:0') the next char

In [20]:
def get_batch(dataset: Tensor, batch_size: int, block_size: int, device=None) -> typing.Tuple[Tensor, Tensor]:
    '''
    Gets a batch of `batch_size` examples from `dataset`. Each example will
    consist of `block_size` characters. The inputs and labels will both be
    returned, both of which will be of size `(batch_size, block_size)`.
    '''

    ix = torch.randint(low=0, high=len(dataset)-block_size, size=(batch_size,), device=device)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [21]:
batch_size = 4
xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
print(xb.shape)
print(xb)
print(yb.shape)
print(yb)

torch.Size([4, 8])
tensor([[35, 56, 43, 52, 41, 46,  1, 59],
        [56, 50, 47, 49, 43,  1, 44, 39],
        [13,  1, 50, 47, 58, 58, 50, 43],
        [51,  6,  1, 47, 44,  1, 51, 63]], device='cuda:0')
torch.Size([4, 8])
tensor([[56, 43, 52, 41, 46,  1, 59, 54],
        [50, 47, 49, 43,  1, 44, 39, 58],
        [ 1, 50, 47, 58, 58, 50, 43,  1],
        [ 6,  1, 47, 44,  1, 51, 63,  1]], device='cuda:0')


In [22]:
for b in range(batch_size):
    print(f'Example {b}')
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Block {t}: When the input is {context} the next character is {target}')

Example 0
Block 0: When the input is tensor([35], device='cuda:0') the next character is 56
Block 1: When the input is tensor([35, 56], device='cuda:0') the next character is 43
Block 2: When the input is tensor([35, 56, 43], device='cuda:0') the next character is 52
Block 3: When the input is tensor([35, 56, 43, 52], device='cuda:0') the next character is 41
Block 4: When the input is tensor([35, 56, 43, 52, 41], device='cuda:0') the next character is 46
Block 5: When the input is tensor([35, 56, 43, 52, 41, 46], device='cuda:0') the next character is 1
Block 6: When the input is tensor([35, 56, 43, 52, 41, 46,  1], device='cuda:0') the next character is 59
Block 7: When the input is tensor([35, 56, 43, 52, 41, 46,  1, 59], device='cuda:0') the next character is 54
Example 1
Block 0: When the input is tensor([56], device='cuda:0') the next character is 50
Block 1: When the input is tensor([56, 50], device='cuda:0') the next character is 47
Block 2: When the input is tensor([56, 50, 47

In [23]:
@torch.no_grad()
def estimate_loss(model: Transformer, train_dataset: Tensor, val_dataset: Tensor, eval_iterations: int, batch_size: int, block_size: int, device = None) -> typing.Dict[str, torch.types.Number]:
    dataset_splits = {'train': train_dataset, 'val': val_dataset}
    out = dict()
    for split_name, split_dataset in dataset_splits.items():
        losses = torch.zeros(eval_iterations, device=device)
        for i in range(eval_iterations):
            xb, yb = get_batch(split_dataset, batch_size, block_size, device)
            logits, loss = model(xb, yb)
            losses[i] = loss.item()
        out[split_name] = losses.mean().item()
    return out

In [None]:
batch_size = 32
max_steps = 5_000
learning_rate = 1e-4
epochs = 30
eval_iterations = 300

model = Transformer(
    number_of_layers=4,
    vocab_size=vocab_size,
    number_of_heads=4,
    embedding_dimension=64,
    block_size=block_size,
    dropout_probability=0.2,
    device=device
)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

model.train()

for epoch in range(epochs):
    for step in range(max_steps):
        xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)

        logits, loss = model(xb, yb)

        logits = typing.cast(Tensor, logits)
        loss = typing.cast(Tensor, loss)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    model.eval()

    loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
    print(f'Step: {step:<7}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()
    scheduler.step()

Step: 0      , last seen loss: 4.3402, estimated training loss: 4.1186, estimated validation loss: 4.1237
Step: 40     , last seen loss: 3.2581, estimated training loss: 3.3166, estimated validation loss: 3.3204
Step: 80     , last seen loss: 3.2803, estimated training loss: 3.3115, estimated validation loss: 3.3125
Step: 120    , last seen loss: 3.1984, estimated training loss: 3.3161, estimated validation loss: 3.3219
Step: 160    , last seen loss: 3.3043, estimated training loss: 3.3236, estimated validation loss: 3.3141
Step: 200    , last seen loss: 3.2348, estimated training loss: 3.3185, estimated validation loss: 3.3165
Step: 240    , last seen loss: 3.3689, estimated training loss: 3.3122, estimated validation loss: 3.3152
Step: 280    , last seen loss: 3.2934, estimated training loss: 3.3256, estimated validation loss: 3.3107
Step: 320    , last seen loss: 3.2544, estimated training loss: 3.3178, estimated validation loss: 3.3117
Step: 360    , last seen loss: 3.1222, estimat

In [25]:
model.eval()
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)



N:SBsomd wed lnel?e god whea A gne, t blysunkagf,
Kkr reirwth,
:Hee ton ist tenn lhoyh;
Sdo,
THrune,
Tlssl wre s coqd buahannitas dg n sat nthtaue fiffinj prd gol toueeo tod sr f bocr mi dele uitls u snosoawids'reg, at fenund nttaor toicsad

wee naods a s hocg w mlawt st est,
Songo: theey

Hrns b.
usr toksrory foay vhciraraatine toid,
S:GKN:AZVURHaI
 Y'rr we bn

SRCodalyt te descte weai t tcdk! rons lt luws uawsltt shot'v tecd nh nrunasgent thst bipuge.
MAvnl mavssys moncy dee, utegrs rit. nr b fos hre feiaur neidt rds'ge msthe, ab sloneiiois.

Od bu,
Sod.Hets oneriodmy ugde m, mooulesr. fuhdsh goinn. tniepfn gn yoldc,gy Ghur  Fudmt basok bdosi, gheretd
 Eet yoI ref itrlett.
Tfris leslut pstepy wiv fc, aisudtd nnaIhf de kiw mosei ts fout
d wowt t,
OUTuo o! teim, iai yaasn to oslnll mr yde soect uhv. jh t rop sus vit wsd,
Aha Say it csiy


Hotett tiy tot Ih tyshe svoe umy sei ?emf.
TICL

KOCsw.

Hlhde  blaTd 'urir sn,

Tau?
Bnidd aues urh'srA -of.
 cl, sesd henor s ceontolt:
Aoln ile 