In [1]:
import torch

Importing `torch` takes time

In [2]:
from constants import DataSplit
from model import TransformerDecoder
from train import train_transformer

In [15]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

## Hyper Parameters

In [3]:
# --------------------------------------------------------------------------------------------------- #

dataset_path = 'data/tinyshakespeare.txt'
percent_train = 0.90

batch_size = 16
context_length = 64

n_embd = 64
n_layer = 4
num_head = 4
dropout = 0.2

learning_rate = 3e-4
max_iters = 1000
device = 'cuda' if torch.cuda.is_available() else 'cpu'

eval_intervals = 100
eval_iters = 250

# --------------------------------------------------------------------------------------------------- #

In [4]:
f'Running on device: {device}, cuda version: {torch.version.cuda}'

'Running on device: cpu, cuda version: 11.8'

## Loading the Dataset

### Opening and Reading the File

In [5]:
with open(dataset_path, 'r') as fp:
    dataset = fp.read()

### Vocabulary
Creating a vocabulary. Essentially contains all the *tokens* which is recognised by the language model.

In [6]:
vocabulary = sorted(list(set(dataset)))
vocab_size = len(vocabulary)

### Encoders and Decoder
We are using a simple method to encode and decode the text data into numbers (integers).

In [7]:
stoi = {ch: i for i, ch in enumerate(vocabulary)}
itos = {i: ch for i, ch in enumerate(vocabulary)}

def encode(x): return [stoi[ch] for ch in x]
def decode(x): return ''.join([itos[i] for i in x])

test_text = 'This is a sample sentence.'
assert (test_text == decode(encode(test_text)))

### Encoding All the Dataset
We encode all the dataset into numbers.

In [8]:
encoded_dataset = torch.tensor(encode(dataset), dtype=torch.long)

### Split Dataset
We split the dataset to train the model and validate the model. Validation is important as we are trying to generalise the model so that it can produce sentences which are "like" the training data but not exactly the training data.

In [9]:
idx = int(len(encoded_dataset) * percent_train)
train_data, val_data = encoded_dataset[:idx], encoded_dataset[idx:]

In [10]:
def get_data(split: DataSplit) -> torch.Tensor:
    return train_data if split == DataSplit.TRAIN else val_data

## Training and Testing the Model

In [11]:
model = TransformerDecoder(
    vocabulary_size=vocab_size,
    embedding_dim=n_embd,
    context_length=context_length,
    number_of_layers=n_layer,
    number_of_heads=num_head,
    dropout=dropout,
    device=device
).to(device)

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [13]:
all_losses = train_transformer(
    model=model,
    optimizer=optimizer,
    get_data=get_data,
    batch_size=batch_size,
    context_length=context_length,
    maximum_iterations=max_iters,
    eval_intervals=eval_intervals,
    eval_iterations=eval_iters
)

  0%|          | 0/1000 [00:00<?, ?it/s]

Iteration[   1/1000], Training Loss:  4.359576, Validation Loss:  4.362085


 10%|█         | 102/1000 [00:20<14:48,  1.01it/s]

Iteration[ 101/1000], Training Loss:  3.114107, Validation Loss:  3.148436


 20%|██        | 203/1000 [00:33<10:44,  1.24it/s]

Iteration[ 201/1000], Training Loss:  2.788194, Validation Loss:  2.800481


 30%|███       | 301/1000 [00:46<11:58,  1.03s/it]

Iteration[ 301/1000], Training Loss:  2.644121, Validation Loss:  2.656755


 40%|████      | 400/1000 [00:51<00:31, 19.00it/s]

Iteration[ 401/1000], Training Loss:  2.586380, Validation Loss:  2.586755


 50%|█████     | 502/1000 [01:11<08:20,  1.00s/it]

Iteration[ 501/1000], Training Loss:  2.552685, Validation Loss:  2.551725


 60%|██████    | 603/1000 [01:23<05:12,  1.27it/s]

Iteration[ 601/1000], Training Loss:  2.521989, Validation Loss:  2.523201


 70%|███████   | 702/1000 [01:36<04:00,  1.24it/s]

Iteration[ 701/1000], Training Loss:  2.498847, Validation Loss:  2.498564


 80%|████████  | 802/1000 [01:48<03:38,  1.10s/it]

Iteration[ 801/1000], Training Loss:  2.483111, Validation Loss:  2.479725


 90%|█████████ | 902/1000 [02:01<01:24,  1.16it/s]

Iteration[ 901/1000], Training Loss:  2.459799, Validation Loss:  2.466308


100%|██████████| 1000/1000 [02:06<00:00,  7.93it/s]

Final Loss:
	Training:  2.459799
	Validation:  2.466308





In [None]:
generate_next_tokens = 100

idx = torch.zeros((1, 1), dtype=torch.long, device=device)
generation = decode(model.generate(idx, max_next_tokens=generate_next_tokens)[0].tolist())
print(generation)


Mord bnoferry were shild tachirchnd y medy wheintollle,
Bag auncoou of se
Thy.
Fou, w; m;
HOr so mmom akthy inorshesenemenovelbee f gor difspr Bicdas at be?

Wht s os Loe hautost LEAveanord my sofloe bpsels batyi thits th asa tethaigtsint Pm r, t leeg ssaa-d,
Gethece want
Mad, toresg ber mareagald.
We, ament tousp todthaut r , heast walavershiwof sn gadathinnd s at! Vnnond,
Th
Angewnt, wouou,, at, texses tithave b llanfanCache, h ce cathss eap cnge,
Bnge, che sot aichotom strystwtor;ilde k nete. arencerd ssthe rr go beaireat.


Srend ano chak! hasr.
TUNUSThan f sfiewat tamarst t:


MKYVKosks:
Andind, the hiy willinthastors tat iciK:
Hingen ct d Dhik?
ERh gocmouretlis n3Lomes cofegeartres wamy ou; aych f hethoncou, Fth wilithighe
sito m.
Q  il sthe brstoveal hyo stinchecanetraNETon f aves gothtoorsiur teem.
W widez she ier, it ofsnd slo cendo thienk,
Anger f rd: t ad'pead hato ondse.
Thithes ghe IThind t
Tomiche thou t aanarshaprrs
Towe doune G Foof he b wicos w. thesor-iltoot dove che

In [16]:
df = pd.DataFrame({
    'iter': range(len(all_losses[DataSplit.TRAIN])),
    'train': all_losses[DataSplit.TRAIN],
    'val': all_losses[DataSplit.VALIDATION]
})

# Melt the DataFrame to create a "long-form" DataFrame
df_melted = pd.melt(df, id_vars=['iter'], value_vars=['train', 'val'],
                    var_name='loss_type', value_name='loss')

# Create the plot
sns.set_theme()
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_melted, x='iter', y='loss', hue='loss_type')

plt.title('Training and Validation Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.legend(title='Loss Type')

plt.show()

NameError: name 'all_losses' is not defined