In [None]:
#importing libraries
!pip install -U datasets tiktoken tqdm numpy torch matplotlib huggingface_hub
import os
import math
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from tqdm.auto import tqdm
import json
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import tiktoken
from datasets import load_dataset
import matplotlib.pyplot as plt

In [None]:
#tokenization implementation
from datasets import load_dataset
import numpy as np
from tqdm.auto import tqdm
import os
import tiktoken 
import math


enc = tiktoken.get_encoding("gpt2")
eot_token = enc.eot_token

In [None]:
def processing_bpe(sample):
    text_data = sample.get('text', '')
    if not isinstance(text_data, str): text_data = ""
    ids = enc.encode_ordinary(text_data)
    ids.append(eot_token)
    out = {'ids': ids, 'len': len(ids)}
    return out
combined_data_dir = config.data_dir
os.makedirs(combined_data_dir, exist_ok=True)
train_filename = os.path.join(combined_data_dir, 'train_combined.bin')
val_filename = os.path.join(combined_data_dir, 'val_combined.bin')

In [None]:
#tokenizer with data splits..
if os.path.exists(train_filename) and os.path.exists(val_filename):
    print(f"Combined .bin files already exist in {combined_data_dir}. Skipping processing.")

else:
    ts_df = load_dataset("roneneldan/TinyStories")
    bc_df = load_dataset("rojagtap/bookcorpus", split='train')
    
    ts_split = ts_df['train'].train_test_split(test_size=0.01, seed=42)
    ts_train_df = ts_split['train']
    ts_val_df = ts_split['test']
    
    bc_split = bc_df.train_test_split(test_size=0.01, seed=42)
    bc_train_df = bc_split['train']
    bc_val_df = bc_split['test']

    dataset_splits = {
        'train': {'ts': ts_train_df, 'bc': bc_train_df},
        'validation': {'ts': ts_val_df, 'bc': bc_val_df}
    }

    for split in ['train', 'validation']:
        filename = train_filename if split == 'train' else val_filename
        
        if os.path.exists(filename):
            os.remove(filename)

        tokenized_ts = dataset_splits[split]['ts'].map(
            processing_bpe, remove_columns=['text'],
            desc=f"Tokenizing TinyStories {split}", num_proc=config.num_proc
        )
        
        tokenized_bc = dataset_splits[split]['bc'].map(
            processing_bpe, remove_columns=['text'],
            desc=f"Tokenizing BookCorpus {split}", num_proc=config.num_proc
        )

        ts_len = np.sum(tokenized_ts['len'], dtype=np.uint64)
        bc_len = np.sum(tokenized_bc['len'], dtype=np.uint64)
        arr_len = ts_len + bc_len

        if arr_len == 0:
            print(f"Warning: No tokens found for combined {split} split. Skipping."); continue
            
        print(f"Total tokens for {split}: {arr_len:,} (TinyStories: {ts_len:,}, BookCorpus: {bc_len:,})")
        dtype_np = np.uint16
        arr = np.memmap(filename, dtype=dtype_np, mode='w+', shape=(arr_len,))
        
        print(f"Writing {ts_len:,} TinyStories tokens to {filename}...")
        idx = 0
        write_batch_size = 1000 # Docs per chunk
        
        total_samples_ts = len(tokenized_ts)
        num_write_batches_ts = math.ceil(total_samples_ts / write_batch_size)
        for i in tqdm(range(num_write_batches_ts), desc=f"Writing TinyStories {split}"):
            start = i * write_batch_size; end = min((i + 1) * write_batch_size, total_samples_ts)
            chunk = tokenized_ts.select(range(start, end))
            try:
                all_ids_in_chunk = [item['ids'] for item in chunk if item['ids']]
                if not all_ids_in_chunk: continue
                arr_batch = np.concatenate(all_ids_in_chunk).astype(dtype_np)
                batch_len = len(arr_batch)
                expected_end_idx = idx + batch_len
                if expected_end_idx > arr_len: print(f"Error: Bounds overflow (TS). Stopping."); break
                arr[idx : expected_end_idx] = arr_batch; idx = expected_end_idx
            except Exception as e: print(f"Error writing TS chunk {i}: {e}"); break
        
        
        total_samples_bc = len(tokenized_bc)
        num_write_batches_bc = math.ceil(total_samples_bc / write_batch_size)
        
        for i in tqdm(range(num_write_batches_bc), desc=f"Writing BookCorpus {split}"):
            start = i * write_batch_size; end = min((i + 1) * write_batch_size, total_samples_bc)
            chunk = tokenized_bc.select(range(start, end))
            try:
                all_ids_in_chunk = [item['ids'] for item in chunk if item['ids']]
                if not all_ids_in_chunk: continue
                arr_batch = np.concatenate(all_ids_in_chunk).astype(dtype_np)
                batch_len = len(arr_batch)
                expected_end_idx = idx + batch_len
                if expected_end_idx > arr_len: print(f"Error: Bounds overflow (BC). Stopping."); break
                arr[idx : expected_end_idx] = arr_batch; idx = expected_end_idx
            except Exception as e: print(f"Error writing BC chunk {i}: {e}"); break

        arr.flush()
        if idx != arr_len:
             print(f"Warning: Final index {idx} doesn't match expected {arr_len}.")
        else:
            print(f"Successfully wrote {idx} total tokens.")
        print(f"Finished writing {filename}.")

