In [1]:
!pip install transformers wandb datasets



In [2]:
import math

import torch

import torch.nn as nn

from torch.nn import functional as F

from dataclasses import dataclass

import numpy as np

from tqdm.auto import tqdm

from contextlib import nullcontext

import wandb

import os

from torch.utils.data import Dataset

from transformers import AutoTokenizer

import datasets

from datasets import load_dataset

In [3]:
# Define LayerNorm class

class LayerNorm(nn.Module):

    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):

        super().__init__()

        self.weight = nn.Parameter(torch.ones(ndim))

        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None



    def forward(self, input):

        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

In [4]:
# Define CausalSelfAttention class

class CausalSelfAttention(nn.Module):

    def __init__(self, config):

        super().__init__()

        assert config.n_embd % config.n_head == 0

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)

        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)

        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.n_head

        self.n_embd = config.n_embd

        self.dropout = config.dropout

        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')

        if not self.flash:

            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")

            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))



    def forward(self, x):

        B, T, C = x.size()

        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.flash:

            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)

        else:

            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))

            att = F.softmax(att, dim=-1)

            att = self.attn_dropout(att)

            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))

        return y

In [5]:
# Define LocalSelfAttention class

class LocalSelfAttention(nn.Module):

    def __init__(self, config, window_size = 128):

        super().__init__()

        assert config.n_embd % config.n_head == 0

        self.window_size = window_size

        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)

        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)

        self.attn_dropout = nn.Dropout(config.dropout)

        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.n_head

        self.n_embd = config.n_embd

        self.dropout = config.dropout



    def forward(self, x):

        B, T, C = x.size()

        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        y = torch.zeros_like(q)

        for i in range(0, T, self.window_size):

            end = min(i + self.window_size, T)

            q_window = q[:, :, i:end, :]

            k_window = k[:, :, i:end, :]

            v_window = v[:, :, i:end, :]

            att = (q_window @ k_window.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

            att = F.softmax(att, dim=-1)

            att = self.attn_dropout(att)

            y[:, :, i:end, :] = att @ v_window

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.resid_dropout(self.c_proj(y))

        return y


In [6]:
# Define MLP class

class MLP(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)

        self.gelu = nn.GELU()

        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)

        self.dropout = nn.Dropout(config.dropout)



    def forward(self, x):

        x = self.c_fc(x)

        x = self.gelu(x)

        x = self.c_proj(x)

        x = self.dropout(x)

        return x

In [7]:
# Define Block class

class Block(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)

        self.attn = CausalSelfAttention(config)

        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)

        self.mlp = MLP(config)



    def forward(self, x):

        x = x + self.attn(self.ln_1(x))

        x = x + self.mlp(self.ln_2(x))

        return x


In [8]:
# Define LocalAttentionBlock class

class LocalAttentionBlock(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)

        self.attn = LocalSelfAttention(config)

        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)

        self.mlp = MLP(config)



    def forward(self, x):

        x = x + self.attn(self.ln_1(x))

        x = x + self.mlp(self.ln_2(x))

        return x

In [9]:
# Define GPTConfig class

@dataclass

class GPTConfig:

    block_size: int = 2048

    vocab_size: int = 50257

    n_layer: int = 4

    n_head: int = 12

    n_embd: int = 768

    dropout: float = 0.0

    bias: bool = True

    window_size: int = 128

In [10]:
# Define GPT class

class GPT(nn.Module):

    def __init__(self, config):

        super().__init__()

        assert config.vocab_size is not None

        assert config.block_size is not None

        self.config = config

        self.transformer = nn.ModuleDict(dict(

            wte = nn.Embedding(config.vocab_size, config.n_embd),

            wpe = nn.Embedding(config.block_size, config.n_embd),

            drop = nn.Dropout(config.dropout),

            h = nn.ModuleList([Block(config) if i % 2 == 0 else LocalAttentionBlock(config) for i in range(config.n_layer)]),

            ln_f = LayerNorm(config.n_embd, bias=config.bias),

        ))

        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.transformer.wte.weight = self.lm_head.weight

        self.apply(self._init_weights)

        for pn, p in self.named_parameters():

            if pn.endswith('c_proj.weight'):

                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))



    def get_num_params(self, non_embedding=True):

        n_params = sum(p.numel() for p in self.parameters())

        if non_embedding:

            n_params -= self.transformer.wpe.weight.numel()

        return n_params



    def _init_weights(self, module):

        if isinstance(module, nn.Linear):

            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

            if module.bias is not None:

                torch.nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):

            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)



    def forward(self, idx, targets=None):

        device = idx.device

        b, t = idx.size()

        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

        pos = torch.arange(0, t, dtype=torch.long, device=device)

        tok_emb = self.transformer.wte(idx)

        pos_emb = self.transformer.wpe(pos)

        x = self.transformer.drop(tok_emb + pos_emb)

        for block in self.transformer.h:

            x = block(x)

        x = self.transformer.ln_f(x)

        if targets is not None:

            logits = self.lm_head(x)

            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

        else:

            logits = self.lm_head(x[:, [-1], :])

            loss = None

        return logits, loss



    @torch.no_grad()

    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

        for _ in range(max_new_tokens):

            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            logits, _ = self(idx_cond)

            logits = logits[:, -1, :] / temperature

            if top_k is not None:

                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))

                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)

            idx_next = torch.multinomial(probs, num_samples=1)

            idx = torch.cat((idx, idx_next), dim=1)

        return idx

