<a href="https://colab.research.google.com/github/ashegde/notebooks/blob/main/nano_nGPT_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

The material in this notebook follows Andrej Karpathy's nanoGPT tutorial (https://www.youtube.com/watch?v=kCc8FmEb1nY).

Edit: 10/11/2024, modified to include the normalized transformer as described in:

Loshchilov, I., Hsieh, C. P., Sun, S., & Ginsburg, B. (2024). nGPT: Normalized Transformer with Representation Learning on the Hypersphere. arXiv preprint arXiv:2410.01131.


In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
%matplotlib inline

In [None]:
#Loading the TinyShakespeare Dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
  dataset = f.read()

In [None]:
print(f'The dataset contains {len(dataset)} characters')

In [None]:
print(type(dataset)) #this dataset is just a big string (not a list of strings, we haven't used split or anything)
print(dataset[:1000])

In [None]:
# extract the unique characters/symbols/atoms that build the dataset
vocab = sorted(list(set(dataset)))
vocab_size = len(vocab)
print(''.join(vocab))
print(vocab_size)

In [None]:
# CHARACTER-LEVEL TOKENIZER

class Tokenizer:
  """
  Tokenizer based on an input character list

  This class provides functionality for converting text
  to character-level tokens.
  """
  def __init__(self, unique_characters: str):
    self.vocab = unique_characters
    self.character_to_index = {ch:i for i,ch in enumerate(self.vocab)}
    self.index_to_character = {i:ch for i,ch in enumerate(self.vocab)}

  def encode(self, text: str) -> list[int]:
    """
    Encode the input text into token indices

    Args:
      text (str): string to tokenize

    Returns:
      list: list of corresponding token indices
    """
    return [self.character_to_index[c] for c in text]

  def decode(self, indices: list[int]) -> str:
    """
    Decode the list of token indices into a string

    Args:
      indices (list[int]): list of integer token indices

    Returns:
      str: corresponding string
    """
    return "".join([self.index_to_character[i] for i in indices])

  def get_vocab_size(self):
    """Returns the size of the vocabulary"""
    return len(self.vocab)

tokenizer = Tokenizer(vocab)
print(tokenizer.encode('hello there'))
print(tokenizer.decode(tokenizer.encode('hello there')))


In [None]:
def prepare_data(
    dataset: str,
    tokenizer: Tokenizer,
    train_frac: float = 0.9,
) -> tuple[torch.tensor, torch.tensor]:
  """
  Prepares dataset for model training

  Converts the original dataset in the form of a string
  into two sequences of token indices -- one for training
  and one for testing.

  Args:
    dataset (str): dataset stored as a singled string
    tokenizer (Tokenizer): tokenizer to encode the dataset
    train_frac (float): fraction of dataset for training

  Returns:
    tuple(
      train_data (torch.tensor): 1d tensor of token indices for training
      val_data (torch.tensor): 1d tensor of token indices for validation
    ):
  """
  data = torch.tensor(
      tokenizer.encode(dataset),
      dtype=torch.long,
  )

  num_train = int(train_frac * len(data))
  return (data[:num_train], data[num_train:])


In [None]:
@dataclass
class GPTConfig:
  batch_size: int = 32
  context_size: int = 256
  max_iters: int = 5000
  learning_rate: int = 3e-4
  eval_iters: int = 200
  vocab_size: int = 50257
  n_layer: int = 6
  n_head: int = 6
  n_embd: int = 60
  eps: float = 1e-6
  device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

config = GPTConfig()

In [None]:
config

In [None]:
def get_batch(
    data: torch.tensor,
    config: GPTConfig
) -> tuple[torch.tensor, torch.tensor]:
  """
  Extracts a minibatch from the input dataset

  Args:
    data (torch.tensor): 1d tensor of token indices
    config (Config): model settings config file

  Returns:
    tuple(
      context (torch.tensor):
      targets (torch.tensor):
    )
  """
  #select n random starting indices for a sequence of size block_size, where n = batch_size
  context_size = config.context_size
  batch_size = config.batch_size
  device = config.device

  ix = torch.randint(len(data)-context_size, (batch_size,))

  context = torch.stack([data[i:i+context_size] for i in ix])  # (B,T)
  targets = torch.stack([data[i+1:i+context_size+1] for i in ix])  # (B,T)
  return context.to(device), targets.to(device)