In [None]:
#memap files
import numpy as np
import torch
import os
train_data_file = os.path.join(config.data_dir, 'train_combined.bin')
val_data_file = os.path.join(config.data_dir, 'val_combined.bin')
if not os.path.exists(train_data_file) or not os.path.exists(val_data_file):
     raise FileNotFoundError(f"Combined .bin files not found in {config.data_dir}. Did Cell 3 complete?")

train_data = np.memmap(train_data_file, dtype=np.uint16, mode='r')
val_data = np.memmap(val_data_file, dtype=np.uint16, mode='r')

print(f"Train data loaded: {len(train_data):,} tokens")
print(f"Validation data loaded: {len(val_data):,} tokens")

In [None]:
#batch function
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.block_size, (config.batch_size,))
    
    x = torch.stack([torch.from_numpy((data[i:i+config.block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+config.block_size]).astype(np.int64)) for i in ix])
    
    if config.device == 'cuda':
        x, y = x.pin_memory().to(config.device, non_blocking=True), y.pin_memory().to(config.device, non_blocking=True)
    else:
        x, y = x.to(config.device), y.to(config.device)
    return x, y

try:
    x_batch, y_batch = get_batch('train')
    
    x_val, y_val = get_batch('val')
    print(f"Validation batch x shape: {x_val.shape}")
    print(f"Validation batch y shape: {y_val.shape}")
except Exception as e:
    print(f"Error testing get_batch: {e}")

In [None]:
#model setup

import torch
import torch.nn as nn
from torch.nn import functional as F
import math
from contextlib import nullcontext
from dataclasses import dataclass, field # Import field

@dataclass
class SLMConfig:
    block_size: int = 256; vocab_size: int = 50257
    n_layer: int = 12; n_head: int = 12; n_embd: int = 768
    dropout: float = 0.1; bias: bool = False
    # Add other fields with defaults from Cell 2 if needed
    batch_size: int = 8; gradient_accumulation_steps: int = 4
    max_iters: int = 100000; eval_interval: int = 1000
    eval_iters: int = 200; learning_rate: float = 3e-4
    weight_decay: float = 0.1; beta1: float = 0.9; beta2: float = 0.95
    grad_clip: float = 1.0; warmup_iters: int = 2000
    lr_decay_iters: int = 100000; min_lr: float = 3e-5
    device: str = 'cuda'; dtype: str = 'bfloat16'; compile: bool = True
    data_dir: str = 'data_combined'; num_proc: int = 8; total_batches: int = 1024
    out_dir: str = 'out_v3'; best_model_name: str = 'best_model_v3.pt'
    local_pretrained_path: str = ""

