In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from tokenizer import Tokenizer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [5]:
with open("data.txt", "r") as f:
    data = f.read()

In [6]:
data = data.replace("\n", " ").replace("===", "").replace("==", "").replace("\\n", " ")

In [7]:
tokenizer = Tokenizer(5,data)

In [8]:
data[:100]

'Rafael Nadal Parera (born 3 June 1986) is a Spanish professional tennis player. Nadal has been ranke'

In [9]:
encoded_text = tokenizer.encode(data)

In [10]:
vocab_size = len(tokenizer.vocab)

In [11]:
vocab_size

261

In [12]:
# create a tensor of the encoded text
data = torch.tensor(encoded_text, dtype=torch.long).to(device)
data.shape, data.dtype

(torch.Size([124433]), torch.int64)

In [13]:
# 90% of the data will be used for training
n = int(0.9*len(data))
train_data = data[:n]
test_data = data[n:]

In [14]:
batch_size = 8
context_size = 32

In [15]:
train_data[:context_size+1]

tensor([ 82,  97, 102,  97, 101, 108,  32,  78,  97, 100,  97, 108,  32,  80,
         97, 114, 101, 114,  97,  32,  40,  98, 111, 114, 110,  32,  51,  32,
         74, 117, 110, 256,  49], device='cuda:0')

In [16]:
# create data loader
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, data, context_size):
        self.data = data
        self.context_size = context_size

    def __len__(self):
        return len(self.data) - self.context_size

    def __getitem__(self, idx):
        return self.data[idx:idx+self.context_size+1]
    
train_dataset = TextDataset(train_data, context_size)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TextDataset(test_data, context_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [17]:
try:
    del model
except:
    pass

In [71]:
# Create the model class
class simplemodel(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(vocab_size, vocab_size)

    def forward(self, x, targets=None):
        logits = self.embedding(x)

        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape # batch, time, channels (vocab size)
            logits = logits.view(B*T, C) # flatten the batch and time dimensions to make it easier to calculate the loss
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, x, n, temperature=1.0):
        # temperature is used to smooth the distribution
        # higher temperature will make the distribution more uniform
        # lower temperature will make the distribution more peaky
        # how to use temperature: pass in the logits and divide by the temperature before applying softmax
        with torch.no_grad():
            for i in range(n):
                logits, _ = self.forward(x)
                logits = logits[:, -1, :] # get the last time step of the logits (becomes (B,C))
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1) # convert logits to probabilities (B,C)
                next_char = torch.multinomial(probs, 1) # sample from the distribution (B,1)
                x = torch.cat([x, next_char], dim=1) # append the new character to the input (B,T+1)

        return x
    

model = simplemodel().to(device)

In [72]:
data = model.generate(torch.zeros((1, 1), dtype=torch.long).to(device), 100).reshape(-1).tolist()

