<a href="https://colab.research.google.com/github/PARTHIBAN-007/SLM-From-Scratch/blob/main/SLM_From_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets tiktoken

In [None]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

In [None]:
import tiktoken
import os
import numpy as np
from tqdm.auto import tqdm


en = tiktoken.get_encoding("gpt2")

def process(example):
  ids = enc.encode_ordinary(example['text'])
  out = {'ids':ids,'len':len(ids)}
  return out


if not os.path.exists('train.bin'):
  tokenized = ds.map(
      process,
      remove_columns = ['text'],
      desc = "tokenizing the splits",
      num_proc = 8
  )

  for split,dset in tokenized.items():
    arr_len = np.sum(dset['len'],dtype = np.uint64)
    filename = f'{split}.bin'
    dtype = np.uint16
    arr = np.memmap(filename,dtype = dtype,mode = "w+",shape = (arr_len,))
    total_batches = 1024

    idx = 0
    for batch_idx in tqd(range(total_batches),desc = f'writing {filename}'):
      batch = dset.shard(num_shards = total_batches,index = batch_idx , contiguous = True).with_format('numpy')
      arr_batch=  np.concatenate(batch['ids'])
      arr[idx:idx+len(arr_batch)] = arr_batch
      idx += len(arr_batch)
    arr.flush()


In [None]:
def get_batch(split):
  if split == "train":
    data = np.memmap('train.bin',dtype = np.uint16,mode = 'r')
  else:
    data = np.memmap('validation.bin',dtype=np.uint16,mode = 'r')
  ix = torch.randint(len(data)-block_size,(batch_size,))
  x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
  y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])

  if device_type == "cuda":
    x,y = x.pin_memory().to(device,non_blocking = True), y.pin_memory().to(device,non_blocking = True)
  else:
    x,y = x.to(device) , y.to(device)
  return x.y


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from tqdm.auto import tqdm
from contextlib import nullcontext
import os

class LayerNorm(nn.Module):
  def __init__(self,ndim,bias):
    self.weight = nn.Parameter(torch.ones(ndim))
    self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
  def forward(self,x):
    return F.layer_norm(x,self.weight.shape , self.weight,self.bias , 1e-5)


class CausalSelfAttention(nn.Module):
  def __init__(self,config):
    super.__init__()
    assert config.n_embed % config.n_head == 0
    self.c_attn = nn.Linear(config.n_embed,3*config.n_embed,bias= config.bias)
    self.c_proj = nn.Linear(config.n_embed,config.n_embed,bias = config.bias)
    self.nattn_dropout = nn.Dropout(config.dropout)
    self.resid_dropout = nn.Dropout(config.dropout)
    self.n_head = config.n_head
    self.n_embed = config.n_embed
    self.flash = hasattr(F,'scaled_dot_product_attention')
    if not self.flash:
      self.regiser_buffer("bias",torch.tril(torch.ones(config.block_size,config.block_size))
                                  .view(1,1,config.block_size,config.block_size))

  def forward(self,x):
    B , T, C = x.size()
    q,k,v = self.c_attn(x).split(self.n_embed,dim=2)
    k = k.view(B ,T ,self.n_head, C//self.n_head).transpose(1,2)
    q = q.view(B ,T ,self.n_head, C//self.n_head).transpose(1,2)
    v = v.view(B ,T ,self.n_head, C//self.n_head).transpose(1,2)


    if self.flash:
      y  = F.scaled_dot_product_attention(q,k,v,attn_mask= None,dropout_p = self.attn_dropout.p id self.training else 0.0 else is_causal = True)
    else:
      attn = (q@k.transpose(-2,-1))*(1.0/math.sqrt(k.size(-1)))
      att = att.masked_fill(self.bias[:,:,:T,:T]==0 ,float('inf'))
      att = F.softmax(att,dim=-1)
      att = self.attn_droput(att)
      y = att@v

    y = y.transpose(1,2).contiguous().view(B,T,C)
    y = self.resid_dropout(self.c_proj(y))
    return y

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embd, config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embd, config.bias)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


@dataclass
class GPTConfig:
    block_size: int
    vocab_size: int
    n_layer: int
    n_head: int
    n_embd: int
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.config = config
    self.transformer = nn.ModuleDict(dict(
        wte = nn.Embedding(config.vocab_size,config.n_embed),
        wpe = nn.Embedding(config.block_size,config.n_embed),
        dro p = nn.Dropout(config.dropout),
        h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ln_f = LayerNorm(config.n_embed,config.bias),
    ))
    self.lm_head = nn.Linear(config.n_embed,config.vocab_size,bias = False)
    self.transformer.wte.weight = self.lm_head.weight

    self.apply(self._init_weights):
    for pn,p in self.named_parameters():
      if pn.endswith("c_proj.weight"):
        nn.init.normal_(p,mean =0.0,std =0.02/math.sqrt(2*config.n_layer))

    def _init_weights(self,module):
      if isinstance(module,nn.Linear):
        nn.init.normal_(module.weight,mean-0.0,std =0.02)
        if module.bias is not None:
          nn.init.zeros_(module.bias)
        elif isinstance(module,nn.Embedding):
          nn.init.normal_(module.weight,mean = 0.0,std = 0.02)

    def forward(self,idx,targets = None):
      device = idx.device
      b, t = idx.size()
      assert t<= self.config.block_size
      pos = torch.arange(0,t,dtype = torch.long,device = device)

      tok_emb = self.transformerr.wte(idx)
      pos_emb = self.transformer.wpe(pos)
      x = self.transformer.drop(tok_emb + pos_emb)
      for block in self.transformer.h:
        x = block(x)
      x = self.transformer.ln_h(x)

      if targets is not None:
        logits = self.lm_head(x)
        loss = F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index = -1)
        return logits,loss
      else:
        return self.lm_head(x)


    @torch.no_grad()
    def generate(self,idx,max_new_tokens,temperature=1.0,top_k = None):
      for _ in range(max_new_tokens):
        idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:,-self.config.block_size:]
        logits = self(idx_cond)
        logits = logits[:,-1,:]
        if top_k is not None:
          v  , _ torch.topk(logits,min(top_k,logits,size(-1)))
          logits[logits<v[:,[-1]]] = - float('inf')
        probs = F.softmax(logits,dim=-1)
        idx_next = torch.multinomial(probs,num_samples=1)
        idx = torch.cat((idx,idx_next),dim=1)
      return idx

In [None]:
config = GPTConfig(
    vocab_size=50257,
    block_size=128,
    n_layer=6,
    n_head=6,
    n_embd=384,
    dropout=0.1,
    bias=True
)

model = GPT(config)

In [None]:
def estimate_loss(model):
  out = {}
  model.eval()
  with torch.inference_mode():
    for split in ['train','val']:
      losses = torch.zeros(eval_iters)
      for k in range(eval_iters):
        X,y = get_batch(split)
        with ctx:
          logits ,loss = model(X,y)
        losses[k] = loss.item()
      out[split] = losses.mean()
  model.train()
  return out

In [None]:
import torch
from contextlib import nullcontext


learning_rate = 1e-4
max_iters = 20000
warmup_steps = 1000
min_lr = 5e-4
eval_ites = 500
batch_size = 32
block_size = 128

gradient_accumulation_steps = 32

device = "cuda" if torch.cuda.is_available() else "cpu"
device_type = "cuda" if "cuda" in device else "cpu"

dtype = "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16"
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16 , "float16": torch.float16 }[dtype]


ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type = device_type,dtype = dtype)

torch.set_default_device(device)
torch.manual_seed(42)