In [35]:
%matplotlib inline
import matplotlib.pyplot as plt

In [8]:
import torch
print(torch.__version__)

2.4.1+cu121


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [10]:
# fetching data from hf datasets: Salesforce/wikitext
# same as the one used by Graphcore/gpt2-wikitext-103
import datasets
from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
print(type(ds), len(ds))

<class 'datasets.arrow_dataset.Dataset'> 1801350


In [11]:
ds[100]

{'text': ' 96 ammunition packing boxes \n'}

In [12]:
import tiktoken

# yes, i'm using gpt2's tokenizer here 
# it is bpe, but not the version llama used, but should be okay ig
enc = tiktoken.get_encoding("gpt2")

encoded = enc.encode("hello world, let's goo")
decoded = enc.decode(encoded)
print(encoded, '\n', decoded)

[31373, 995, 11, 1309, 338, 467, 78] 
 hello world, let's goo


In [13]:
vocab_sz = enc.n_vocab
print(f"VOCAB SIZE: {vocab_sz}")

VOCAB SIZE: 50257


In [14]:
len(enc.encode(ds[3]['text']))

166

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import tiktoken

class WikiText103Dataset(Dataset):
    def __init__(self, split='train', seq_len=128):
        self.dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split=split)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.seq_len = seq_len

    def __len__(self):
        return len(self.dataset)

    # TODO: not the most optimized, but simple and works, can use pad_sequence
    # TODO: also needs attention mask
    # TODO: also if say tokens is of len 100, and seq_len is 40, other 60 tokens are thrown away
    # TODO: use sliding window to use other tokens
    def __getitem__(self, idx):
        text = self.dataset[idx]['text']

        # skip empty entries
        # wraps to beginning for last value
        if not text.strip():
            return self. __getitem__((idx + 1) % len(self))
        
        tokens = self.tokenizer.encode(text)
        
        # we need to pull 1 more token above seq_len 
        # reason: [a,b,c] -> [b,c,d] for seq [a,b,c,d]

        # clip tokens if exceeding seq_lim + 1
        if len(tokens) > (self.seq_len + 1):
            tokens = tokens[:(self.seq_len + 1)]
        # pad with zeros if short of seq_len + 1
        else:
            tokens = tokens + [0] * ((self.seq_len + 1) - len(tokens)) 

        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

# create dataset and dataloader
dataset = WikiText103Dataset(split='train')
#dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [18]:
# as the original data is too big, we make a smaller subset for
# quicker iteration

subset_size = 10000
subset_indices = list(range(subset_size))
subset_dataset = torch.utils.data.Subset(dataset, subset_indices)

# Create dataloader
dataloader = DataLoader(subset_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

In [21]:
# sample
for batch in dataloader:
    x, y = batch
    print(x.shape, y.shape)
    print()
    for input, output in zip(x.tolist()[:5], y.tolist()[:5]):
        print("[INPUT]", enc.decode(input[:10]))
        print("[OUTPUT]", enc.decode(output[:10]))
        print("-"*10)
    break

torch.Size([32, 128]) torch.Size([32, 128])

[INPUT]  The regiment was split into three groups for the attack
[OUTPUT]  regiment was split into three groups for the attack .
----------
[INPUT]  On May 17 , 2016 , Nathan signed with the
[OUTPUT]  May 17 , 2016 , Nathan signed with the Chicago
----------
[INPUT]  = = French recovery = = 
!!
[OUTPUT]  = French recovery = = 
!!!
----------
[INPUT]  Nathan graduated from Pine Bush High School in Pine Bush
[OUTPUT]  graduated from Pine Bush High School in Pine Bush ,
----------
[INPUT]  Proper discipline for children while being careful not to provoke
[OUTPUT]  discipline for children while being careful not to provoke them
----------


In [27]:
# copping from @bkitano xd
# TODO: i'm not using an eval set for now, but when i do, below function should be modded
# check @bkitano's implem

import numpy as np

@torch.no_grad()
def evaluate_loss(model, eval_dataloader):
    out = {}
    model.eval()
    losses = []
    for batch in eval_dataloader:
        xb, yb = batch
        xb, yb = xb.to(device), yb.to(device)
        _, loss = model(xb, yb)
        losses.append(loss.item())
    out["train"] = np.mean(losses)
    model.train()
    return out

In [23]:
from torch import nn
from torch.nn import functional as F

class DummyModel(nn.Module):
    def __init__(self, vocab_sz, d_model):  # Changed __init to __init__
        # d_model: embedding dim
        super().__init__()
        self.vocab_sz = vocab_sz  # Store vocab_sz as an attribute
        self.embedding = nn.Embedding(vocab_sz, d_model)
        self.linear = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, vocab_sz),
        )

        print("model params:", sum([m.numel() for m in self.parameters()]))
    
    def forward(self, idx, targets=None):
        x = self.embedding(idx)
        logits = self.linear(x)

        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_sz), targets.view(-1))
            return logits, loss
        else:
            return logits

