In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Subset
import tqdm
import json
import datasets
from typing import List
import os
import tiktoken
import inspect
from sentencepiece import SentencePieceProcessor

In [2]:
os.makedirs('data/', exist_ok=True)

In [3]:
# encoding = tiktoken.get_encoding("gpt2")

In [4]:
# encoding.n_vocab

In [5]:
# !wget https://huggingface.co/OmAlve/TinyStories-SmolGPT/resolve/main/tok4096.model -P data/

In [6]:
class Tokenizer:
    def __init__(self, tokenizer_model="gpt2"):
        self.enc = tiktoken.get_encoding(tokenizer_model)
        self.tokenizer_model = tokenizer_model

        self.n_words = self.enc.n_vocab
        self.bos_id = None
        self.eos_id = self.enc.eot_token
        self.pad_id = None

    def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
        t = self.enc.encode(s)
        if bos and self.bos_id is not None:
            t = [self.bos_id] + t
        if eos and self.eos_id is not None:
            t = t + [self.eos_id]
        return t

    def decode(self, tokens: List[int]) -> str:
        return self.enc.decode(tokens)

In [7]:
# TOKENIZER_MODEL = "./data/tok4096.model"
TOKENIZER_MODEL = "gpt2"

In [8]:
# class Tokenizer:
#     def __init__(self, tokenizer_model):
#         model_path = tokenizer_model if tokenizer_model else TOKENIZER_MODEL
#         self.sp_model = SentencePieceProcessor(model_file=model_path)
#         self.model_path = model_path

#         self.n_words = self.sp_model.vocab_size()
#         self.bos_id = self.sp_model.bos_id()
#         self.eos_id = self.sp_model.eos_id()
#         self.pad_id = self.sp_model.pad_id()

#     def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
#         t = self.sp_model.encode(s)
#         if bos:
#             t = [self.bos_id] + t
#         if eos:
#             t = t + [self.eos_id]
#         return t

#     def decode(self, tokens: List[int]) -> str:
#         return self.sp_model.decode(tokens)

In [9]:
tokenizer = Tokenizer(tokenizer_model=TOKENIZER_MODEL)

In [10]:
tokenizer.n_words

50257

In [11]:
vocab_size = tokenizer.n_words
batch_size = 12
block_size = 512
max_iters = 1
eval_interval = 1000
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 256
n_embd = 512
n_head = 12
n_layer = 12
dropout = 0.3

target_batch_size = 8192
gradient_accumulation_steps = target_batch_size // batch_size
weight_decay = 1e-1
beta1 = 0.9
beta2 = 0.95

In [12]:
gradient_accumulation_steps

682

In [13]:
torch.set_float32_matmul_precision('high')

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")

In [15]:
def encode(s): return tokenizer.encode(s, bos=False, eos=False)

def decode(l):
	try:
		return tokenizer.decode(l)
	except:
		return ""

In [16]:
# ds = datasets.load_dataset("allenai/c4", "realnewslike")

ds = datasets.load_dataset("roneneldan/TinyStories")

In [17]:
ds = ds.with_format("torch")

In [18]:
ds['train'][1]

{'text': 'Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn.\n\nBeep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.'}

In [19]:
def collate_fn(batch):
    texts = [encode(item['text'])[:block_size] for item in batch]
    padded_texts = [t + [0] * (block_size - len(t)) for t in texts]
    return {
        'text': torch.tensor(padded_texts, dtype=torch.long)
    }

In [20]:
eval_iters

256

In [21]:
len(ds['train'])

2119719

In [22]:
subset_indices = list(range(eval_iters))
# train_indices = list(range(8000000))
# dataset_train = Subset(ds['train'], train_indices)
dataset_valid = Subset(ds['validation'], subset_indices)

In [23]:
train_dataloader = DataLoader(ds['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(dataset_valid, batch_size=batch_size, collate_fn=collate_fn)

In [24]:
class Head(nn.Module):
    """ one head of self-attention """
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(
            torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B, T, C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x)  # (B,T,hs)
        # compute attention scores ("affinities")
        # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = q @ k.transpose(-2, -1) * k.shape[-1]**-0.5
        wei = wei.masked_fill(
            self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei, dim=-1)  # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B,T,hs)
        out = wei @ v  # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out


class RMSNorm(nn.Module):
    """ Root Mean Square Normalization """
    def __init__(self, embed_dim, epsilon=1e-8):
        super().__init__()
        self.embed_dim = embed_dim
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(embed_dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.epsilon)
        out = x / rms
        out = self.gamma * out
        return out


class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.ModuleList([Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self.token_embedding_table.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx)  # (B,T,C)
        pos_emb = self.position_embedding_table(
            torch.arange(T, device=device))  # (T,C)
        x = tok_emb + pos_emb  # (B,T,C)
        for block in self.blocks:
            x = block(x)  # (B,T,C)
        x = self.ln_f(x)  # (B,T,C)
        logits = self.lm_head(x)  # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=50, top_p=0.9, repetition_penalty=1.2):
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]

            if repetition_penalty != 1.0:
                for i in range(idx.shape[0]):
                    for token in set(idx[i].tolist()):
                        if logits[i, token] < 0:  
                            logits[i, token] *= repetition_penalty  
                        else:  
                            logits[i, token] /= repetition_penalty  

            logits = logits / temperature
            probs = F.softmax(logits, dim=-1)

            if top_k is not None:
                top_k_values, top_k_indices = torch.topk(probs, top_k, dim=-1)
                mask = torch.zeros_like(probs, dtype=torch.bool).scatter_(1, top_k_indices, True)
                probs = torch.where(mask, probs, torch.tensor(0.0, device=probs.device))

            if top_p is not None:
                sorted_probs, sorted_indices = torch.sort(probs, descending=True)
                cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                sorted_indices_to_remove[:, 0] = 0
                for i in range(probs.size(0)):
                    probs[i][sorted_indices[i][sorted_indices_to_remove[i]]] = 0

            next_token = torch.multinomial(probs, num_samples=1).squeeze(-1)
            idx = torch.cat((idx, next_token.unsqueeze(1)), dim=-1)

        return idx
        

In [25]:
torch.cuda.empty_cache()

In [26]:
model = GPTLanguageModel()

if torch.cuda.device_count() > 1:
    model = torch.nn.DataParallel(model)

model = model.to(device)
# model = torch.compile(model)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

63.60832 M parameters


In [27]:
# fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
# use_fused = fused_available and 'cuda' == str(device)
# print(f"{use_fused=}")

In [28]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2), eps=1e-8)

