In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import wandb
from tqdm import tqdm
import time
import json
torch.manual_seed(1337)

<torch._C.Generator at 0x14f31a0f0>

In [None]:
# # initialize wandb
# wandb.init(project="GPT 2 848K")
# wandb.run.tags = ['GPT 1', 'test run']

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


In [3]:
# pull from local folder
filename = 'tinyshakespeare.txt'
with open(filename, 'r') as f:
    text = f.read()

In [4]:
# TODO: count how many params you're using in this code, and implement chinchilla law to understand how much data you need to ensure you aren't under training
# get vocab
vocab = list(sorted(set(text)))
vocab_size = len(vocab)
# embedding dimensions 
n_emb = 32
learning_rate = 1e-4
block_size = 8
epochs = 5000
# how often to evaluate loss
eval_iter = 200
# number of blocks in the transformer
n_layer = 2
# number of heads in the transformer
n_heads = 2
# each head size is n_emb // n_heads = 32 // 2 = 16
dropout = 0.2 # 20% will be zeroed out
train_test_split = 0.9 # 85% of data will be used for training
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [5]:
# character level encoding and decoding
stoi = {c: i for i, c in enumerate(vocab)}
# itos = {i: c for i, c in enumerate(vocab)}
# alternate way of creating decoder func
itos = {i: c for c, i in stoi.items()}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])

In [6]:
# encode full dataset
data = torch.tensor(encode(text), dtype=torch.long)

# train test split
train_size = int(train_test_split * len(data))
train_data = data[:train_size]
test_data = data[train_size:]

In [7]:
torch.manual_seed(1337)
batch_size = 4 # how many sequences we will process in parallel, each of these sequences is block_size long
block_size = 8 # the length of each sequence

In [8]:
def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [9]:
class AttentionHead(nn.Module):
    '''one head of self-attention'''

    def __init__(self, head_size):
        super().__init__()
        # usually bias is not used in self-attention TODO: understand better why
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        # triangular mask to prevent attending to future tokens
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        # using register buffer ensures that tril is not initialized as a param, so it won't be optimized during training
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # BxTxC
        q = self.query(x) # BxTxC
        v = self.value(x) # BxTxC
        # compute attention scores
        # could potentially be optimized by using einsum? TODO: understand how
        # could potentially use lora's code to optimize this
        wei = q @ k.transpose(-2, -1) * C ** -0.5 # BxTxC @ BxCxT (because of transposing second last and last dim of k) --> BxTxT
        # BxTxT: the TxT part of this attention matrix is where the quadratic complexity dependent on context length comes from
        # * C ** -0.5 is the one over root dk scaling factor in the attention formula
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # wherever tril is 0, in that position of wei, replace existing value with -inf
        # :T, :T is sliced to prevent index out of bounds error (for the case where block_size is not equal to T)
        wei = torch.softmax(wei, dim=-1) # TODO: understand why we softmax on the last dim
        wei = self.dropout(wei) # dropout on attention scores, randomly set some of them to 0
        # perform aggregation of values with attention scores
        out = wei @ v # BxTxT @ BxTxC --> BxTxC
        # out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # BxTxC
        # back to the dims we started with
        return out

In [10]:
class MultiHeadAttention(nn.Module):
    '''multi headed self attention'''

    def __init__(self, num_heads, head_size):
        super().__init__() # This initializes nn.Module (parent class from which MultiHeadAttention inherits from) before 
        # initializing anything in this child class
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(n_emb, n_emb) # linear layer to project concatenated heads output back to n_emb
        # project back into the residual pathway
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # BxTxC
        out = self.projection(out)
        return self.dropout(out)

In [11]:
class FeedForwardNN(nn.Module):
    '''simple one layer linear nn'''

    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb), # add a factor of 4 to n_emb as per GPT-2, just to make it more expressive, increasing complexity and computation
            nn.ReLU(), # TODO: use GELU instead of ReLU
            nn.Linear(4 * n_emb, n_emb), # linear projection back into the residual pathway
            nn.Dropout(dropout) # add right before connetion before residual connection
        )
    
    def forward(self, x):
        return self.net(x)

In [12]:
class Block(nn.Module):
    '''transformer block: create multiple blocks and concatenate them'''

    def __init__(self, n_emb, num_heads):
        super().__init__()
        head_size = n_emb // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size)
        self.ffn = FeedForwardNN(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)
        self.ln2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # residual connection # TODO: test using layer norm after sa and ffn as in original transformer paper 
        # and understand why there was an improvement in the new method
        x = x + self.ffn(self.ln2(x)) # residual connection (damn that was a very easy change to make)
        return x

