<a href="https://colab.research.google.com/github/ashegde/notebooks/blob/main/LinearLM_and_tiny_shakespeare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a scratch notebook for implementing a linear language model -- for example, as discussed in:


*Malach, E. (2023). Auto-regressive next-token predictors are universal learners. arXiv preprint arXiv:2309.06979.*


Note that the principal aim here is just to show that simple auto-regressive models can have seemingly non-trivial performance. Such performance can be used as a reference for which to compare and interpret more sophisticated models.

At some level, this should not be a surprise. Linear(ized) systems and linear(ized) state space models have formed the foundation of many disciplines -- e.g., control, optimization, time series, econometrics, etc. -- and their usage is widespread in practice.

What we show here is that a simple linear model containing just a few thousand parameters can rapidly extract general structure and style from the toy ``Tiny Shakespeare'' dataset. Naturally, this model fails to learn any real language capabaility. Transformer models -- which contain many more parameters -- of course perform much better, as we observed in a previous notebook, but it is still remarkable that structure emerges from such a simple formulation.

In [None]:
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
%matplotlib inline

In [None]:
#Loading the TinyShakespeare Dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
  dataset = f.read()

In [None]:
print(f'The dataset contains {len(dataset)} characters')

In [None]:
print(type(dataset)) #this dataset is just a big string (not a list of strings, we haven't used split or anything)
print(dataset[:1000])

In [None]:
# find unique characters to use as our tokens
vocab = sorted(list(set(dataset)))
vocab_size = len(vocab)
print(''.join(vocab))
print(vocab_size)

In [None]:
# CHARACTER-LEVEL TOKENIZER

class Tokenizer:
  """
  Tokenizer based on an input character list

  This class provides functionality for converting text
  to character-level tokens.
  """
  def __init__(self, unique_characters: str):
    self.vocab = unique_characters
    self.character_to_index = {ch:i for i,ch in enumerate(self.vocab)}
    self.index_to_character = {i:ch for i,ch in enumerate(self.vocab)}

  def encode(self, text: str) -> list[int]:
    """
    Encode the input text into token indices

    Args:
      text (str): string to tokenize

    Returns:
      list: list of corresponding token indices
    """
    return [self.character_to_index[c] for c in text]

  def decode(self, indices: list[int]) -> str:
    """
    Decode the list of token indices into a string

    Args:
      indices (list[int]): list of integer token indices

    Returns:
      str: corresponding string
    """
    return "".join([self.index_to_character[i] for i in indices])

  def get_vocab_size(self):
    """Returns the size of the vocabulary"""
    return len(self.vocab)

tokenizer = Tokenizer(vocab)
print(tokenizer.encode('hello there'))
print(tokenizer.decode(tokenizer.encode('hello there')))


In [None]:
def prepare_data(
    dataset: str,
    tokenizer: Tokenizer,
    train_frac: float = 0.9,
) -> tuple[torch.tensor, torch.tensor]:
  """
  Prepares dataset for model training

  Converts the original dataset in the form of a string
  into two sequences of token indices -- one for training
  and one for testing.

  Args:
    dataset (str): dataset stored as a singled string
    tokenizer (Tokenizer): tokenizer to encode the dataset
    train_frac (float): fraction of dataset for training

  Returns:
    tuple(
      train_data (torch.tensor): 1d tensor of token indices for training
      val_data (torch.tensor): 1d tensor of token indices for validation
    ):
  """
  data = torch.tensor(
      tokenizer.encode(dataset),
      dtype=torch.long,
  )

  num_train = int(train_frac * len(data))
  return (data[:num_train], data[num_train:])


In [None]:
train_data, val_data = prepare_data(dataset, tokenizer, 0.9)

In [None]:
len(train_data)

In [None]:
# model configuration
@dataclass
class Config:
  batch_size: int = 64
  context_size: int = 8
  eps: float = 1e-10
  n_embd: int = 32
  random_seed = 1337
  device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = Config()

In [None]:
def get_batch(
    data: torch.tensor,
    config: Config
) -> tuple[torch.tensor, torch.tensor]:
  """
  Extracts a minibatch from the input dataset

  Args:
    data (torch.tensor): 1d tensor of token indices
    config (Config): model settings config file

  Returns:
    tuple(
      context (torch.tensor):
      targets (torch.tensor):
    )
  """
  #select n random starting indices for a sequence of size block_size, where n = batch_size
  context_size = config.context_size
  batch_size = config.batch_size
  device = config.device

  ix = torch.randint(len(data)-context_size, (batch_size,))

  context = torch.stack([data[i:i+context_size] for i in ix])  # (B,T)
  targets = torch.stack([data[i+1:i+context_size+1] for i in ix])  # (B,T)
  return context.to(device), targets.to(device)

In [None]:
# Trial: sample a minibatch
xb, yb = get_batch(train_data, config)

print('inputs:')
print(xb.shape)
# print(xb)
print('targets:')
print(yb.shape)
# print(yb)

# print('----')

# #below we unpack all of the examples stored in each block in the batch
# for b in range(config.batch_size): #b = batch
#   for t in range(config.context_size): #t = time
#     context = xb[b,:t+1]
#     target = yb[b,t]
#     print(f'When context is {context.tolist()}, the target is: {target}')