In [11]:
config = GPTConfig

In [12]:
from kaggle_secrets import UserSecretsClient

# Load API keys https://www.kaggle.com/discussions/product-feedback/114053

secret_label = "WANDB_API_KEY" 

my_secret = UserSecretsClient().get_secret(secret_label)

wandb_config = {k:v for k,v in vars(config).items() if not callable(getattr(config, k)) and not k.startswith("__")} # Creating Wandb hyperparameters config for tracking experiements

wandb.login(key=my_secret)

run = wandb.init(project="gptneo-33M-tinystories", name="tinystories33M-3", config=wandb_config)
wandb_config

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33manirudhr1201[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112775377777476, max=1.0…

{'block_size': 2048,
 'vocab_size': 50257,
 'n_layer': 4,
 'n_head': 12,
 'n_embd': 768,
 'dropout': 0.0,
 'bias': True,
 'window_size': 128}

In [13]:
# Load dataset

import datasets

from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

README.md:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

(…)-00000-of-00004-2d5a1467fff1081b.parquet:   0%|          | 0.00/249M [00:00<?, ?B/s]

(…)-00001-of-00004-5852b56a2bd28fd9.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00002-of-00004-a26307300439e943.parquet:   0%|          | 0.00/246M [00:00<?, ?B/s]

(…)-00003-of-00004-d243063613e5a057.parquet:   0%|          | 0.00/248M [00:00<?, ?B/s]

(…)-00000-of-00001-869c898b519ad725.parquet:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2119719 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/21990 [00:00<?, ? examples/s]

In [14]:
type(ds)

datasets.dataset_dict.DatasetDict

In [15]:
# Initialize tokenizer

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

tokenizer_config.json:   0%|          | 0.00/727 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

In [16]:
def save_tokenized_dataset(tokenized_dataset, filename):

    with open(filename, 'wb') as f:

        torch.save(tokenized_dataset, f)

In [17]:
def load_tokenized_dataset(filename):

    with open(filename, 'rb') as f:

        return torch.load(f)

In [18]:
# Define TokenizedDataset class

import time

import os

start_time = time.time()

class TokenizedDataset(Dataset):

    def __init__(self, dataset):

        self.dataset = dataset.map(

            self.process,

            remove_columns=['text'],

            batched=True,

            batch_size=1000,

            num_proc=os.cpu_count()

        )

        self.data = []

        total_splits = len(self.dataset)

        for i, (split, dset) in enumerate(self.dataset.items()):

            self.data.extend(dset)

            #print(f"Tokenized {i+1}/{total_splits} splits ({(i+1)/total_splits*100:.2f}%)")



    def process(self, examples):

        total_examples = len(examples['text'])

        ids = []

        lengths = []

        for i, text in enumerate(examples['text']):

            ids.append(tokenizer.encode(text))

            lengths.append(len(ids[-1]))

            #if (i + 1) % 100 == 0 or (i + 1) == total_examples:

            #    print(f"Tokenized {i + 1}/{total_examples} texts. Estimated time remaining: {((total_examples - (i + 1)) / (i + 1)) * (time.time() - start_time):.2f} seconds.")

        out = {'ids': ids, 'len': lengths}

        return out



    def __len__(self):

        return len(self.data)



    def __getitem__(self, idx):

        if torch.is_tensor(idx):

            idx = idx.tolist()

        if idx < 0 or idx >= len(self.data):

            raise IndexError("Index out of range")

        return self.data[idx]



if not os.path.exists("train.bin"):

    tokenized_dataset = TokenizedDataset(ds)

    #save_tokenized_dataset(tokenized_dataset, "tokenized_dataset.pt")

    for split, dset in tokenized_dataset.dataset.items():

        arr_len = np.sum(dset['len'], dtype=np.uint64)

        filename = f'{split}.bin'

        dtype = np.uint16

        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))

        total_batches = 1024

        idx = 0

        batch_size = len(dset) // total_batches

        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):

            start_idx = batch_idx * batch_size

            end_idx = start_idx + batch_size

            batch = dset.select(range(start_idx, end_idx)).with_format('numpy')

            arr_batch = np.concatenate(batch['ids'])

            arr[idx : idx + len(arr_batch)] = arr_batch

            idx += len(arr_batch)

        arr.flush()

#else:

#    tokenized_dataset = load_tokenized_dataset("tokenized_dataset.pt")

Map (num_proc=4):   0%|          | 0/2119719 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/21990 [00:00<?, ? examples/s]

writing train.bin:   0%|          | 0/1024 [00:00<?, ?it/s]

writing validation.bin:   0%|          | 0/1024 [00:00<?, ?it/s]