In [13]:
class BigramLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token in the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb) # W_E in GPT-2
        self.positional_embedding_table = nn.Embedding(block_size, n_emb) # W_P in GPT-2
        self.blocks = nn.Sequential(*[Block(n_emb, num_heads=n_heads) for _ in range(n_layer)]) # 4 blocks as per GPT-2 
        # asterisk is used here to unpack the list of blocks so it can be passed as individual elements to nn.Sequential and not as one big list
        # also this is just a simpler representation of the previous thing we did, where we had a list of blocks and we individually called them
        self.lm_head = nn.Linear(n_emb, vocab_size) # W_o in GPT-2

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both of shape (batch_size, block_size) aka (B, T)
        token_emb = self.token_embedding_table(idx) # Batch x time x channel (here channel is now n_emb)
        pos_emb = self.positional_embedding_table(torch.arange(T)) # time x channel
        x = token_emb + pos_emb  # add positional embedding to token embedding
        x = self.blocks(x)
        logits = self.lm_head(x) # B, T, vocab size

        if targets is None:
            loss = None
        else:
            # loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) # we could do this, but its hard to understand, so
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) 

        return logits, loss

    # auto regressive generation
    def generate(self, idx, max_new_tokens):
        # idx is BxT
        for _ in range(max_new_tokens):
            # get the last block_size tokens of the idx
            idx_cond = idx[:, -block_size:] # BxT
            logits, loss = self(idx_cond)
            # pluck out last column in time dimension, because this is the generated predictions for what comes next
            logits = logits[:, -1, :] # keep only the last token for each sequence in the batch aka BxC
            probs = F.softmax(logits, dim=-1) # BxC
            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1) # Bx1
            # append newly generated token to input idx to obtain new input for next generation iteration
            idx = torch.cat([idx, next_token], dim=-1) # Bx(T+1) # TODO: understand why this is dim=-1
        return idx

In [14]:
model = BigramLanguageModel()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # TODO: try adding a lr schedule

In [15]:
# Track best losses and store losses for plotting
best_train_loss = float('inf')
best_val_loss = float('inf')
train_losses = []
val_losses = []

In [None]:
# Training loop
start_time = time.time()
for iter in tqdm(range(epochs), desc="Training Epochs"):
    # Training phase
    model.train()  # Set model to training mode
    xb, yb = get_batch('train')
    logits, train_loss = model(xb, yb)

    # Zero gradients, backward pass, and optimizer step
    optimizer.zero_grad(set_to_none=True)
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())

    # Evaluation phase every eval_iter
    if iter % eval_iter == 0:
        model.eval()  # Set model to evaluation mode
        val_losses_list = []

        for _ in range(eval_iter):
            with torch.no_grad():  # Disable gradient calculation
                X_val, Y_val = get_batch('val')
                logits, val_loss = model(X_val, Y_val)
                val_losses_list.append(val_loss.item())
        
        # Calculate mean of validation losses
        avg_val_loss = sum(val_losses_list) / len(val_losses_list)

        # Log and print average train and validation losses
        print(f"Epoch: {iter}, Train Loss: {train_loss.item()}, Val Loss: {avg_val_loss}")
        # wandb.log({
        #     'train_loss': train_loss.item(),
        #     'val_loss': avg_val_loss
        # })

        # Track best losses
        if train_loss.item() < best_train_loss:
            best_train_loss = train_loss.item()
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
        val_losses.append(avg_val_loss)

end_time = time.time()
train_time = end_time - start_time

Training Epochs:   0%|          | 22/5000 [00:00<01:25, 58.14it/s]

Epoch: 0, Train Loss: 4.441923141479492, Val Loss: 4.58493901014328


Training Epochs:   5%|▍         | 234/5000 [00:01<00:28, 168.53it/s]

Epoch: 200, Train Loss: 4.216477394104004, Val Loss: 3.957894432544708


Training Epochs:   9%|▊         | 426/5000 [00:02<00:27, 163.74it/s]

Epoch: 400, Train Loss: 3.482996702194214, Val Loss: 3.5243161797523497


Training Epochs:  13%|█▎        | 641/5000 [00:03<00:22, 193.74it/s]

Epoch: 600, Train Loss: 3.58988094329834, Val Loss: 3.3569255125522615


Training Epochs:  17%|█▋        | 834/5000 [00:04<00:20, 204.33it/s]

Epoch: 800, Train Loss: 3.233412981033325, Val Loss: 3.244213374853134