In [None]:
train_data, val_data = prepare_data(dataset, tokenizer, 0.9)
print(len(train_data))

In [None]:
# Trial: sample a minibatch
xb, yb = get_batch(train_data, config)

print('inputs:')
print(xb.shape)
# print(xb)
print('targets:')
print(yb.shape)
# print(yb)

# print('----')

# #below we unpack all of the examples stored in each block in the batch
# for b in range(config.batch_size): #b = batch
#   for t in range(config.context_size): #t = time
#     context = xb[b,:t+1]
#     target = yb[b,t]
#     print(f'When context is {context.tolist()}, the target is: {target}')

## Full Transformer Model

In [None]:
class L2Norm(nn.Module):
  """
  L2 normalization layer
  """
  def __init__(self, eps = 1e-6):
    super().__init__()
    self.eps = eps

  def forward(self, x, dim=-1):
    # x is (..., D)
    return x / (self.eps+torch.linalg.norm(x,dim=dim, keepdim=True))

class NormableLinear(nn.Module):
  """
  Linear module with normalizable rows.
  """
  def __init__(
      self,
      input_dimension: int,
      output_dimension: int,
      scale: bool = False,
      norm_dim: int = -1,
  ):
    super().__init__()

    self.normable_weight = nn.Parameter(
        1/math.sqrt(input_dimension) * torch.randn(output_dimension, input_dimension),
    )
    self.norm = L2Norm()
    self.norm_dim = norm_dim
    self.is_scale = scale
    self.scale = nn.Parameter(
        torch.zeros(output_dimension),
    ) if self.is_scale else None

  def forward(self, x):
    # x is (..., input_dimension)
    # returns x@W.T
    if self.is_scale:
      return x@self.normable_weight.transpose(0,1) * self.scale.exp()
    else:
      return x@self.normable_weight.transpose(0,1)

  def get_indices(self, index: list[int]):
    if self.is_scale:
      return self.normable_weight[index] * self.scale[index].exp()
    else:
      return self.normable_weight[index]

  @torch.no_grad()
  def normalize(self):
    """
    Row-normalization of the weight matrix.
    """
    self.normable_weight.copy_(
        self.norm(self.normable_weight, dim=self.norm_dim),
    )

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0 # ensure divisibility for multiple heads
        self.c_attn = NormableLinear(config.n_embd, 3 * config.n_embd, scale=True) #3 for the query, key, and value matrices
        # output projection
        self.c_proj = NormableLinear(config.n_embd, config.n_embd, scale=True)
        # regularization
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x):
        # x is (B, T, C),
        # [B]atches
        # [T]okens := context_size = number of tokens in the sequence
        # [C]hannels := n_embd = n_heads * h_size. Recall, h_size = n_embd // n_heads = C // n_heads

        B, T, C = x.size()
        nh = self.n_head
        hs = C // self.n_head
        qkv = self.c_attn(x) #(B, T, 3 * n_embd)

        # split dim 2 into chunks of size n_embd --> in this case, we will have 3 chunks for query, key, and value outputs
        q, k, v = qkv.split(self.n_embd, dim=2) # q, k, v are each (B, T, n_embd)

        # distribute the projections across different heads
        q = q.view(B, T, nh, hs).transpose(1,2) #(B, n_head, T, h_size)
        q = q * hs # to rescale the attention dot product
        k = k.view(B, T, nh, hs).transpose(1,2) #(B, n_head, T, h_size)
        v = v.view(B, T, nh, hs).transpose(1,2) #(B, n_head, T, h_size)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)  # flash attention

        y = y.transpose(1,2).contiguous().view(B,T,C) #(B,T, C = n_heads * h_size)

        # output projection
        y = self.c_proj(y)
        return y

class MLP(nn.Module):

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        NormableLinear(n_embd, 4*n_embd, scale=True),    #the choice of 4 here is empirical
        nn.GELU(),
        NormableLinear(4*n_embd, n_embd, scale=True, norm_dim=0),
    )

  def forward(self,x):
    return self.net(x)

