# GPT2 Pre training from scratch

In [1]:
# imports
import torch
import torch.nn as nn
from gpt2 import GPT2
from utils import generate_text_simple, text_to_token_ids, token_ids_to_text, create_dataloader
# device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
device = torch.device('cpu')
print('default device:', device)

# tokenizer
import tiktoken
tokenizer = tiktoken.get_encoding('gpt2')

default device: cpu


### Model Initialized

In [2]:
# reduced orignal GPT2 context_len 1024 to 256, to fit in my laptop memeory
my_model_conf = {
    'vocab_size': 50257,
    'context_len': 256,
    'emb_dim': 768,
    'n_heads': 12,
    'n_layers': 12,
    'drop_rate': 0.1,
    'qkv_bias': False,
}
model = GPT2(my_model_conf)
model.to(device)
model.eval()

GPT2(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(256, 768)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (trf_block): Sequential(
    (0): TransformerLayer(
      (attn): MHSA(
        (W_query): Linear(in_features=768, out_features=768, bias=False)
        (W_key): Linear(in_features=768, out_features=768, bias=False)
        (W_value): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (norm1): LayerNorm()
      (norm2): LayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerLayer(
      (attn): MHSA(
        (W_query): Linear(in_features=768, out_features=768, bias=False)
 

### Download small datasets

In [3]:
import os
import urllib.request

file_path = "the-verdict.txt"
url = "https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt"

if not os.path.exists(file_path):
    with urllib.request.urlopen(url) as response:
        text_data = response.read().decode('utf-8')
    with open(file_path, "w", encoding="utf-8") as file:
        file.write(text_data)
else:
    with open(file_path, "r", encoding="utf-8") as file:
        text_data = file.read()

In [4]:
ratio = 0.9
split_idx = int(len(text_data.split()) * ratio)
train_data, val_data = text_data.split()[:split_idx], text_data.split()[split_idx:]
print(len(train_data), len(val_data))

train_data = [' '.join(train_data)] # encapsulated in list since it's a single sentence text
val_data = [' '.join(val_data)]

3270 364


In [5]:
# pytorch dataloader
train_dataloader = create_dataloader(train_data, 
                                     tokenizer, 
                                     batch_size = 2, 
                                     max_len = my_model_conf['context_len'], 
                                     stride = 1, 
                                     shuffle = True, 
                                     drop_last=True, 
                                     num_workers = 0)

val_dataloader = create_dataloader(val_data, 
                                     tokenizer, 
                                     batch_size = 2, 
                                     max_len = my_model_conf['context_len'], 
                                     stride = 1, 
                                     shuffle = False, 
                                     drop_last=True, 
                                     num_workers = 0)

In [6]:
for x, y in val_dataloader:
    print(token_ids_to_text(x[0], tokenizer))
    print(token_ids_to_text(y[0], tokenizer))
    print(x.shape, y.shape)
    break

for x, y in train_dataloader:
    print(x.shape, y.shape)
    break

his lips, through the gray beard, I seemed to hear the question: 'Are you sure you know where you're coming out?' "If I could have painted that face, with that question on it, I should have done a great thing. The next greatest thing was to see that I couldn't--and that grace was given me. But, oh, at that minute, Rickham, was there anything on earth I wouldn't have given to have Stroud alive before me, and to hear him say: 'It's not too late--I'll show you how'? "It _was_ too late--it would have been, even if he'd been alive. I packed up my traps, and went down and told Mrs. Stroud. Of course I didn't tell her _that_--it would have been Greek to her. I simply said I couldn't paint him, that I was too moved. She rather liked the idea--she's so romantic! It was that that made her give me the donkey. But she was terribly upset at not getting the portrait--she did so want him 'done' by some one showy! At first I was afraid she wouldn't let me off--and at my wits' end I suggested Grind
 li

In [7]:
def calc_loss_batch(input_batch, target_batch, model, device):
    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)
    logits = model(input_batch)
    loss = nn.functional.cross_entropy(logits.flatten(0,1), target_batch.flatten())
    return loss


In [8]:
def calc_loss_loader(data_loader, model, device, num_batches=None):
    total_loss = 0
    
    if len(data_loader) == 0:
        return float("nan")
    elif num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(len(data_loader), num_batches)
        
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
    return total_loss / num_batches
        
    

In [9]:
# # loss calculation
# with torch.no_grad():
#     train_loss = calc_loss_loader(train_dataloader, model, device)
#     val_loss = calc_loss_loader(val_dataloader, model, device)
# print(f'Training Loss: {train_loss}')
# print(f'Validation Loss: {val_loss}')

In [10]:
def train_model_simple(model, train_dataloader, val_dataloader,
                       optimizer, device, num_epoch, eval_freq,
                       eval_iter, start_context, tokenizer):
    train_losses, val_losses, track_token_seen =[], [], []
    token_seen, global_setp = 0, -1
    
    for epoch in range(num_epoch):
        model.train()
        for i, (input_batch, target_batch) in enumerate(train_dataloader):
            optimizer.zero_grad()
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()
            optimizer.step()
            
            token_seen += input_batch.numel()
            global_setp += 1
            
            if global_setp % eval_freq == 0:
                train_loss, val_loss = evaluate_model(model, train_dataloader, val_dataloader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_token_seen.append(token_seen)
                print(f"Epoch {epoch+1} (step {global_setp:06d}) : "
                      f"Train Loss:{train_loss:.3f}, "
                      f"Val Loss:{val_loss:.3f}")
        
                generate_and_print_samples(model, tokenizer, device, start_context)
    
    return train_losses, val_losses, track_token_seen


def evaluate_model(model, train_dataloader, val_dataloader, device, eval_iter):
    model.eval()
    with torch.no_grad():
        train_loss  = calc_loss_loader(train_dataloader, model, device, eval_iter)
        val_loss    = calc_loss_loader(val_dataloader, model, device, eval_iter)
    model.train()
    return train_loss, val_loss

def generate_and_print_samples(model, tokenizer, device, start_context):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokenizer).to(device)
    with torch.no_grad():
        token_ids = generate_text_simple(model=model, idx=encoded, max_new_tokens=50, context_size=context_size)
    decoded_text = token_ids_to_text(token_ids, tokenizer)
    print(decoded_text.strip())
    model.train()

In [11]:
len(train_dataloader)

2117

In [17]:
# model training
LR = 0.0004
MAX_EPOCHS = 20

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.1)
train_model_simple(model, train_dataloader, val_dataloader, optimizer, device, MAX_EPOCHS, eval_freq=100, eval_iter=5, start_context="Every Effort moves you", tokenizer=tokenizer)

Epoch 1 (step 000000) : Train Loss:5.957, Val Loss:7.090


KeyboardInterrupt: 