Training Epochs:  21%|██        | 1054/5000 [00:05<00:19, 202.41it/s]

Epoch: 1000, Train Loss: 3.2395029067993164, Val Loss: 3.215423300266266


Training Epochs:  25%|██▍       | 1246/5000 [00:06<00:18, 208.06it/s]

Epoch: 1200, Train Loss: 3.294506788253784, Val Loss: 3.1433665442466734


Training Epochs:  29%|██▉       | 1441/5000 [00:07<00:17, 209.21it/s]

Epoch: 1400, Train Loss: 3.244929552078247, Val Loss: 3.057328428030014


Training Epochs:  33%|███▎      | 1638/5000 [00:07<00:16, 208.97it/s]

Epoch: 1600, Train Loss: 3.363179922103882, Val Loss: 3.0281841170787813


Training Epochs:  37%|███▋      | 1832/5000 [00:08<00:16, 195.40it/s]

Epoch: 1800, Train Loss: 3.440908432006836, Val Loss: 2.993301634788513


Training Epochs:  41%|████      | 2033/5000 [00:10<00:24, 122.35it/s]

Epoch: 2000, Train Loss: 3.348334550857544, Val Loss: 2.9892898738384246


Training Epochs:  45%|████▍     | 2230/5000 [00:11<00:23, 117.73it/s]

Epoch: 2200, Train Loss: 2.984619140625, Val Loss: 2.908329746723175


Training Epochs:  49%|████▊     | 2427/5000 [00:12<00:21, 119.82it/s]

Epoch: 2400, Train Loss: 2.676140069961548, Val Loss: 2.8756549632549286


Training Epochs:  52%|█████▏    | 2622/5000 [00:14<00:20, 118.71it/s]

Epoch: 2600, Train Loss: 3.526101589202881, Val Loss: 2.8939671194553376


Training Epochs:  57%|█████▋    | 2836/5000 [00:15<00:18, 117.52it/s]

Epoch: 2800, Train Loss: 2.8766255378723145, Val Loss: 2.825735219717026


Training Epochs:  61%|██████    | 3032/5000 [00:17<00:16, 118.96it/s]

Epoch: 3000, Train Loss: 2.503922462463379, Val Loss: 2.835417102575302


Training Epochs:  64%|██████▍   | 3225/5000 [00:18<00:16, 106.45it/s]

Epoch: 3200, Train Loss: 2.4428529739379883, Val Loss: 2.8332872414588928


Training Epochs:  68%|██████▊   | 3425/5000 [00:20<00:14, 106.14it/s]

Epoch: 3400, Train Loss: 2.410163164138794, Val Loss: 2.781707580089569


Training Epochs:  73%|███████▎  | 3631/5000 [00:21<00:11, 114.17it/s]

Epoch: 3600, Train Loss: 2.4952361583709717, Val Loss: 2.813198951482773


Training Epochs:  76%|███████▋  | 3823/5000 [00:22<00:10, 108.75it/s]

Epoch: 3800, Train Loss: 2.573543071746826, Val Loss: 2.788221287727356


Training Epochs:  80%|████████  | 4023/5000 [00:24<00:08, 116.84it/s]

Epoch: 4000, Train Loss: 2.4551286697387695, Val Loss: 2.7386697661876678


Training Epochs:  84%|████████▍ | 4217/5000 [00:25<00:08, 96.93it/s] 

Epoch: 4200, Train Loss: 3.237701892852783, Val Loss: 2.6911581814289094


Training Epochs:  89%|████████▊ | 4428/5000 [00:27<00:04, 117.56it/s]

Epoch: 4400, Train Loss: 3.2079591751098633, Val Loss: 2.704334862232208


Training Epochs:  92%|█████████▏| 4616/5000 [00:29<00:05, 64.97it/s] 

Epoch: 4600, Train Loss: 2.630408763885498, Val Loss: 2.683697439432144


Training Epochs:  96%|█████████▋| 4814/5000 [00:31<00:02, 75.11it/s] 

Epoch: 4800, Train Loss: 2.4602599143981934, Val Loss: 2.7012743133306505


Training Epochs: 100%|██████████| 5000/5000 [00:32<00:00, 154.57it/s]


In [17]:
print(100*'*')
# Load best losses from JSON file if it exists
best_losses_file = 'best_losses.json'
try:
    with open(best_losses_file, 'r') as f:
        best_losses = json.load(f)
        best_train_loss = best_losses.get('best_train_loss', best_train_loss)
        best_val_loss = best_losses.get('best_val_loss', best_val_loss)
