In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
import wandb
from tqdm import tqdm
import time
import json
torch.manual_seed(1337)

<torch._C.Generator at 0x118a1a0f0>

In [2]:
# # initialize wandb
# wandb.init(project="GPT 2 848K")
# wandb.run.tags = ['GPT 1', 'test run']

In [3]:
# pull from local folder
filename = 'tinyshakespeare.txt'
with open(filename, 'r') as f:
    text = f.read()

In [4]:
# TODO: count how many params you're using in this code, and implement chinchilla law to understand how much data you need to ensure you aren't under training
# get vocab
vocab = list(sorted(set(text)))
vocab_size = len(vocab)
# embedding dimensions 
n_emb = 32
learning_rate = 1e-4
block_size = 8
epochs = 5000
batch_size = 4 # how many sequences we will process in parallel, each of these sequences is block_size long
block_size = 8 # the length of each sequence
# how often to evaluate loss
eval_iter = 200
# number of blocks in the transformer
n_layer = 2
# number of heads in the transformer
n_heads = 2
# each head size is n_emb // n_heads = 32 // 2 = 16
dropout = 0.2 # 20% will be zeroed out
train_test_split = 0.9 # 85% of data will be used for training
device = 'mps' if torch.backends.mps.is_available() else 'cpu'

In [5]:
# character level encoding and decoding
stoi = {c: i for i, c in enumerate(vocab)}
# itos = {i: c for i, c in enumerate(vocab)}
# alternate way of creating decoder func
itos = {i: c for c, i in stoi.items()}
encode = lambda x: [stoi[c] for c in x]
decode = lambda x: ''.join([itos[i] for i in x])

In [6]:
# encode full dataset
data = torch.tensor(encode(text), dtype=torch.long)

# train test split
train_size = int(train_test_split * len(data))
train_data = data[:train_size]
test_data = data[train_size:]

In [7]:
def get_batch(split):
    data = train_data if split == 'train' else test_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [8]:
-torch.arange(0, n_emb, 2)

tensor([  0,  -2,  -4,  -6,  -8, -10, -12, -14, -16, -18, -20, -22, -24, -26,
        -28, -30])

