In [15]:
import torch
from torch import nn
from torch import functional as F
from fancy_einsum import einsum
import torch.utils.data as data

In [10]:
!pip install rotary-embedding-torch
!pip install fancy_einsum

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fancy_einsum
  Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)
Installing collected packages: fancy_einsum
Successfully installed fancy_einsum-0.0.3


In [13]:
# data

In [50]:
class PreviousTokenDataset(data.Dataset):
    def __init__(self, n_examples, n_tokens):
      self.samples = []
      for _ in range(n_examples):
        r = torch.randint(high=10, size=(n_tokens,))
        label = r[-2]
        self.samples.append((r, label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        return self.samples[index]

def create_dataloaders(batch_size, n_examples, train_val_split, n_tokens, num_workers=1):
    train_dataset = PreviousTokenDataset(n_examples, n_tokens)
    val_dataset = PreviousTokenDataset(int(n_examples * train_val_split), n_tokens)

    train_loader = data.DataLoader(train_dataset, batch_size=batch_size,
                                   shuffle=True, num_workers=num_workers)
    val_loader = data.DataLoader(val_dataset, batch_size=batch_size,
                                 shuffle=False, num_workers=num_workers)

    return train_loader, val_loader

batch_size = 64
num_workers = 1
n_examples = 1000
train_val_split = 0.2
n_tokens = 3

train_loader, val_loader = create_dataloaders(batch_size, n_examples, train_val_split, n_tokens)


In [48]:
def test_previous_token_dataset():
  batch_size = 64
  num_workers = 1
  n_examples = 1000
  train_val_split = 0.2
  n_tokens = 3

  train_loader, val_loader = create_dataloaders(batch_size, n_examples, train_val_split, n_tokens)

  assert len(train_loader.dataset) == n_examples, "Training dataset size is incorrect."

  for inputs, labels in train_loader:
      assert inputs.size(1) == n_tokens, f"Input tensor has wrong number of tokens: {inputs.size(1)}"
      assert inputs.size(0) == batch_size or inputs.size(0) == len(train_loader.dataset) % batch_size, "Batch size is incorrect."
      print(labels.shape)
      print(inputs.shape)
      assert labels[0] == inputs[0][-2]
      break

  for inputs, labels in val_loader:
      assert inputs.size(1) == n_tokens, f"Input tensor has wrong number of tokens: {inputs.size(1)}"
      assert inputs.size(0) == batch_size or inputs.size(0) == len(val_loader.dataset) % batch_size, "Batch size is incorrect."
      break

test_previous_token_dataset()

torch.Size([64])
torch.Size([64, 3])


In [None]:

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

In [9]:
n_embd = 64
num_heads = 4
head_size = n_embd / num_heads
dropout = 0.8
vocab_size = 128
block_size = 3

In [3]:
# GPT Neo-X Rotary Implementation
class Rotary(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_dim=1):
        seq_len = x.shape[seq_dim]
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[:, None, None, :]
            self.sin_cached = emb.sin()[:, None, None, :]
        return self.cos_cached, self.sin_cached


# rotary pos emb helpers:

def rotate_half(x):
    x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
    return torch.cat(
        (-x2, x1), dim=x1.ndim - 1
    )  # dim=-1 triggers a bug in torch < 1.8.0


@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)



In [4]:
# Train a 1L attention-only transformer with rotary embeddings to predict the previous token

In [11]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # input of size (batch, time-step, channels)
        # output of size (batch, time-step, head size)
        B,T,C = x.shape
        k = self.key(x)   # (B,T,hs)
        q = self.query(x) # (B,T,hs)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,hs)
        out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class Unembed(nn.Module):
    def __init__(self, init_range=1):
        super().__init__()
        self.W_U = nn.Parameter(torch.empty((n_embd, vocab_size)))
        nn.init.normal_(self.W_U, std=init_range)
        self.b_U = nn.Parameter(torch.zeros((vocab_size), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U
        return logits

In [53]:
from torch.nn.modules.activation import MultiheadAttention
# https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
class RotaryAttention(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, n_embd)
    self.rotary_emb = Rotary(n_embd)
    self.attention = MultiheadAttention(num_heads, head_size, dropout)
    self.unembed = Unembed()

  def forward(self, idx):
    B, T = idx.shape
    tok_emb = self.token_embedding(idx) # (B,T,C)
    pos_emb = self.rotary_emb(idx)
    x = tok_emb + pos_emb
    x += self.attention(x)
    return self.unembed(x)


In [54]:
model = RotaryAttention()

AssertionError: ignored

In [51]:

for inputs, batch in train_loader:
  pass