In [29]:
T_max = len(train_dataloader)
warmup_steps = 0.01 * T_max  # 1% of total training
scheduler = lr_scheduler.OneCycleLR(
    optimizer, max_lr=6e-4, total_steps=T_max, pct_start=0.01
)

# scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10000, eta_min=6e-6)

In [30]:
# eval_interval = len(train_dataloader) // 10
# eval_interval

In [31]:
os.makedirs("ckpt/", exist_ok=True)

In [32]:
str(device)

'cuda'

In [33]:
sample = tokenizer.decode(tokenizer.encode(ds["train"][0]["text"][:100], bos=True, eos=True))
sample

'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with<|endoftext|>'

In [34]:
def generate(model, idx, max_new_tokens):
    for _ in range(max_new_tokens):
        # crop idx to the last block_size tokens
        idx_cond = idx[:, -block_size:]
        # get the predictions
        logits, loss = model(idx_cond)
        # focus only on the last time step
        logits = logits[:, -1, :]  # becomes (B, C)
        # apply softmax to get probabilities
        probs = F.softmax(logits, dim=-1)  # (B, C)
        # sample from the distribution
        idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
        # append sampled index to the running sequence
        idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
    return idx

In [None]:
gradient_accumulation_steps, batch_size, target_batch_size

(682, 12, 8192)

In [36]:
with open("losses.txt", "w") as f:
	f.write("Training Loss,Validation Loss,Output\n")

In [None]:
for iter, batch in enumerate(tqdm.notebook.tqdm(train_dataloader, total=len(train_dataloader))):
    inputs, targets = batch['text'], batch['text']
    inputs, targets = inputs.to(device), targets.to(device)

    with torch.autocast(device_type=str(device), dtype=torch.bfloat16):
        logits, loss = model(inputs, targets)

    loss = loss / gradient_accumulation_steps
    loss.backward()

    norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    
    if (iter + 1) % gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    if iter % (gradient_accumulation_steps * 2) == 0 or iter == max_iters - 1:
        print(f"Step {iter}: Performing validation")
        model.eval()
        with torch.no_grad():
            val_loss = 0
            train_loss = loss.item() * gradient_accumulation_steps
            for batch in tqdm.notebook.tqdm(valid_dataloader, total=len(valid_dataloader)):
                inputs, targets = batch['text'], batch['text']
                inputs, targets = inputs.to(device), targets.to(device)
                _, loss = model(inputs, targets)
                val_loss += loss.item()

            torch.save(model.state_dict(), f"ckpt/ckpt_{iter}.pt")
            print(f"Train loss: {train_loss:.4f}")
            print(f"Validation loss: {val_loss / len(valid_dataloader):.4f}")

            prompt = "Hello I am "
            prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
            output = decode(generate(model, prompt, max_new_tokens=50)[0].tolist())
            print(output)
            with open("losses.txt", "a") as f:
                f.write(f"{train_loss},{val_loss / len(valid_dataloader)},\"{output}\"\n")
        model.train()

In [38]:
torch.save(model.state_dict(), "final_model_tiny_stories.pt")

In [39]:
# # model = GPTLanguageModel()
# # model = model.to(device)
# model.load_state_dict(torch.load("/kaggle/working/ckpt/ckpt_5625.pt", weights_only=True))

# # model.eval()
# # model.to('cpu')

In [40]:
model = GPTLanguageModel()
model.load_state_dict(torch.load("final_model_tiny_stories.pt", map_location=device))
model = model.to(device)

In [41]:
model = model.eval()

In [42]:
prompt = "There was a girl who"

prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
print(decode(generate(model, prompt, max_new_tokens=50)[0].tolist()))

There was a girl who who who who Adobe plnex sore definite " "".!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


In [43]:
prompt = "Today a town in"

prompt = torch.tensor([encode(prompt)], dtype=torch.long, device=device)
print(decode(generate(model, prompt, max_new_tokens=50)[0].tolist()))

Today a town in in inExperience MendNull flashlightPed Visit Thy dismasses SQLwinning shook Mitt Tem pilgrimage MOT Reggie tireRegistered kids drive even grass sounds sure but pid Sarah counties werenouemo thick SR OutlookdemocraticLTSHA conj juicyneyCredit LorenzoJÉ fundra constitutedワ


In [None]:
# model.to('cpu')

In [None]:
# torch.save(model.state_dict(), "final_model.pt")