In [73]:
print(tokenizer.decode(data))

 ��D��j[�,��I�(�a����dQ�`���Џ��(��"��bϘi^�����頻k�~�f�1�e��w�l|��j"��*�!e'}]R��7��0#


In [74]:
model.state_dict

<bound method Module.state_dict of simplemodel(
  (embedding): Embedding(261, 261)
)>

In [75]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [78]:
for i in range(2):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        x = batch[:, :-1].to(device)
        targets = batch[:, 1:].to(device).contiguous()
        logits, loss = model(x, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(f"epoch {i} loss {total_loss/len(train_loader)}")
    
    model.eval()
    with torch.inference_mode():
        # test loop
        total_loss = 0
        for batch in tqdm(test_loader):
            x = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device).contiguous()
            _, loss = model(x, targets)
            total_loss += loss.item()
        print(f"test loss {total_loss/len(test_loader)}")            

100%|██████████| 13995/13995 [00:36<00:00, 385.44it/s]


epoch 0 loss 3.050558552600606


100%|██████████| 1552/1552 [00:00<00:00, 2036.88it/s]


test loss 2.767654577942239


100%|██████████| 13995/13995 [00:35<00:00, 389.00it/s]
100%|██████████| 1552/1552 [00:00<00:00, 1945.68it/s]

test loss 2.778070725884634





In [79]:
data = model.generate(torch.zeros((1,1), dtype=torch.long).to(device), 1000, 0.8).reshape(-1).tolist()
print(tokenizer.decode(data))

 ��7 an 2022000131 anthalas r m ing tom Set secupeis tare misealecor Nalerdon inset at inscunth as extchert time adee bed Wighit mo He t fou00 d arch fiarerenct h 200135 th Bhit alllainnd in Narendet me 24-chount Nafourenndostostiasurice ougre t Opis al stededete, sbemampled  Togloromeathis arlo d checonao. d inaly. Nad 3 Namete suradad wond Nal ingimakagr astoutiz fing ondadourth Nanitilen serd o oraclo. Fon s, Wichd f as fin the P wh chendeetinerajous ad d Th sed the mefexi as fo me a witis wat alid rerad Ju20 ingedary ind o as hthict ad touron 204 tial mpepon the fichon allladr pal. Th ateer 2000000099800 chim beachitin Nadiond, cand ourend As adiz and. \ur onsepo Pwin the tor he malot sst adase t aurexter Opes oviovis botha Nis 8 ay nd the tifon the 200 Bou20 arin Jupiobotr Mas osqurnt 2010099997 tir wad the f the stwaisteblive ingl fime wir Rorncr that oualyefond wn th blis tical 48 opr P 15 fofes, patm mplynar bathe Sle t wer) Non, the winante ch in h fet stevi ar w Op rdourolat 

In [24]:
embeddings = nn.Embedding(5, 5)

In [25]:
logits = embeddings(torch.tensor([0,0,0,0,0], dtype=torch.long))
logits.shape
probs = F.softmax(logits, dim=-1)
probs.shape
out = torch.multinomial(probs, 1)
out.shape
out

tensor([[2],
        [2],
        [2],
        [4],
        [3]])

### Explanation
Embedding layer is a lookup table that stores a vector for each index in the input tensor. 
It stores random weights for each index in the input tensor.
These weights are learned during training to learn the best representation for each index.
The output spits out the weights for each index given in the input.
These weights are basically the logits which are converted to probabilities using softmax 
These probabilities are sampled to get the next index

We tell the model that we expect the next token given the input and the model should learn to predict the next token given the input by adjusting the weights in the embedding layer.

In [26]:
# create self attention mechanism

In [18]:
def train(model, optimizer, train_loader):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        x = batch[:, :-1].to(device)
        targets = batch[:, 1:].to(device).contiguous()
        logits, loss = model(x, targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"epoch {i} loss {total_loss/len(train_loader)}")

    model.eval()
    
def test(model, test_loader):
    with torch.inference_mode():
        total_loss = 0
        for batch in tqdm(test_loader):
            x = batch[:, :-1].to(device)
            targets = batch[:, 1:].to(device).contiguous()
            _, loss = model(x, targets)
            total_loss += loss.item()
    print(f"test loss {total_loss/len(test_loader)}")

In [19]:
class LayerNorm1d:
    # Layer Normalization for 1D data
    # eps is the epsilon value to prevent division by zero
    # momentum is the momentum value for the running mean and variance
    # gamma is the scaling parameter
    # beta is the shifting parameter
    def __init__(self, dim, eps=1e-5, momentum=0.1, training=True, device=device):
        self.eps = eps
        self.gamma = torch.ones(dim).to(device)
        self.beta = torch.zeros(dim).to(device)
        self.momentum = momentum
        self.training = training
        self.mean = torch.zeros(dim).to(device)
        self.var = torch.ones(dim).to(device)

    def __call__(self, x):
        if self.training:
            mean = x.mean(dim=-1, keepdim=True)
            var = x.var(dim=-1, keepdim=True)
            self.mean = self.momentum * mean + (1-self.momentum) * self.mean
            self.var = self.momentum * var + (1-self.momentum) * self.var
        else:
            mean = self.mean
            var = self.var

        x = (x - mean) / torch.sqrt(var + self.eps) # normalize to zero mean and unit variance
        x = x * self.gamma + self.beta
        return x
    
    def parameters(self):
        return [self.gamma, self.beta]
    

In [20]:
# create a multi-head self attention mechanism

class Head(nn.Module):
    # a single head of the multi-head self attention mechanism
    def __init__(self, embedding_size, head_size, dropout=0.1):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(embedding_size, head_size, bias=False)
        self.query = nn.Linear(embedding_size, head_size, bias=False)
        self.value = nn.Linear(embedding_size, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        key = self.key(x)
        query = self.query(x)
        value = self.value(x)

        weights = query @ key.transpose(-2, -1) / np.sqrt(C)
        weights = weights.masked_fill(torch.triu(torch.ones(T, T), diagonal=1).to(device) == 1, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)

        out = weights @ value
        return out
    
class MultiHeadAttention(nn.Module):
    # combine multiple heads into a single multi-head self attention mechanism
    def __init__(self, embedding_size, head_size, num_heads, dropout=0.1):
        super().__init__()
        self.heads = nn.ModuleList([Head(embedding_size, head_size, dropout) for _ in range(num_heads)])
        self.linear = nn.Linear(embedding_size, embedding_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.linear(out)
        out = self.dropout(out)
        return out


class FeedForward(nn.Module):
    # a simple feed forward network
    def __init__(self, embedding_size, hidden_size, dropout=0.1):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(embedding_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, embedding_size),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.ff(x)


class Block(nn.Module):
    # a single block of the transformer
    def __init__(self, embedding_size, num_heads, dropout=0.1):
        super().__init__()
        head_size = embedding_size // num_heads
        assert head_size * num_heads == embedding_size, "head_size * num_heads must equal vocab_size"
        self.attention = MultiHeadAttention(
            embedding_size, head_size, num_heads, dropout)
        self.feedforward = FeedForward(
            embedding_size, embedding_size * 4, dropout)
        # Layer norm 1 is for the attention mechanism
        self.norm1 = LayerNorm1d(embedding_size)
        # Layer norm 2 is for the feed forward network
        self.norm2 = LayerNorm1d(embedding_size)
        
    def forward(self, x):
        x = self.norm1(x)
        x = x + self.attention(x) # add the residual connection
        x = self.norm2(x)
        x = x + self.feedforward(x) # add the residual connection
        return x

class Transformer(nn.Module):
    # the full transformer model
    def __init__(self, vocab_size, embedding_size, context_size, num_heads, num_blocks, dropout=0.1):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, embedding_size)
        self.position_embedding = nn.Embedding(context_size, embedding_size)   
        self.blocks = nn.Sequential(*[Block(embedding_size, num_heads, dropout) for _ in range(num_blocks)])
        self.layer_norm = LayerNorm1d(embedding_size) # Layer norm for the output of the transformer
        self.linear = nn.Linear(embedding_size, vocab_size)

    def forward(self, x, targets=None):
        B,T = x.shape
        positions = torch.arange(T).to(x.device)
        x = self.token_embedding(x) + self.position_embedding(positions)
        x = self.blocks(x)
        x = self.layer_norm(x)
        logits = self.linear(x)

        if targets is None:
            loss = None
        else:
            B,T,C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)

        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_context = idx[:, -context_size:]
            logits, _ = self.forward(idx_context)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            idx = torch.cat([idx, next_token], dim=-1)
        return idx


In [21]:
# set the hyperparameters
embedding_size = 64
num_heads = 4
num_blocks = 4
vocab_size = len(tokenizer.vocab)
context_size = 32
dropout = 0.1

# create the model
model = Transformer(vocab_size, embedding_size, context_size, num_heads, num_blocks, dropout).to(device)
print(sum(p.numel() for p in model.parameters())/1e3, 'k parameters')

233.861 k parameters


In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [23]:
for i in range(2):
    train(model, optimizer, train_loader)
    test(model, test_loader)

  0%|          | 0/13995 [00:00<?, ?it/s]

 45%|████▌     | 6321/13995 [06:23<09:35, 13.34it/s]