In [None]:
# Downloading data
import wget
url = 'https://github.com/fpaupier/RapLyrics-Scraper/raw/master/lyrics_US/Nas_lyrics.txt'
wget.download(url, 'Nas_lyrics.txt')

In [None]:
with open('Nas_lyrics.txt', 'r', encoding='utf-8') as f:
    text = f.read()
print("dataset length", len(text)) 

In [None]:
print(text[:1000])

In [None]:
# Unique characters in the file
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print('Number of unique characters: ', vocab_size)

In [None]:
# create a mapping from characters to integers(indices)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of character mapped integers of that string
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("revolution"))
print(decode(encode("revolution")))

In [None]:
# encoding the entire text into a tensor
# !pip install torch
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.type)
print(data[:1000]) 

In [7]:
# splitting the data into training and validation sets
n = int(0.8*len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
# this helps the transformer to understand the sequence contexts of a single chunk or multiple character chunks
# limiting the block size to 8 makes the computation less expensive 
block_size = 256
# this contains 8 individual examples of sequences for eg: 
# 0 is followed by 50
# 0, 50 is followed by 56
# 0, 50, 56 is followed by 57 and so on
train_data[:block_size + 1]

In [86]:
import torch
torch.manual_seed(1234)
batch_size = 64
block_size = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_batch(split):
    data = train_data if split == "train" else val_data

    offsets = torch.randint(0, len(data) - block_size, (batch_size,))

    x = torch.stack([data[offset:offset+block_size] for offset in offsets])
    y = torch.stack([data[offset+1 : offset + block_size + 1] for offset in offsets])
    x,y = x.to(device), y.to(device)

    return x,y

# xb, yb = get_batch("train")
# print('inputs:')
# print(xb.shape)
# print(xb)
# print('targets:')
# print(yb.shape)
# print(yb)

In [None]:
for batch in range(batch_size):
    for t in range(block_size):
        # tensor xb will input into the transformer for (each row in a tensor) simultaneous processing
        context = xb[batch, :t + 1]
        target = yb[batch,t]
        print(f"context: {context.tolist()} target: {target}")

In [124]:
n_embed = 8
n_head = 6
n_layer = 3
dropout = 0.2

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1234)

class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(block_size, n_embed)
        # self.sa_heads = MultiHeadAttention(4, n_embed//4)
        # self.ffwd = FeedFoward(N-n_embed) 
        # self.blocks = nn.Sequential(
        #     Block(n_embed, n_embed//4),
        #     Block(n_embed, n_embed//4),
        #     Block(n_embed, n_embed//4),
        #     nn.LayerNorm(n_embed),
        # )
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets = None):

        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device = device))
        x = tok_emb + pos_emb
        # x = self.sa_heads(x)
        # x = self.ffwd(x)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T , C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    # take a B,T and convert it into a B,T+1 and further on
    def generate(self, idx, max_new_tokens):
        # idx is (B,T)
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, loss = self(idx_cond) # this will call the forward function
            logits = logits[:,-1,:] # keep only the logits for the last time step because those are the predictions for the next time step
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1): from each batch one output
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

model = BigramLanguageModel()
model.to(device)
# logits, loss = model(xb, yb)


# this idx is a 1 by 1 tensor holding a zero, this is gonna kick off the generation
# print(decode(model.generate(idx = torch.zeros((1,1), dtype = torch.long), max_new_tokens = 100)[0].tolist()))

# the reason for the garbage output is because the model is not trained yet

## Model Training

In [152]:
# model optimizer 
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

In [None]:
# batch_size = 32
for steps in range(10000):
    xb, yb = get_batch("train")
    logits, loss = model(xb, yb)

    # very crucial to not have perivious gradients add up and affect current gradients
    optimizer.zero_grad()

    loss.backward()
    optimizer.step()
    
print(loss.item())

In [None]:
print(decode(model.generate(idx = torch.zeros((1,1), dtype = torch.long), max_new_tokens = 500)[0].tolist()))

Up until now, all the generation that is being done is by keeping in consideration only the most recent/previous token. 
The contextual history of all tokens that have come so far is not being considered.

Now we are gonna take the context of the tokens occuring before to make predictions about what comes next.

## Writing Self-Attention:

In the case of a masked self attention block, information only flows from the previous context to the current timestep. A current token will only interact with the tokens that came before it and not after it. 

Consider token 6 for example, this will b e the 6th time channel in the tensor, this token only needs to interact with the tokens from 5th 4th 3rd ... steps. Averaging those will result in 'the 6th token in context to history'. 

We use a lower triangular matrix to leverage matrix multiplication to create a mean for the current token containing the information regarding how much previous tokens influence the current one.

In case of self attention, the key, query and values are coming from the same source. In other cases the keys and queries might be coming from different sources.

In [None]:
torch.manual_seed(1234)
B, T, C = 4, 8 ,32
x = torch.randn(B,T,C)

head_size = 16
# bias False so the weights are fixed 
key = nn.Linear(C, head_size, bias = False)
query = nn.Linear(C, head_size, bias = False)
value = nn.Linear(C, head_size, bias = False) 

k = key(x) # B,T,16
q = query(x) # B,T,16

wei = q @ k.transpose(-2,-1) # B,T,16 @ B,16,T == B,T,T

tril = torch.tril(torch.ones(T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x)
out = wei @ v

out.shape

In [161]:
class Head(nn.Module):
    # one head of self-attention 

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [136]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # out = self.proj(out)
        out = self.dropout(self.proj(out))
        # return out
        return out

In [144]:
class FeedForward(nn.ModuleList):
    # a simple linear layer followed by a non-linearity

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )
    
    def forward(self, x):
        return self.net(x)

Different from the implementation as seen in the paper, the layernorm here is being applied before being sent into self attention and feed forward network.

this is ***pre norm formulation***

In [145]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [None]:
print(decode(m.generate(context, max_new_tokens=1000)[0].tolist()))