In [19]:
# Define get_batch function

def get_batch(split):

    if split == 'train':

        data = np.memmap('train.bin', dtype=np.uint16, mode='r')

    else:

        data = np.memmap('validation.bin', dtype=np.uint16, mode='r')

    ix = torch.randint(len(data) - block_size, (batch_size,))

    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])

    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

    if device_type == 'cuda':

        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)

    else:

        x, y = x.to(device), y.to(device)

    return x, y

In [20]:
# Define estimate_loss function

def estimate_loss(model):

    out = {}

    model.eval()

    with torch.inference_mode():

        for split in ['train', 'val']:

            losses = torch.zeros(eval_iters)

            for k in range(eval_iters):

                X, Y = get_batch(split)

                with ctx:

                    logits, loss = model(X, Y)

                losses[k] = loss.item()

            out[split] = losses.mean()

    model.train()

    return out

In [21]:
# Training configuration

learning_rate = 1e-5

max_iters = 6000

warmup_steps = 300

min_lr = 5e-6

eval_iters = 100

batch_size = 8

block_size = 1024

gradient_accumulation_steps = 20

device = "cuda" if torch.cuda.is_available() else "cpu"

device_type = 'cuda' if 'cuda' in device else 'cpu'

dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'

ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]

ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

torch.set_default_device(device)

torch.manual_seed(42)


In [22]:
# Initialize model and optimizer
print(config)
model = GPT(config)


from torch.optim.lr_scheduler import LinearLR, SequentialLR, CosineAnnealingLR
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-9)

scheduler_warmup = LinearLR(optimizer, total_iters=warmup_steps)

scheduler_decay = CosineAnnealingLR(optimizer, T_max=max_iters - warmup_steps, eta_min=min_lr)

scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], milestones=[warmup_steps])

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

<class '__main__.GPTConfig'>
number of parameters: 66.95M


  scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))


In [23]:
# Training loop

best_val_loss = float('inf')

best_model_params_path = "best_model_params.pt"

final_model_params_path = "final_model_params.pt"

train_loss_list, validation_loss_list = [], []

In [24]:
import os
import torch

# Set PyTorch to use expandable segments to avoid fragmentation
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Example of reducing batch size and clearing cache
batch_size = 32  # Adjust batch size to fit your GPU memory
torch.cuda.empty_cache()  # Clear cache if needed

# Continue with your model training setup and execution

In [25]:
for epoch in tqdm(range(max_iters)):

    if epoch % eval_iters == 0 and epoch != 0:

        losses = estimate_loss(model)

        print(f"Epoch {epoch}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        print(f"The current learning rate: {optimizer.param_groups[0]['lr']:.4f}")

        train_loss_list += [losses['train']]

        validation_loss_list += [losses['val']]

        wandb.log({

            "epoch": epoch,

            "train/loss": losses['train'],

            "val/loss": losses['val'],

            "lr": optimizer.param_groups[0]['lr']

        })

        if losses['val'] < best_val_loss:

            best_val_loss = losses['val']

            torch.save(model.state_dict(), best_model_params_path)

            wandb.save(best_model_params_path)

    X, y = get_batch("train")

    with ctx:

        logits, loss = model(X, y)

        loss = loss / gradient_accumulation_steps

        scaler.scale(loss).backward()

    if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

        scaler.step(optimizer)

        scaler.update()

        optimizer.zero_grad(set_to_none=True)

    scheduler.step()

  0%|          | 0/6000 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 6.14 GiB. GPU 0 has a total capacity of 15.89 GiB of which 3.27 GiB is free. Process 2363 has 12.62 GiB memory in use. Of the allocated memory 12.30 GiB is allocated by PyTorch, and 31.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Save the final model

torch.save(model.state_dict(), final_model_params_path)

wandb.save(final_model_params_path)

wandb.finish()


In [None]:
# Plot training and validation loss

import matplotlib.pyplot as plt

train_loss_list_converted = [i.cpu().detach() for i in train_loss_list]

validation_loss_list_converted = [i.cpu().detach() for i in validation_loss_list]

plt.plot(train_loss_list_converted, 'g', validation_loss_list_converted, 'r')

plt.xlabel("Steps - Every 100 epochs")

plt.ylabel("Loss")

plt.show()

In [None]:
# Load the model

bestModel = GPT(config)

device = "cuda" if torch.cuda.is_available() else "cpu"

best_model_params_path = "best_model_params.pt"

bestModel.load_state_dict(torch.load(best_model_params_path, map_location=torch.device(device)))


In [None]:
# Generate text
sentence = "Once upon a time, there lived a black cat. The cat belonged to a little girl called Katie. Every day, Katie would take her cat for a walk in the park. One day, as Katie and her cat were walking around, they saw a mean looking man. He said he wanted to take the cat, to which she replied 'This cat belongs"
context = tokenizer.encode(sentence, return_tensors='pt')
context = context.to(device)
y = bestModel.generate(context, max_new_tokens=150)
print(tokenizer.decode(y[0], skip_special_tokens=True))