class Block(nn.Module):

  def __init__(self, config):
    super().__init__()
    n_embd = config.n_embd
    n_head = config.n_head
    self.causal_attn = CausalSelfAttention(config)
    self.mlp = MLP(config.n_embd)
    self.norm = L2Norm()


    self.scale_attn = nn.Parameter(-0.5*math.log(n_embd)*torch.ones(n_embd))
    self.scale_mlp = nn.Parameter(-0.5*math.log(n_embd)*torch.ones(n_embd))

  def forward(self, h):
    hA = self.norm(self.causal_attn(h))
    hM = self.norm(self.mlp(h))
    h = self.norm(h + self.scale_attn.exp() * (hA - h))
    h = self.norm(h + self.scale_mlp.exp() * (hM - h))
    return h

In [None]:
class MultiHeadSelfAttentionModel(nn.Module):

  def __init__(self, config):
    super().__init__()

    self.context_size = config.context_size
    self.token_embedding_table = NormableLinear(config.n_embd, config.vocab_size)
    self.position_embedding_table = NormableLinear(config.n_embd, config.context_size)
    self.norm = L2Norm()
    self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
    self.lm_head = NormableLinear(config.n_embd, config.vocab_size, scale = True)

  def forward(self, idx, targets=None):
    # Let B = batch_size, T = time, C = channels = vocab_size
    # idx and targets are both integer tensors of dimension (B,T)
    B,T = idx.shape
    tok_embd = self.token_embedding_table.get_indices(idx) #(B,T,n_embd)
    pos_embd = self.position_embedding_table.get_indices(torch.arange(T, device=config.device)) #(T, n_embd)
    x = tok_embd+pos_embd #(B,T,n_embd) + (T,n_embd) = (B,T,n_embd)
    x = self.norm(x)
    x = self.blocks(x)
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T,C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets) #(B,T)

    return logits, loss

  @torch.no_grad()
  def normalize(self):
    for name, module in model.named_modules():
      if isinstance(module, NormableLinear):
          module.normalize()

  def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T) whose (b,t)th entry corresponds to the vocabulary index in batch b at time t

    for _ in range(max_new_tokens):
      #ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens)
      idx_cond = idx[:,-self.context_size:]
      logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
      logits = logits[:,-1,:] # (B,C)
      probs = F.softmax(logits, dim=-1) #(B,C)
      idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
      idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
    return idx

In [None]:
model = MultiHeadSelfAttentionModel(config)
model.to(config.device)

In [None]:
model.normalize()
with torch.no_grad():
  print(torch.linalg.norm(model.token_embedding_table.normable_weight[0]))

In [None]:
print(f'This model has {sum(p.numel() for p in model.parameters())} parameters')

In [None]:
## Helper functions

@torch.no_grad()
def estimate_loss(
    model,
    data,
    config,
):
  out = {}
  model.eval()
  losses = torch.zeros(config.eval_iters)
  for k in range(config.eval_iters):
    X, Y = get_batch(data, config)
    logits, loss = model(X,Y)
    losses[k] = loss.item()

  model.train()
  return losses.mean()

In [None]:
config.learning_rate = 3e-2
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.0)

# for simplicity, we will just use an iteration-based scheduler (as opposed to epoch-based)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=100)

print(f'Device: {config.device}')
for step in range(config.max_iters):
  #minibatch
  xb, yb = get_batch(train_data, config)

  # train
  model.train()
  optimizer.zero_grad(set_to_none=True)
  logits, loss = model(xb,yb)

  #loss
  loss.backward()
  optimizer.step()
  model.normalize()

  if step % 100 == 0:
    train_loss = estimate_loss(
        model,
        train_data,
        config,
    )
    val_loss = estimate_loss(
        model,
        val_data,
        config,
    )
    print(f'iter {step} | train: {train_loss} | test: {val_loss}')
    scheduler.step(val_loss)


print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
idx = idx.to(config.device)
print( tokenizer.decode( model.generate(idx, max_new_tokens=500)[0].tolist() ))