## Importing modules

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-07-02 04:27:53--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.108.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-07-02 04:27:53 (20.0 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [4]:
class LayerNorm1d:

  """
  Layernorm class.
  Inspired from : 'https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py'
  Args:
    dim = dimension of residual stream.
    eps = adding small constant for numerical stability.

  """

  def __init__(self, dim, eps = 1e-5):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    # Calculate the forward pass
    xmean = x.mean(1, keepdim = True) # batch mean
    xvar = x.var(1, keepdim = True) # batch variance
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalization
    self.out = self.gamma * xhat + self.beta
    return self.out

  def parameters(self):
    return [self.gamma, self.beta]


In [5]:


class Config:

  """
  Hyper-parameters for the model
  """

  batch_size = 16
  block_size = 32   # Context length
  max_iters = 5000
  eval_interval = 100
  learning_rate = 1e-3
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
  eval_iters = 200
  n_embed = 64 # Dimension of token embedding
  n_head = 4  # Number od heads
  head_size = int(n_embed / n_head) # Input space of each attention head for first attention mechanism
  qk_dim = 8  # Dimension vector for secondary attention mechanism
  n_rules = 2 # Number of attributes
  n_layer = 4 # Number of layers
  dropout = 0.0


torch.manual_seed(1337)


<torch._C.Generator at 0x7f86441b5070>

## Data preprocessing

In [6]:
'''
Data preprocessing, Inspired from : 'https://github.com/karpathy/ng-video-lecture/blob/master/gpt.py'
'''

with open('input.txt', 'r', encoding = 'utf-8') as f:
  text = f.read()

# Building a vocab of all characters.
vocab = sorted(list(set(text)))
vocab_size = len(vocab)

# Create a tokenizer to map chars to numbers and numbers to chars
strtoint = {ch : i for i, ch in enumerate(vocab)}
inttostr = {i : ch for i, ch in enumerate(vocab)}

# Encoder takes a string and outputs a list of integers.
encoder = lambda s : [strtoint[c] for c in s]
decoder = lambda l: ''.join([inttostr[i] for i in l])

# Preparing training and test splits.
data = torch.tensor(encoder(text), dtype = torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

# Data loading
def get_batch(split):

  # generate a small batch of data of inputs x and targets y
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - Config.block_size, (Config.batch_size, ))
  x = torch.stack([data[i : i + Config.block_size] for i in ix])
  y = torch.stack([data[i + 1 : i + Config.block_size + 1] for i in ix])
  x, y = x.to(Config.device), y.to(Config.device)
  return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(Config.eval_iters)
        for k in range(Config.eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


class FCN(nn.Module):

  """a simple linear layer for attention block"""

  def __init__(self, n_embed, dropout):
    super().__init__()
    self.fcn_module = nn.Sequential(
        nn.Linear(n_embed, 4 * n_embed),
        nn.ReLU(),
        nn.Linear(4 * n_embed, n_embed),
        nn.Dropout(dropout),
    )

  def forward(self, x):
    return self.fcn_module(x)

class Compositional_attention(nn.Module):

  def __init__(self, dim, nheads, nrules, qk_dim, head_dim, context_length, dropout):

    super(Compositional_attention, self).__init__()
    self.dim = dim
    self.nheads = nheads
    self.nrules = nrules
    self.qk_dim = qk_dim
    self.head_dim = head_dim
    self.context_length = context_length

    # Defining Q, K, V for primary attention mechanism.
    self.query = nn.Linear(dim, dim)
    self.key = nn.Linear(dim, dim)
    self.value = nn.Linear(dim, self.head_dim * self.nrules)
    self.register_buffer('tril', torch.tril(torch.ones(self.context_length, self.context_length)))

    # Defining Q, k for secondary attention mechanism.
    self.query_value = nn.Linear(dim, self.qk_dim * nheads)
    self.key_value = nn.Linear(self.head_dim, self.qk_dim)

    # Final projection and dropout layer.
    self.projection = nn.Linear(dim, dim)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):

    batch_size, context_length, _= x.shape

    # Calculating Q, K, V from input as describe in equation 7 in paper.
    q = self.query(x).reshape(batch_size, self.nheads, context_length, self.head_dim)
    k = self.key(x).reshape(batch_size, self.nheads, self.head_dim, context_length)
    v = self.value(x).reshape(batch_size, self.nrules, context_length, self.head_dim).unsqueeze(1)

    # Calculating primary causal attention pattern(Search) as describe in equation 8 in paper.
    causal_attn_pattrn = torch.matmul(q, k) / (self.head_dim ** 0.5)
    causal_attn_pattrn = causal_attn_pattrn.masked_fill(self.tril[:context_length, :context_length] == 0, float('-inf'))
    causal_attn_pattrn = F.softmax(causal_attn_pattrn, dim = -1).unsqueeze(2)

    # Calculating output(Retrieval = Search * V)
    output = torch.matmul(causal_attn_pattrn, v).reshape(batch_size, context_length, self.nheads, self.nrules, self.head_dim)

    # Instantiation of Q and K for secondary attention pattern.
    q_v = self.query_value(x).reshape(batch_size, context_length, self.nheads, 1, self.qk_dim) / (self.qk_dim ** 0.5)
    k_v = self.key_value(output).reshape(batch_size, context_length, self.nheads, self.nrules, self.qk_dim)

    # Calculating value score as describe in equation 13 in paper.
    comp_score = F.softmax(torch.matmul(q_v, k_v.transpose(4, 3)), dim = -1).reshape(batch_size, context_length, self.nheads, self.nrules, 1)

    # Final Compositional score
    out = (comp_score * output).sum(dim = 3).reshape(batch_size, context_length, self.dim)
    out = self.dropout(self.projection(out))
    return out


class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, Config):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        self.sa = Compositional_attention(Config.n_embed, Config.n_head, Config.n_rules, Config.qk_dim, Config.head_size, Config.block_size, Config.dropout)
        self.ffwd = FCN(Config.n_embed, Config.dropout)
        self.ln1 = nn.LayerNorm(Config.n_embed)
        self.ln2 = nn.LayerNorm(Config.n_embed)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class LanguageModel(nn.Module):

    def __init__(self, Config):
        super().__init__()

        # Building token and positional embedding martix
        self.token_embedding_table = nn.Embedding(vocab_size, Config.n_embed)
        self.position_embedding_table = nn.Embedding(Config.block_size, Config.n_embed)


        self.blocks = nn.Sequential(*[Block(Config) for _ in range(Config.n_layer)])
        self.ln_f = nn.LayerNorm(Config.n_embed) # final layer norm
        self.lm_head = nn.Linear(Config.n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=Config.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -Config.block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

model = LanguageModel(Config)
m = model.to(Config.device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=Config.learning_rate)

for iter in range(Config.max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % Config.eval_interval == 0 or iter == Config.max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=Config.device)
print(decoder(m.generate(context, max_new_tokens=2000)[0].tolist()))






0.211041 M parameters
step 0: train loss 4.3630, val loss 4.3649
step 100: train loss 2.4564, val loss 2.4700
step 200: train loss 1.3903, val loss 1.4081
step 300: train loss 0.6807, val loss 0.7069
step 400: train loss 0.3956, val loss 0.4124
step 500: train loss 0.2821, val loss 0.3015
step 600: train loss 0.2412, val loss 0.2493
step 700: train loss 0.2094, val loss 0.2128
step 800: train loss 0.1924, val loss 0.1985
step 900: train loss 0.1823, val loss 0.1886
step 1000: train loss 0.1759, val loss 0.1759
step 1100: train loss 0.1591, val loss 0.1648
step 1200: train loss 0.1413, val loss 0.1428
step 1300: train loss 0.1191, val loss 0.1283
step 1400: train loss 0.0992, val loss 0.1027
step 1500: train loss 0.0929, val loss 0.0976
step 1600: train loss 0.0866, val loss 0.0895
step 1700: train loss 0.0873, val loss 0.0908
step 1800: train loss 0.0838, val loss 0.0851
step 1900: train loss 0.0846, val loss 0.0837
step 2000: train loss 0.0800, val loss 0.0829
step 2100: train loss 0.