In [9]:
-torch.arange(0, n_emb // 2)

tensor([  0,  -1,  -2,  -3,  -4,  -5,  -6,  -7,  -8,  -9, -10, -11, -12, -13,
        -14, -15])

In [10]:
class AttentionHead(nn.Module):
    '''one head of self-attention'''

    def __init__(self, head_size):
        super().__init__()
        # usually bias is not used in self-attention TODO: understand better why
        self.key = nn.Linear(n_emb, head_size, bias=False)
        self.query = nn.Linear(n_emb, head_size, bias=False)
        self.value = nn.Linear(n_emb, head_size, bias=False)
        # triangular mask to prevent attending to future tokens
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        # using register buffer ensures that tril is not initialized as a param, so it won't be optimized during training
        self.dropout = nn.Dropout(dropout)

    def create_rotary_embeddings(self):
        '''rotary position embedding, as described in the RoPE paper'''
        # N is typically set to 10,000 in RoPE implementations
        N = 10_000
        # Generate base theta values, each value corresponds to a "pair" and repeats twice
        theta_base = torch.pow(N, -torch.arange(0, n_emb // 2).float() / n_emb)
        # Repeat each element in theta_base twice to create the desired pattern
        theta = torch.repeat_interleave(theta_base, repeats=2)
        pos = torch.arange(block_size).float()
        idx_theta = pos[:, None] * theta[None, :]
        cache = torch.stack((torch.cos(idx_theta), torch.sin(idx_theta)), dim=-1)
        self.register_buffer('cache', cache, persistent=False) # TODO: understand why persistent=False

        return theta, pos, idx_theta
    
    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x) # BxTxC
        q = self.query(x) # BxTxC
        v = self.value(x) # BxTxC
        
        # compute attention scores
        # could potentially be optimized by using einsum? TODO: understand how
        # could potentially use lora's code to optimize this
        wei = q @ k.transpose(-2, -1) * C ** -0.5 # BxTxC @ BxCxT (because of transposing second last and last dim of k) --> BxTxT
        # BxTxT: the TxT part of this attention matrix is where the quadratic complexity dependent on context length comes from
        # * C ** -0.5 is the one over root dk scaling factor in the attention formula
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # wherever tril is 0, in that position of wei, replace existing value with -inf
        # :T, :T is sliced to prevent index out of bounds error (for the case where block_size is not equal to T)
        wei = torch.softmax(wei, dim=-1) # TODO: understand why we softmax on the last dim
        wei = self.dropout(wei) # dropout on attention scores, randomly set some of them to 0
        # perform aggregation of values with attention scores
        out = wei @ v # BxTxT @ BxTxC --> BxTxC
        # out = F.scaled_dot_product_attention(q, k, v, is_causal=True) # BxTxC
        # back to the dims we started with
        return out

In [11]:
x = AttentionHead(n_emb // n_heads)
theta, pos, idx = x.create_rotary_embeddings()

In [12]:
# theta.shape, pos.shape, idx.shape
theta

tensor([1.0000, 1.0000, 0.7499, 0.7499, 0.5623, 0.5623, 0.4217, 0.4217, 0.3162,
        0.3162, 0.2371, 0.2371, 0.1778, 0.1778, 0.1334, 0.1334, 0.1000, 0.1000,
        0.0750, 0.0750, 0.0562, 0.0562, 0.0422, 0.0422, 0.0316, 0.0316, 0.0237,
        0.0237, 0.0178, 0.0178, 0.0133, 0.0133])

In [13]:
theta

tensor([1.0000, 1.0000, 0.7499, 0.7499, 0.5623, 0.5623, 0.4217, 0.4217, 0.3162,
        0.3162, 0.2371, 0.2371, 0.1778, 0.1778, 0.1334, 0.1334, 0.1000, 0.1000,
        0.0750, 0.0750, 0.0562, 0.0562, 0.0422, 0.0422, 0.0316, 0.0316, 0.0237,
        0.0237, 0.0178, 0.0178, 0.0133, 0.0133])

In [14]:
class MultiHeadAttention(nn.Module):
    '''multi headed self attention'''

    def __init__(self, num_heads, head_size):
        super().__init__() # This initializes nn.Module (parent class from which MultiHeadAttention inherits from) before 
        # initializing anything in this child class
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.projection = nn.Linear(n_emb, n_emb) # linear layer to project concatenated heads output back to n_emb
        # project back into the residual pathway
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # BxTxC
        out = self.projection(out)
        return self.dropout(out)

In [15]:
class FeedForwardNN(nn.Module):
    '''simple one layer linear nn'''

    def __init__(self, n_emb):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_emb, 4 * n_emb), # add a factor of 4 to n_emb as per GPT-2, just to make it more expressive, increasing complexity and computation
            nn.ReLU(), # TODO: use GELU instead of ReLU
            nn.Linear(4 * n_emb, n_emb), # linear projection back into the residual pathway
            nn.Dropout(dropout) # add right before connetion before residual connection
        )
    
    def forward(self, x):
        return self.net(x)

In [16]:
class Block(nn.Module):
    '''transformer block: create multiple blocks and concatenate them'''

    def __init__(self, n_emb, num_heads):
        super().__init__()
        head_size = n_emb // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size)
        self.ffn = FeedForwardNN(n_emb)
        self.ln1 = nn.LayerNorm(n_emb)
        self.ln2 = nn.LayerNorm(n_emb)

    def forward(self, x):
        x = x + self.sa(self.ln1(x)) # residual connection # TODO: test using layer norm after sa and ffn as in original transformer paper 
        # and understand why there was an improvement in the new method
        x = x + self.ffn(self.ln2(x)) # residual connection (damn that was a very easy change to make)
        return x

In [17]:
class NanoGPT(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token in the lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_emb) # W_E in GPT-2
        self.positional_embedding_table = nn.Embedding(block_size, n_emb) # W_P in GPT-2
        self.blocks = nn.Sequential(*[Block(n_emb, num_heads=n_heads) for _ in range(n_layer)]) # 4 blocks as per GPT-2 
        # asterisk is used here to unpack the list of blocks so it can be passed as individual elements to nn.Sequential and not as one big list
        # also this is just a simpler representation of the previous thing we did, where we had a list of blocks and we individually called them
        self.lm_head = nn.Linear(n_emb, vocab_size) # W_o in GPT-2

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both of shape (batch_size, block_size) aka (B, T)
        token_emb = self.token_embedding_table(idx) # Batch x time x channel (here channel is now n_emb)
        pos_emb = self.positional_embedding_table(torch.arange(T)) # time x channel
        # x = token_emb + pos_emb  # add positional embedding to token embedding
        x = token_emb
        x = self.blocks(x)
        logits = self.lm_head(x) # B, T, vocab size

        if targets is None:
            loss = None
        else:
            # loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1)) # we could do this, but its hard to understand, so
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets) 

        return logits, loss

    # auto regressive generation
    def generate(self, idx, max_new_tokens=100):
        # idx is BxT
        for _ in range(max_new_tokens):
            # get the last block_size tokens of the idx
            idx_cond = idx[:, -block_size:] # BxT
            logits, loss = self(idx_cond)
            # pluck out last column in time dimension, because this is the generated predictions for what comes next
            logits = logits[:, -1, :] # keep only the last token for each sequence in the batch aka BxC
            probs = F.softmax(logits, dim=-1) # BxC
            # sample from the distribution
            next_token = torch.multinomial(probs, num_samples=1) # Bx1
            # append newly generated token to input idx to obtain new input for next generation iteration
            idx = torch.cat([idx, next_token], dim=-1) # Bx(T+1) # TODO: understand why this is dim=-1
        return idx