except FileNotFoundError:
    best_losses = {
        'best_train_loss': best_train_loss,
        'best_val_loss': best_val_loss
    }
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

****************************************************************************************************


In [18]:
print(f"Generated Text:")
idx = torch.zeros((1,1), dtype=torch.long)
generated_text = decode(model.generate(idx, max_new_tokens=2000)[0].tolist())
print(generated_text)

Generated Text:

enNijde,

MrhE stCyo?th oth fany wmy ch mhastRI
I athethou'ersulthn fe Iorde 

VG,E
Sesilupy qP
Cigardy usoussthi'dou ist


Th'y
Limuchy ow:
itt:


I
S rlplLother hidse
-!oos.
Qe Upos b t diyound cacu, aennndle cosst karer dor  past osenothth an rEnl Censing acineen Ir d nyfearist.

e g.
STSthesichootg f m Es: yoThest the rspk'lerd 

xty sree le 

N

asoGjmi y piregeit Ns H heU-o ca bnith!ere
Aiwe ad tou oOko wno!:
My:
C, ksowe the heel, t 'ar
N
Oo A!,
PDuraLLe hy-
3 thave.

Sik rl usive ranW wn hieanavery
PUs!Sornouwus wf veithfndo:
z cem:
SQerr, I,
Bdove ld sond!eAbyerd bowe 's.
kou.
Ycthe?UYivAlthiuchey winth,coy alelalles soriss


H blscyf'n aNs,
SD
A!t baadob wheo,

ThourenleH sftis t arithe rndower; bekenr; oredorlNmen t:
T
S:
T:
Iehe Lcerof'rt.

qP
WHimer:
c!COyamind chesth that-he Tag, I,
Sn?
T b se hllyesne,
nd.
J

xithom les
Hhin't AS ikav t akthal.

sow ai
C:
3hyN?
D$ YOG wyouranr hS:,
M aka?Nindon f ndart cal bt d, we ve

Iou an.

C:
S Liror won.
TtoraB
cE 

In [19]:
# Check if current run has better losses
if best_train_loss < best_losses.get('best_train_loss', float('inf')) or best_val_loss < best_losses.get('best_val_loss', float('inf')):
    # Save generated text to file
    with open('generated_shakespeare_text.txt', 'w') as f:
        f.write(generated_text)

    # Update best losses and save to JSON file
    best_losses['best_train_loss'] = best_train_loss
    best_losses['best_val_loss'] = best_val_loss
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

    # Have wandb save the text file
    wandb.save('generated_shakespeare_text.txt')
    # also save an image of the training and validation loss curves
    plt.plot(train_losses, label='train loss')
    plt.plot(val_losses, label='val loss')
    plt.legend()
    plt.savefig('train_val_loss.png')
    wandb.save('train_val_loss.png')
    print("Current run beat the best losses. Generated text saved.")

else:
    print("Current run did not beat the best losses. Generated text not saved.")
print(100*'*')
print(100*'*')

Current run did not beat the best losses. Generated text not saved.
****************************************************************************************************
****************************************************************************************************


In [20]:
print(f"Best Train Loss: {best_train_loss}")
print(f"Best Validation Loss: {best_val_loss}")
# show total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")
# show toal number of tokens in the dataset
total_tokens = len(data)
print(f"Total number of tokens in the dataset: {total_tokens}")
print(f"According to Chinchilla Law, you need at least {total_params * 2} tokens to train this model.") # TODO: work on this

Best Train Loss: 2.410163164138794
Best Validation Loss: 2.683697439432144
Total number of parameters in the model: 29697
Total number of tokens in the dataset: 1115394
According to Chinchilla Law, you need at least 59394 tokens to train this model.


In [23]:
# Ensure train_time and other parameters are defined before logging
# wandb.log({
#     'epochs': epochs,
#     "learning_rate": learning_rate,
#     "block_size": block_size,
#     "batch_size": batch_size,
#     "embedding_size": n_emb,
#     "optimizer": "AdamW",
#     "device": device,
#     "vocab_size": vocab_size,
#     "best_train_loss": best_train_loss,
#     "best_val_loss": best_val_loss,
#     'Training Time': train_time, 
#     'dropout': dropout,
#     'n_layer': n_layer,
#     'n_heads': n_heads,
#     'train_test_split': train_test_split,
#     'total_params': total_params
# })

print(f"Total time to train model up to {epochs} epochs: {train_time:.2f} seconds")
wandb.finish()

Total time to train model up to 5000 epochs: 32.38 seconds
