<a href="https://colab.research.google.com/github/Aditya-Shandilya1182/neo/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass

In [None]:
#embd -> DIM
#hidden_dim -> FFN_DIM
#n_head -> N_HEADS
#head_dim -> HEAD_DIM

In [None]:
class RMSNorm(nn.Module):
  def __init__(self, dim, norm_eps):
    super().__init__()
    self.norm_eps = norm_eps
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)

  def forward(self, x):
    out = self._norm(x.float()).type_as(x)
    return out * self.weight

In [None]:
class MLP(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.w1 = nn.Linear(config.embd, config.hidden_dim, bias=False)
    self.w2 = nn.Linear(config.hidden_dim, config.embd, bias=False)
    self.w3 = nn.Linear(config.embd, config.hidden_dim, bias=False)

  def forward(self, x):
    x_w1 = self.w1(x)
    x_w3 = self.w3(x)
    x = F.silu(x_w1) * x_w3
    x = self.w2(x)
    return x

In [None]:
def apply_rotary_emb(xq, xk, freqs_cis):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

def precompute_freqs_cis(dim, end, theta = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis, x):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In [None]:
class Attention(nn.Module):
  def __init__(self, config):
    super().__init__()

    self.wq = nn.Linear(config.embd, config.n_heads * config.head_dim, bias=False)
    self.wk = nn.Linear(config.embd, config.n_kv_head * config.head_dim, bias=False)
    self.wv = nn.Linear(config.embd, config.n_kv_head * config.head_dim, bias=False)
    self.wo = nn.Linear(config.n_heads * config.head_dim, config.embd, bias=False)

    self.cache_k = torch.zeros((config.batch_size, config.seq_len, config.n_kv_head, config.head_dim))
    self.cache_v = torch.zeros((config.batch_size, config.seq_len, config.n_kv_head, config.head_dim))

    self.n_heads = config.n_heads
    self.head_dim = config.head_dim
    self.n_kv_head = config.n_kv_head
    self.n_kv_head_rep = config.n_kv_head_rep

  def forward(self, x, start_pos, freqs_cis, mask):
    batch_sz, seqlen, _ = x.shape

    queries = self.wq(x)
    keys = self.wk(x)
    values = self.wv(x)

    queries = queries.view(batch_sz, seqlen, self.n_heads, self.head_dim)
    keys = keys.view(batch_sz, seqlen, self.n_kv_head, self.head_dim)
    values = values.view(batch_sz, seqlen, self.n_kv_head, self.head_dim)

    queries, keys = apply_rotary_emb(queries, keys, freqs_cis)

    self.cache_k = self.cache_k.to(queries.device)
    self.cache_v = self.cache_v.to(queries.device)

    self.cache_k[:batch_sz, start_pos : start_pos + seqlen] = keys
    self.cache_v[:batch_sz, start_pos : start_pos + seqlen] = values

    keys = self.cache_k[:batch_sz, : start_pos + seqlen]
    values = self.cache_v[:batch_sz, : start_pos + seqlen]

    keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_head_rep)
    values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_head_rep)

    queries = queries.transpose(1, 2)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)

    out = F.scaled_dot_product_attention(queries, keys, values, attn_mask=mask)
    out = out.transpose(1, 2).contiguous().view(batch_sz, seqlen, -1)
    return self.wo(out)

In [None]:
class Block(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.attention = Attention(config)
    self.feed_forward = MLP(config)
    self.attention_norm = RMSNorm(config.embd, config.norm_eps)
    self.ffn_norm = RMSNorm(config.embd, config.norm_eps)

  def forward(self, x, start_pos, freqs_cis, mask):
     x = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask)
     x = x + self.feed_forward(self.ffn_norm(x))
     return x

In [None]:
class Neo(nn.Module):
  def __init__(self, config):
    super().__init__()
    self.tok_embeddings = nn.Embedding(config.vocab_size, config.embd)
    self.layers = nn.ModuleList()
    for _ in range(config.n_layers):
      self.layers.append(Block(config))
    self.norm = RMSNorm(config.embd, config.norm_eps)
    self.output = nn.Linear(config.embd, config.vocab_size, bias=False)
    self.freq_cis = precompute_freqs_cis(config.head_dim, config.seq_len * 2, config.theta)

  @torch.inference_mode()
  def forward(self, tokens, start_pos):
    bsz, seqlen = tokens.shape
    h = self.tok_embeddings(tokens)
    self.freqs_cis = self.freqs_cis.to(tokens.device)
    freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
    mask = None
    if seqlen > 1:
      mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device)

      mask = torch.triu(mask, diagonal=1).to(tokens.device)

    for layer in self.layers:
      h = layer(h, start_pos, freqs_cis, mask)
    h = self.norm(h)
    out = self.output(h).float()
    return out

In [None]:
@dataclass
class ModelConfig:
  embd: int = 4096
  hidden_dim: int = 14336
  n_heads: int = 32
  head_dim: int = embd // n_heads
  n_kv_heads: int = 8
  vocab_size: int = 128256
  n_layers: int = 32
  norm_eps: int = 1e-5
  theta: int = 500000
  seq_len: int = 128
  n_kv_heads_rep: int = n_heads // n_kv_heads

In [None]:
config = ModelConfig()
model = Neo(config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def train_model(model, dataloader, epochs, learning_rate, device, compile_model=False):
  model.to(device)
  if compile_model:
    model = torch.compile(model)
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)
  criterion = nn.CrossEntropyLoss()

  for epoch in range(epochs):
      model.train()
      total_loss = 0
      for batch in dataloader:
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch, start_pos = 0)
        loss = criterion(output[:, :-1].reshape(-1, config.vocab_size), batch[:, 1:].reshape(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
      avg_loss = total_loss/ len(dataloader)
      print(f"Epoch: {epoch + 1}/{epochs}, Avg Loss: {avg_loss:.4f}")