In [18]:
x = NanoGPT()
idx = torch.randint(vocab_size, (2,4))
logits, loss = x(idx=idx)

In [19]:
model = NanoGPT()

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) # TODO: try adding a lr schedule

In [20]:
# Track best losses and store losses for plotting
best_train_loss = float('inf')
best_val_loss = float('inf')
train_losses = []
val_losses = []

In [21]:
# Training loop
start_time = time.time()
for iter in tqdm(range(epochs), desc="Training Epochs"):
    # Training phase
    model.train()  # Set model to training mode
    xb, yb = get_batch('train')
    logits, train_loss = model(xb, yb)

    # Zero gradients, backward pass, and optimizer step
    optimizer.zero_grad(set_to_none=True)
    train_loss.backward()
    optimizer.step()
    train_losses.append(train_loss.item())

    # Evaluation phase every eval_iter
    if iter % eval_iter == 0:
        model.eval()  # Set model to evaluation mode
        val_losses_list = []

        for _ in range(eval_iter):
            with torch.no_grad():  # Disable gradient calculation
                X_val, Y_val = get_batch('val')
                logits, val_loss = model(X_val, Y_val)
                val_losses_list.append(val_loss.item())
        
        # Calculate mean of validation losses
        avg_val_loss = sum(val_losses_list) / len(val_losses_list)

        # Log and print average train and validation losses
        print(f"Epoch: {iter}, Train Loss: {train_loss.item()}, Val Loss: {avg_val_loss}")
        # wandb.log({
        #     'train_loss': train_loss.item(),
        #     'val_loss': avg_val_loss
        # })

        # Track best losses
        if train_loss.item() < best_train_loss:
            best_train_loss = train_loss.item()
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
        val_losses.append(avg_val_loss)

end_time = time.time()
train_time = end_time - start_time

Training Epochs:   1%|          | 27/5000 [00:00<00:43, 113.71it/s]

Epoch: 0, Train Loss: 4.302455425262451, Val Loss: 4.40168883562088


Training Epochs:   5%|▍         | 233/5000 [00:01<00:26, 178.27it/s]

Epoch: 200, Train Loss: 3.3392529487609863, Val Loss: 3.713762534856796


Training Epochs:   9%|▊         | 433/5000 [00:02<00:23, 191.87it/s]

Epoch: 400, Train Loss: 3.524874210357666, Val Loss: 3.321220185756683


Training Epochs:  12%|█▏        | 621/5000 [00:03<00:28, 155.48it/s]

Epoch: 600, Train Loss: 2.595158100128174, Val Loss: 3.2182006192207337


Training Epochs:  17%|█▋        | 832/5000 [00:04<00:33, 123.94it/s]

Epoch: 800, Train Loss: 2.9629464149475098, Val Loss: 3.104701887369156


Training Epochs:  20%|██        | 1023/5000 [00:05<00:31, 125.49it/s]

Epoch: 1000, Train Loss: 2.5057387351989746, Val Loss: 2.9971798968315126


Training Epochs:  25%|██▍       | 1231/5000 [00:07<00:30, 125.04it/s]