In [None]:
class UpperTriLinear(nn.Linear):
  '''
  Upper Triangular Linear Layer

  Implements an upper triangular linear layer of the form:

  x W + b

  where W is upper triangular. Note, the implementation below
  is based on nn.Linear(in_features, out_features), which has
  the following implementation:

  x A.T + b

  where:

  nn.Linear.weight = A, and is (out_features, in_features)
  nn.Linear.bias = b, and is (out_features,)

  Hence, to apply an upper triangular mask on the matrix multiply,
  we must apply a lower triangular mask to A.

  '''
  def __init__(self, in_features, out_features):
      super().__init__(in_features, out_features)

      with torch.no_grad():
        self.weight.copy_(torch.tril(self.weight))
      self.weight.register_hook(lambda grad: grad * torch.tril(torch.ones_like(grad)))

class CausalLinearBlock(nn.Module):
  '''
  Causal Linear Block

  Causal linear block that linearly mixes token embeddings *across time*
  in a causal manner. The dimensions/channels of the modified tokens
  are then subsequently linearly mixed.
  '''
  def __init__(self, config: Config):
    super().__init__()

    self.triu_linear = UpperTriLinear(
        config.context_size,
        config.context_size,
    )
    self.channel_mixer = nn.Linear(
        config.n_embd,
        config.n_embd,
    )
    self.n_embd = config.n_embd
    self.context_size = config.context_size

  def forward(self, x: torch.tensor) -> torch.tensor:
    # x is (B, T, C),
    # [B]atches
    # [T]okens <= context_size
    # [C]hannels := n_embd

    B, T, C = x.size()

    if T < self.context_size:
        x = F.pad(
            x,
            (0, 0, 0, self.context_size - T),
            'constant',
            0,
        )  # (B, context_size, C)

    x = x.transpose(1,2)  # (B, C, context_size)
    y = self.triu_linear(x)[:,:,:T].transpose(1,2) # (B, T, C)
    return self.channel_mixer(y)


class LinearLM(nn.Module):

    def __init__(
        self,
        config: Config,
        tokenizer: Tokenizer,
    ):
        super().__init__()
        self.config = config
        self.tokenizer = tokenizer

        self.input_embedding = nn.Embedding(
            tokenizer.get_vocab_size(),
            config.n_embd,
        )
        self.causal_linear_block = CausalLinearBlock(config)
        self.lm_head = nn.Linear(
            config.n_embd,
            tokenizer.get_vocab_size(),
            bias=False,
        )    # language model head, final classifier

        # weight sharing scheme
        self.input_embedding.weight = self.lm_head.weight

        # initialize parameters
        # self.apply(self._init_weights)

    # def _init_weights(self, module):
    #     return None

    def forward(self, idx, targets=None):
        # idx is (B,T)
        B, T = idx.size()
        assert T <= self.config.context_size, f'input sequence length ({T}) exceeds model context size ({self.config.context_size})'
        x = self.input_embedding(idx) # (B, T, n_embd) token embedding for each sequence element
        x = self.causal_linear_block(x) #(B, T, n_embd)

        # next_token_logits
        nt_logits = self.lm_head(x) # (B, T, vocab_size)

        # teacher-forcing supervision
        # targets is (B,T)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                nt_logits.view(-1, nt_logits.size(-1)),
                targets.view(-1)
            )  # (B*T)

        return nt_logits, loss

    def generate(self, idx, max_new_tokens):
    # idx is of dim (B,T)
      for _ in range(max_new_tokens):
        # ensure we stay within scope (context never exceeds block_size, i.e., the context = the most recent upt-to-block_size tokens)
        idx_cond = idx[:,-self.config.context_size:]
        logits, loss = self(idx_cond) # logits is (B,T,C), loss is (B*T)
        logits = logits[:,-1,:] # (B,C)
        probs = F.softmax(logits, dim=-1) #(B,C)
        idx_next = torch.multinomial(probs, num_samples=1) #(B,1)
        idx = torch.cat((idx, idx_next), dim=1) #(B,T+1)
      return idx



In [None]:
 ## Helper functions

@torch.no_grad()
def estimate_loss(
    model,
    data,
    config,
):
  out = {}
  model.eval()
  losses = torch.zeros(eval_iters)
  for k in range(eval_iters):
    X, Y = get_batch(data, config)
    logits, loss = model(X,Y)
    losses[k] = loss.item()

  model.train()
  return losses.mean()

In [None]:
torch.manual_seed(config.random_seed)

In [None]:
model = LinearLM(config, tokenizer)
model.to(config.device)

In [None]:
print(f'This model has {sum(p.numel() for p in model.parameters())} parameters')

In [None]:
# define optimization settings
max_iters = 10000
eval_interval = 300
learning_rate = 1e-3
eval_iters = 200

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=1000)
for step in range(max_iters):
  #minibatch
  xb, yb = get_batch(train_data, config)

  #loss
  logits, loss = model(xb,yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

  if step % 100 == 0:
    train_loss = estimate_loss(
        model,
        train_data,
        config,
    )
    val_loss = estimate_loss(
        model,
        val_data,
        config,
    )
    print(f'iter {step} | train: {train_loss} | test: {val_loss}')

  scheduler.step(val_loss)

print(loss.item())

In [None]:
idx = torch.zeros((1,1), dtype=torch.long)
idx = idx.to(config.device)
print( tokenizer.decode( model.generate(idx, max_new_tokens=500)[0].tolist() ))