if 'config' not in locals(): raise NameError("Config object 'config' not found. Please re-run Cell 2.")
if 'ptdtype' not in locals(): ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[config.dtype]
ctx = nullcontext() if config.device == 'cpu' else torch.amp.autocast(device_type=config.device, dtype=ptdtype)


class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.flash = hasattr(F, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: Flash Attention 2.0 not available.")
        # Causal mask
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                    .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        is_training = self.training
        dropout_val = self.dropout if is_training else 0.0

        if self.flash and not is_training:
             try: y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
             except Exception: self.flash = False; 
        if not self.flash or is_training:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            current_T = min(T, self.bias.size(-1))
            att = att.masked_fill(self.bias[:,:,:current_T,:current_T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1);
            if dropout_val > 0.0: att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        proj_output = self.c_proj(y)
        if dropout_val > 0.0: y = self.resid_dropout(proj_output)
        else: y = proj_output
        return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = nn.GELU(approximate='tanh')
        self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        x = self.c_fc(x); x = self.gelu(x); x = self.c_proj(x)
        if self.training: x = self.dropout(x)
        return x

class Block(nn.Module):
    def __init__(self, config):
        super().__init__(); self.ln_1 = LayerNorm(config.n_embd, bias=config.bias); self.attn = CausalSelfAttention(config); self.ln_2 = LayerNorm(config.n_embd, bias=config.bias); self.mlp = MLP(config)
    def forward(self, x): x = x + self.attn(self.ln_1(x)); x = x + self.mlp(self.ln_2(x)); return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None; assert config.block_size is not None
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = LayerNorm(config.n_embd, bias=config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight # Weight tying
        
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def get_num_params(self, non_embedding=True):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device; b, t = idx.size();
        if t > self.config.block_size: idx = idx[:, -self.config.block_size:]; t = self.config.block_size
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        tok_emb = self.transformer.wte(idx); pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb) if self.training else tok_emb + pos_emb
        for block in self.transformer.h: x = block(x)
        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
        else: logits = self.lm_head(x[:, [-1], :]); loss = None
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        self.eval()
        enc = tiktoken.get_encoding("gpt2")
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))); logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1); idx_next = torch.multinomial(probs, num_samples=1)
            if idx_next == enc.eot_token: break
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

model = GPT(config) 

model.to(config.device)

if config.compile:
    print("compiling")
    try:
        model = torch.compile(model)
        print("Model compiled.")
    except Exception as e:
        print(f"torch.compile failed: {e}. Proceeding without compilation.")
        config.compile = False

In [None]:
#adam optimizer and lr scheduler

from torch.optim.lr_scheduler import LambdaLR

if 'model' not in locals():
    raise NameError("Model not defined. Please re-run Cell 5.")
decay_params = []
no_decay_params = []
for pn, p in model.named_parameters():
    if p.requires_grad:
        if p.dim() >= 2:
            decay_params.append(p)
        else:
            no_decay_params.append(p)

param_groups = [
    {'params': decay_params, 'weight_decay': config.weight_decay},
    {'params': no_decay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in no_decay_params)

optimizer = torch.optim.AdamW(
    param_groups,
    lr=config.learning_rate, 
    betas=(config.beta1, config.beta2),
    eps=1e-8
)

def get_lr_multiplier(it):
    if it < config.warmup_iters:
        return float(it) / float(max(1, config.warmup_iters))
    if it > config.lr_decay_iters:
        return config.min_lr / config.learning_rate
    decay_ratio = (it - config.warmup_iters) / (config.lr_decay_iters - config.warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    multiplier = (config.min_lr + coeff * (config.learning_rate - config.min_lr)) / config.learning_rate
    return multiplier

lr_scheduler = LambdaLR(optimizer, lr_lambda=get_lr_multiplier)