In [1]:
import torch
print(torch.__version__)

2.4.1+cu121


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
# fetching data from hf datasets: Salesforce/wikitext
# same as the one used by Graphcore/gpt2-wikitext-103
import datasets
from datasets import load_dataset

ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1", split='train')
print(type(ds), len(ds))

  from .autonotebook import tqdm as notebook_tqdm


<class 'datasets.arrow_dataset.Dataset'> 1801350


In [6]:
ds[100]

{'text': ' 96 ammunition packing boxes \n'}

In [7]:
import tiktoken

# yes, i'm using gpt2's tokenizer here 
# it is bpe, but not the version llama used, but should be okay ig
enc = tiktoken.get_encoding("gpt2")

encoded = enc.encode("hello world, let's goo")
decoded = enc.decode(encoded)
print(encoded, '\n', decoded)

[31373, 995, 11, 1309, 338, 467, 78] 
 hello world, let's goo


In [12]:
vocab_sz = enc.n_vocab
print(f"VOCAB SIZE: {vocab_sz}")

VOCAB SIZE: 50257


In [8]:
len(enc.encode(ds[3]['text']))

166

In [16]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import tiktoken

class WikiText103Dataset(Dataset):
    def __init__(self, split='train', seq_len=128):
        dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split=split)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.seq_len = seq_len
        
        # Filter out empty entries
        self.dataset = [item for item in dataset if item['text'].strip()]

    def __len__(self):
        return len(self.dataset)

    # TODO: not the most optimized, but simple and works, can use pad_sequence
    # TODO: also needs attention mask
    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        tokens = self.tokenizer.encode(text)
        
        # we need to pull 1 more token above seq_len 
        # reason: [a,b,c] -> [b,c,d] for seq [a,b,c,d]

        # clip tokens if exceeding seq_lim + 1
        if len(tokens) > (self.seq_len + 1):
            tokens = tokens[:(self.seq_len + 1)] 
        # pad with zeros if short of seq_len + 1
        else:
            tokens = tokens + [0] * ((self.seq_len + 1) - len(tokens)) 

        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

# create dataset and dataloader
dataset = WikiText103Dataset(split='train')
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [17]:
# sample
for batch in dataloader:
    x, y = batch
    print(x.shape, y.shape)
    print()
    for input, output in zip(x.tolist()[:5], y.tolist()[:5]):
        print("[INPUT]", enc.decode(input[:10]))
        print("[OUTPUT]", enc.decode(output[:10]))
        print("-"*10)
    break

torch.Size([16, 128]) torch.Size([16, 128])

[INPUT]  As a writer and director of a number of science
[OUTPUT]  a writer and director of a number of science fiction
----------
[INPUT]  In the meantime , at about 11 : 00 am
[OUTPUT]  the meantime , at about 11 : 00 am ,
----------
[INPUT]  In September 2010 , Entertainment Weekly reported that based on
[OUTPUT]  September 2010 , Entertainment Weekly reported that based on reviews
----------
[INPUT]  Schmerber submitted an appeal to the Supreme Court
[OUTPUT] merber submitted an appeal to the Supreme Court of
----------
[INPUT]  Meanwhile , she had begun a friendship with Berkman
[OUTPUT]  , she had begun a friendship with Berkman ,
----------


In [22]:
# copping from @bkitano xd
# TODO: i'm not using an eval set for now, but when i do, below function should be modded
# check @bkitano's implem

import numpy as np

@torch.no_grad()
def evaluate_loss(model):
    out = {}
    model.eval()
    losses = []
    for batch in dataloader:
        xb, yb = batch
        xb, yb = xb.to(device), yb.to(device)
        _, loss = model(xb, yb)
        losses.append(loss.item())
    out["train"] = np.mean(losses)
    model.train()
    return out

In [24]:
from torch import nn
from torch.nn import functional as F

class DummyModel(nn.Module):
    def __init__(self, vocab_sz, d_model):  # Changed __init to __init__
        # d_model: embedding dim
        super().__init__()
        self.vocab_sz = vocab_sz  # Store vocab_sz as an attribute
        self.embedding = nn.Embedding(vocab_sz, d_model)
        self.linear = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, vocab_sz),
        )

        print("model params:", sum([m.numel() for m in self.parameters()]))
    
    def forward(self, idx, targets=None):
        x = self.embedding(idx)
        logits = self.linear(x)

        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.vocab_sz), targets.view(-1))
            return logits, loss
        else:
            return logits

# Now you can create the model
d_model = 64
model = DummyModel(vocab_sz, d_model).to(device)

model params: 6487313


In [25]:
optimizer = torch.optim.Adam(model.parameters())

import time
import pandas as pd
from tqdm import tqdm

epochs = 1
log_interval = 1

def train(model, optimizer, dataloader, epochs=10, log_interval=100, print_logs=False):
    losses = []
    for epoch in tqdm(range(epochs), desc="Epochs"):
        start_time = time.time()
        for batch_idx, (xs, ys) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}", leave=False)):
            optimizer.zero_grad()

            xs, ys = xs.to(device), ys.to(device)
            
            logits, loss = model(xs, targets=ys)
            loss.backward()
            optimizer.step()
            
            if batch_idx % log_interval == 0:
                batch_time = time.time() - start_time
                x = evaluate_loss(model)
                losses.append(x)
                if print_logs:
                    print(f"Epoch {epoch+1}, Batch {batch_idx} | train loss {x['train']:.3f} | Time {batch_time:.3f}")
                start_time = time.time()

    print("Final train loss: ", losses[-1]['train'])
    return pd.DataFrame(losses).plot()

train(model, optimizer, dataloader, epochs=10, log_interval=100, print_logs=True)

Epochs:   0%|          | 0/10 [01:53<?, ?it/s]


KeyboardInterrupt: 