# Now you can create the model
d_model = 64
model = DummyModel(vocab_sz, d_model).to(device)

model params: 6487313


In [40]:
optimizer = torch.optim.Adam(model.parameters())

import time
import pandas as pd
from tqdm import tqdm

epochs = 1
log_interval = 1

def train(model, optimizer, dataloader, epochs=10, log_interval=10, print_logs=True):
    losses = []
    total_start_time = time.time()
    for epoch in range(epochs):
        epoch_start_time = time.time()
        data_load_time = 0
        forward_time = 0
        backward_time = 0
        
        for batch_idx, (xs, ys) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)):
            batch_start_time = time.time()
            
            data_load_end_time = time.time()
            data_load_time += data_load_end_time - batch_start_time
            
            optimizer.zero_grad()
            xs, ys = xs.to(device), ys.to(device)
            logits, loss = model(xs, targets=ys)
            
            forward_end_time = time.time()
            forward_time += forward_end_time - data_load_end_time
            
            loss.backward()
            optimizer.step()
            
            backward_end_time = time.time()
            backward_time += backward_end_time - forward_end_time
            
            if batch_idx % log_interval == 0 or batch_idx == len(dataloader) - 1:
                batch_time = time.time() - batch_start_time
                x = evaluate_loss(model, dataloader)
                losses.append(x)
                if print_logs:
                    print(f"Epoch {epoch+1}, Batch {batch_idx}/{len(dataloader)} | "
                          f"train loss {x['train']:.3f} | Batch Time {batch_time:.3f}s | "
                          f"Data Load Time {data_load_time:.3f}s | Forward Time {forward_time:.3f}s | "
                          f"Backward Time {backward_time:.3f}s")
                data_load_time, forward_time, backward_time = 0, 0, 0
        
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} completed in {epoch_time:.2f}s")
    
    total_time = time.time() - total_start_time
    print(f"Training completed in {total_time:.2f}s")
    print("Final train loss: ", losses[-1]['train'])
    return model

In [38]:
# Warmup
print("Running warmup batch...")
warmup_xs, warmup_ys = next(iter(dataloader))
warmup_xs, warmup_ys = warmup_xs.to(device), warmup_ys.to(device)
_, warmup_loss = model(warmup_xs, targets=warmup_ys)
warmup_loss.backward()
print("Warmup complete.")

Running warmup batch...
Warmup complete.


In [41]:
model = train(model, optimizer, dataloader, epochs=1, log_interval=100, print_logs=True)

Epoch 1:   3%|▎         | 8/313 [00:02<01:16,  3.99it/s]

Epoch 1, Batch 0/313 | train loss 2.551 | Batch Time 0.019s | Data Load Time 0.000s | Forward Time 0.017s | Backward Time 0.002s


Epoch 1:  35%|███▌      | 110/313 [00:07<00:21,  9.48it/s]

Epoch 1, Batch 100/313 | train loss 2.576 | Batch Time 0.019s | Data Load Time 0.000s | Forward Time 1.766s | Backward Time 0.070s


Epoch 1:  68%|██████▊   | 212/313 [00:11<00:10,  9.45it/s]

Epoch 1, Batch 200/313 | train loss 2.554 | Batch Time 0.018s | Data Load Time 0.000s | Forward Time 1.769s | Backward Time 0.069s


Epoch 1:  98%|█████████▊| 308/313 [00:15<00:00,  9.40it/s]

Epoch 1, Batch 300/313 | train loss 2.525 | Batch Time 0.019s | Data Load Time 0.000s | Forward Time 1.776s | Backward Time 0.061s


                                                          

Epoch 1, Batch 312/313 | train loss 2.522 | Batch Time 0.019s | Data Load Time 0.000s | Forward Time 0.193s | Backward Time 0.011s
Epoch 1 completed in 18.42s
Training completed in 18.42s
Final train loss:  2.522393282990867




In [42]:
def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0):
    model.eval()
    tokens = tokenizer.encode(prompt)
    input_ids = torch.tensor(tokens).unsqueeze(0).to(device)
    
    with torch.no_grad():
        for _ in range(max_length):
            outputs = model(input_ids)
            next_token_logits = outputs[:, -1, :] / temperature
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            
            if next_token.item() == tokenizer.eot_token:
                break
    
    generated_tokens = input_ids[0].tolist()
    generated_text = tokenizer.decode(generated_tokens)
    
    return generated_text

# Example usage
prompt = "We live in a world where"
generated_text = generate_text(model, dataset.tokenizer, prompt, max_length=100, temperature=0.8)
print(f"Prompt: {prompt}")
print(f"Generated text: {generated_text}")

Prompt: We live in a world where
Generated text: We live in a world where the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and the first time , and