Epoch: 1200, Train Loss: 2.911221981048584, Val Loss: 2.9582037484645842


Training Epochs:  28%|██▊       | 1421/5000 [00:08<00:28, 126.85it/s]

Epoch: 1400, Train Loss: 2.732435464859009, Val Loss: 2.8974747002124785


Training Epochs:  32%|███▏      | 1619/5000 [00:09<00:30, 112.32it/s]

Epoch: 1600, Train Loss: 3.0408029556274414, Val Loss: 2.8946908688545228


Training Epochs:  36%|███▌      | 1809/5000 [00:11<00:28, 113.79it/s]

Epoch: 1800, Train Loss: 2.563445806503296, Val Loss: 2.8202089047431946


Training Epochs:  40%|████      | 2022/5000 [00:12<00:29, 101.13it/s]

Epoch: 2000, Train Loss: 3.221748113632202, Val Loss: 2.7686537432670595


Training Epochs:  45%|████▍     | 2227/5000 [00:14<00:24, 112.39it/s]

Epoch: 2200, Train Loss: 3.025428056716919, Val Loss: 2.7483159244060515


Training Epochs:  48%|████▊     | 2421/5000 [00:15<00:20, 128.58it/s]

Epoch: 2400, Train Loss: 3.222254514694214, Val Loss: 2.687283924818039


Training Epochs:  53%|█████▎    | 2629/5000 [00:16<00:18, 127.24it/s]

Epoch: 2600, Train Loss: 2.664768695831299, Val Loss: 2.705757224559784


Training Epochs:  56%|█████▋    | 2824/5000 [00:18<00:16, 133.59it/s]

Epoch: 2800, Train Loss: 2.6541106700897217, Val Loss: 2.70195885181427


Training Epochs:  61%|██████    | 3038/5000 [00:19<00:15, 127.67it/s]

Epoch: 3000, Train Loss: 2.1677467823028564, Val Loss: 2.65768724322319


Training Epochs:  65%|██████▍   | 3235/5000 [00:20<00:13, 132.50it/s]

Epoch: 3200, Train Loss: 3.0206761360168457, Val Loss: 2.640993736386299


Training Epochs:  69%|██████▊   | 3434/5000 [00:22<00:11, 135.94it/s]

Epoch: 3400, Train Loss: 2.817840814590454, Val Loss: 2.666358770132065


Training Epochs:  72%|███████▏  | 3614/5000 [00:23<00:11, 117.06it/s]

Epoch: 3600, Train Loss: 2.840182065963745, Val Loss: 2.634136815071106


Training Epochs:  76%|███████▋  | 3822/5000 [00:24<00:09, 118.45it/s]

Epoch: 3800, Train Loss: 2.2630693912506104, Val Loss: 2.6102201628684996


Training Epochs:  81%|████████  | 4035/5000 [00:26<00:07, 132.47it/s]

Epoch: 4000, Train Loss: 2.642688512802124, Val Loss: 2.6013960182666778


Training Epochs:  85%|████████▍ | 4231/5000 [00:27<00:05, 131.32it/s]

Epoch: 4200, Train Loss: 2.861732006072998, Val Loss: 2.593399704694748


Training Epochs:  89%|████████▊ | 4436/5000 [00:28<00:04, 127.38it/s]

Epoch: 4400, Train Loss: 2.387462854385376, Val Loss: 2.58214755654335


Training Epochs:  93%|█████████▎| 4636/5000 [00:30<00:02, 131.02it/s]

Epoch: 4600, Train Loss: 2.1232218742370605, Val Loss: 2.619614409208298


Training Epochs:  97%|█████████▋| 4828/5000 [00:31<00:01, 123.29it/s]

Epoch: 4800, Train Loss: 2.7502715587615967, Val Loss: 2.6101433753967287


Training Epochs: 100%|██████████| 5000/5000 [00:32<00:00, 154.55it/s]


In [22]:
print(100*'*')
# Load best losses from JSON file if it exists
best_losses_file = 'best_losses.json'
try:
    with open(best_losses_file, 'r') as f:
        best_losses = json.load(f)
        best_train_loss = best_losses.get('best_train_loss', best_train_loss)
        best_val_loss = best_losses.get('best_val_loss', best_val_loss)
except FileNotFoundError:
    best_losses = {
        'best_train_loss': best_train_loss,
        'best_val_loss': best_val_loss
    }
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

****************************************************************************************************


In [23]:
print(f"Generated Text:")
idx = torch.zeros((1,1), dtype=torch.long)
generated_text = decode(model.generate(idx, max_new_tokens=2000)[0].tolist())
print(generated_text)

Generated Text:

Sroko tes nRhol:
S berme pllresepy? ollopy'de thount ary Thangic the, handes the maousd t save frkofEIs,
The nocore I witisohe l hainss r,
E flldiO,
Gomeshoulr th fimewenchdoly telelu n? cvourteve t Werotho y e laxr'nd cc.
 x.
A,n meno lobld:r ces berorele hiveho f t?
Hhe.
Jn hed b niyharo ave d pe
!anortinfory tedlhour s.
:

Nhe
 'gall t'd bef wanorehe,nngou th, iske ia
NreniSps heem, tindle ve, coanut,
Su a k bather canf, hde t fars, erobe,
TE.

DLYe.:
MBast the t ayetErene foes,h ie oser t thar nrgolt athe thoutowensurararof s Er.
Dzrve, pis, ives ngan tH:
Whoug.

D oIAC
Thos cG sheaket's nginn:
Ix
S:
OWand-- t ve anthat igt clealleD? ot we  anThecaresthses, a
KL-epo win pe ' s'dors Rurhek
Ado f Jturll y isiworr: chte berraany t singig Gos: s ausFo tall.
Ds pe, bolou r rend mit,
We f-?it!
T;

IKACO'do q'd asasthef :
Whesus; trirceali, t llaor, y t yot hou hod r qereat, r
ESdom:
L
?INGano f I thandiouc, te tharshok, be,
HB
Am theiy avem ash n, y LAgthiu tr'sd gredy f

In [24]:
# Check if current run has better losses
if best_train_loss < best_losses.get('best_train_loss', float('inf')) or best_val_loss < best_losses.get('best_val_loss', float('inf')):
    # Save generated text to file
    with open('generated_shakespeare_text.txt', 'w') as f:
        f.write(generated_text)

    # Update best losses and save to JSON file
    best_losses['best_train_loss'] = best_train_loss
    best_losses['best_val_loss'] = best_val_loss
    with open(best_losses_file, 'w') as f:
        json.dump(best_losses, f)

    # Have wandb save the text file
    wandb.save('generated_shakespeare_text.txt')
    # also save an image of the training and validation loss curves
    plt.plot(train_losses, label='train loss')
    plt.plot(val_losses, label='val loss')
    plt.legend()
    plt.savefig('train_val_loss.png')
    wandb.save('train_val_loss.png')
    print("Current run beat the best losses. Generated text saved.")

else:
    print("Current run did not beat the best losses. Generated text not saved.")
print(100*'*')
print(100*'*')

Current run did not beat the best losses. Generated text not saved.
****************************************************************************************************
****************************************************************************************************


In [25]:
print(f"Best Train Loss: {best_train_loss}")
print(f"Best Validation Loss: {best_val_loss}")
# show total number of parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters in the model: {total_params}")
# show toal number of tokens in the dataset
total_tokens = len(data)
print(f"Total number of tokens in the dataset: {total_tokens}")
print(f"According to Chinchilla Law, you need at least {total_params * 2} tokens to train this model.") # TODO: work on this

Best Train Loss: 2.1232218742370605
Best Validation Loss: 2.58214755654335
Total number of parameters in the model: 29697
Total number of tokens in the dataset: 1115394
According to Chinchilla Law, you need at least 59394 tokens to train this model.


In [27]:
# Ensure train_time and other parameters are defined before logging
# wandb.log({
#     'epochs': epochs,
#     "learning_rate": learning_rate,
#     "block_size": block_size,
#     "batch_size": batch_size,
#     "embedding_size": n_emb,
#     "optimizer": "AdamW",
#     "device": device,
#     "vocab_size": vocab_size,
#     "best_train_loss": best_train_loss,
#     "best_val_loss": best_val_loss,
#     'Training Time': train_time, 
#     'dropout': dropout,
#     'n_layer': n_layer,
#     'n_heads': n_heads,
#     'train_test_split': train_test_split,
#     'total_params': total_params
# })

print(f"Total time to train model up to {epochs} epochs: {train_time:.2f} seconds")
wandb.finish()

Total time to train model up to 5000 epochs: 32.37 seconds
