# Getting started

In [None]:
import torch
import torch.nn as nn
import math
import nltk
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader

In [4]:
# nltk.download('punkt_tab') -> Already executed once

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}!")

Using cuda!


# Prepare data

In [None]:
path_to_data = "data/moby_dick.txt"
all_data = []
with open(path_to_data, 'r', encoding="utf-8") as f:
    curr = []
    for line in f:
        line = line.strip()
        if line:
            curr.append(line)
        else:
            all_data.append(' '.join(curr))
            curr = []

sentence_word = [word_tokenize(text.lower()) for text in all_data]
sentence_word.sort(key=lambda x: len(x))
sentence_word = [sentence for sentence in sentence_word if sentence]

# Get maximum length of sentence
max_seq_len = max(len(sentence) for sentence in sentence_word)
print(max_seq_len)

# Get smallest length of sentence
print(min(len(sentence) for sentence in sentence_word))

944
1


In [None]:
word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
curr = 4
for sentence in sentence_word:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = curr
            curr += 1

idx2word = [''] * len(word2idx)
for token, idx in word2idx.items():
    idx2word[idx] = token

def process_sentence(sentence):
    input = [word2idx['<SOS>']]
    for word in sentence:
        if word not in word2idx:
            word = '<UNK>'
        input.append(word2idx[word])

    input.append(word2idx['<EOS>'])
    output = input[1:] 
    return input, output

def decode(sentence):
    res = []
    for idx in sentence:
        res.append(idx2word[idx])
    
    return ' '.join(res)


def add_padding(X, y, batch_size=32):
    num_sentences = len(X)
    pad_idx = word2idx['<PAD>'] 
    for i in range(0, num_sentences, batch_size):
        idx = min(num_sentences-1, i + batch_size-1)
        max_length = len(X[idx])
        for j in range(i, idx+1):
            missing_X = max_length - len(X[j])
            
            X[j].extend([pad_idx]*missing_X)
            y[j].extend([pad_idx]*(missing_X + 1))

    return X, y

class LMDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 16

# Train dataloader
batch_size = 16
data = [process_sentence(sentence) for sentence in sentence_word]
X = [sentence[0] for sentence in data]
y = [sentence[1] for sentence in data]

X_pad, y_pad = add_padding(X, y, batch_size)
dataset = LMDataset(X_pad, y_pad)
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size)

In [None]:
vocabulary_size = len(word2idx)
print(f"Size of vocabulary: {len(word2idx)}")

for X, y in dataloader:
    print(X.shape)
    print(y.shape)
    break

Size of vocabulary: 19625

torch.Size([16, 6])
torch.Size([16, 6])

torch.Size([16, 138])
torch.Size([16, 138])


# Architecture

In [16]:
class MultiAttentionHead(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        head_dim = hidden_size // num_heads

        # Weight matrices for queries, keys, values
        self.W_q = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))
        self.W_k = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))
        self.W_v = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))

        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.num_heads = num_heads

        # Store the key and values
        self.past_key = None 
        self.past_value = None 

        # Output linear projection
        self.output_proj = nn.Parameter(torch.randn(hidden_size, hidden_size))

        # Normalization layer
        self.norm = nn.LayerNorm(hidden_size)

        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):        
        std = 0.02  
        
        with torch.no_grad():
            self.W_q.normal_(0, std)
            self.W_k.normal_(0, std) 
            self.W_v.normal_(0, std)
            self.output_proj.normal_(0, std)

    def forward(self, X, padding_mask, auto_reg=False):
        """
        Forward pass for multi-head attention.

        Inputs:
            X            : (batch_size, seq_len, hidden_size)
            padding_mask : (batch_size, seq_len)  -> 1 for real tokens, 0 for padding
        """
        batch_size, seq_len, _ = X.shape
        device = X.device 
        X_heads = X.unsqueeze(1)

        if not auto_reg:
            # Create padding mask
            mask_matrix = padding_mask.unsqueeze(-1) * padding_mask.unsqueeze(-2)
            pad_mask = mask_matrix.unsqueeze(1).expand(-1, self.num_heads, -1, -1)

            # Create causal mask
            causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)

            # Combine masks
            att_mask = pad_mask * causal_mask

            # Add head dimension to input
            X_heads = X.unsqueeze(1)  # (batch, 1, seq_len, hidden_size)

            # Compute Q, K, V for full sequence
            Q = torch.matmul(X_heads, self.W_q)  # (batch, num_heads, seq_len, head_dim)
            K = torch.matmul(X_heads, self.W_k)  # (batch, num_heads, seq_len, head_dim)
            V = torch.matmul(X_heads, self.W_v)  # (batch, num_heads, seq_len, head_dim)
            
        else:
    
            # Compute Q for current token only
            Q = torch.matmul(X_heads, self.W_q)  # (batch, num_heads, 1, head_dim)
            
            if self.past_key is None:
                # First token - compute K,V for the entire sequence so far
                K = torch.matmul(X_heads, self.W_k)
                V = torch.matmul(X_heads, self.W_v)
            
            else:
                # Subsequent tokens - compute K,V only for current token and concatenate
                curr_K = torch.matmul(X_heads, self.W_k)  # (batch, num_heads, 1, head_dim)
                curr_V = torch.matmul(X_heads, self.W_v)
                
                # Concatenate with cached values
                K = torch.cat([self.past_key, curr_K], dim=2)
                V = torch.cat([self.past_value, curr_V], dim=2)

            self.past_key = K
            self.past_value = V


        # Compute attention scores
        K_transposed = K.transpose(-2, -1)  # (batch, num_heads, head_dim, kv_seq_len)
        scores = torch.matmul(Q, K_transposed)  # (batch, num_heads, q_seq_len, kv_seq_len)
        scaled_scores = scores / (self.head_dim ** 0.5)

        # Apply mask and softmax
        LARGE_NEG = -1e9
        if not auto_reg:
            scaled_scores = scaled_scores.masked_fill(~(att_mask == 1), LARGE_NEG)
        
        att_weights = torch.softmax(scaled_scores, dim=-1)

        # Attention output
        att_output = torch.matmul(att_weights, V)  # (batch, num_heads, q_seq_len, head_dim)
        
        if auto_reg:
            combined_heads = att_output.reshape(batch_size, 1, self.hidden_size)
        else:
            combined_heads = att_output.reshape(batch_size, seq_len, self.hidden_size)

        # Apply layer normalization and the final projection
        norm_output = self.norm(combined_heads)
        projected_output = torch.matmul(norm_output, self.output_proj)

        return projected_output
    
    def clear_cache(self):
        self.past_key, self.past_value = None, None

class FFN(nn.Module):
    def __init__(self, hidden_layers=2, act='relu', hidden_size=[32, 32], input_size=32):
        super().__init__()
        layers = [nn.Linear(input_size, hidden_size[0])]

        for i in range(1, hidden_layers):
            if act == 'relu':
                layers.append(nn.ReLU())

            elif act == 'tanh':
                layers.append(nn.Tanh())
            layers.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            
                
        self.layers = nn.Sequential(*layers)
        

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)

        return X

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_seq_len, dropout=0.3):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
    
        position = torch.arange(max_seq_len).unsqueeze(1) # (seq_len, 1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0) / embed_size))
        pe = torch.zeros(1, max_seq_len, embed_size)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term) # (1, max_seq_len, embed_size)
        self.register_buffer('pe', pe) # So it goes to CPU/GPU as well

    def forward(self, X):
        """
        Arguments:
            X (batch_size, seq_len, embed_size)
        """
        X = X + self.pe[:, :X.shape[1],:]
        return self.dropout(X)  

class Transformer(nn.Module):
    def __init__(self, voc_size, embed_size, num_heads, depth, pad_idx, p=0.3):
        super().__init__()
        self.voc_size = voc_size
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.depth = depth
        self.pad_idx = pad_idx
        
        # Embeddings layer
        self.embed = nn.Embedding(voc_size, embed_size)
        nn.init.xavier_uniform_(self.embed.weight)

        # Positional Encoding
        self.pe = PositionalEncoding(embed_size=self.embed_size, max_seq_len=1000).to(device)

        # Dropout layer
        self.dropout = nn.Dropout(p)

        # Attention Heads
        self.heads = nn.ModuleList([
            nn.ModuleList([
                MultiAttentionHead(num_heads, embed_size),
                nn.LayerNorm(embed_size),
                FFN(hidden_layers=2, act="relu",
                    hidden_size=[4 * embed_size, embed_size],
                    input_size=embed_size),
                nn.LayerNorm(embed_size)
            ])
            for _ in range(depth)
        ])

        # Last linear layer
        self.linear = nn.Linear(embed_size, voc_size, bias=False)
    

    def forward(self, X, auto_reg=False):
        '''
        Inputs:
            X (batch_size, seq_length)
        '''
        _, seq_len = X.shape
        padding_mask = X != self.pad_idx

        # Positional Encoding

        X = self.embed(X) # (batch_size, seq_length, embedding_size)
        X = self.pe(X)


        for head, norm1, ffn, norm2 in self.heads:
            # Go through the multi heads
            X = self.dropout(norm1(head(X, padding_mask, auto_reg) + X))

            # Do the same for the FFN
            X = self.dropout(norm2(ffn(X) + X))
        
        output = self.linear(X)
        
        return output
    
    def generate_from_prompt(self, prompt, sample='p', max_size=1000):
        self.eval() # We aren't training anymore
        self.clear_memo() # Clear hashes
        inputs, _ = process_sentence(prompt)
        inputs = inputs[1:-1] # Remove <SOS> and <EOS>
        print([idx2word[input_] for input_ in inputs])
        curr = len(inputs)
        while curr < max_size:
            X = torch.tensor(inputs[-1], dtype=torch.long).unsqueeze(0).unsqueeze(1).to(device) # Add batch size 1
            with torch.no_grad():
                logits = self(X, auto_reg=True)
        
            # Sample from logits
            match sample:
                case 'k':
                    token = self.top_k_sample(logits)[0, 0].item()
                
                case 'greedy':
                    token = self.greedy_sample(logits)[0, 0].item()
                
                case 'p':
                    token = self.top_p_sample(logits)[0, 0].item()
        

            # See if end of sentence
            if idx2word[token] == '<EOS>':
                break
            
            curr += 1
            inputs.append(token)
        
        return decode(inputs)

    def clear_memo(self):
        for module in self.modules():
            if isinstance(module, MultiAttentionHead):
                module.clear_cache()


    def greedy_sample(self, output):
        with torch.no_grad():
            return torch.argmax(output, dim=-1)


    def top_k_sample(self, output, k=50):
        '''
        Inputs:
            output (batch_size, seq_len, voc_size)
        '''
        batch_size, seq_len, _ = output.shape 
        res = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)
        with torch.no_grad():
            probabilities = torch.softmax(output, dim=-1)
            values, indices = torch.topk(probabilities, k=k, dim=-1)
            batch_idx = torch.arange(batch_size)
            for pos in range(seq_len): 
                sample = torch.multinomial(values[:, pos, :], num_samples=1).squeeze(1) # (batch_size)
                res[:, pos] = indices[batch_idx, pos, sample]
        
        return res

    def top_p_sample(self, output, p=0.9):
        '''
        Inputs:
            output (batch_size, seq_len, voc_size)
        '''
        batch_size, seq_len, _ = output.shape
        res = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)

        with torch.no_grad():
            probs = torch.softmax(output, dim=-1)  # (batch, seq_len, vocab)
            batch_idx = torch.arange(batch_size)
            sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
            cum_probs = sorted_probs.cumsum(dim=-1)
            mask = cum_probs > p
            mask[..., 1:] = mask[..., :-1].clone()  # keep at least first token
            sorted_probs[mask] = 0.0
            sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
            for pos in range(seq_len):
                sample = torch.multinomial(sorted_probs[:, pos, :], num_samples=1).squeeze(-1)
                res[:, pos] = sorted_idx[batch_idx, pos, sample]

        return res

In [17]:
# Example dimensions
batch_size = 2    # number of sequences
seq_len = 5       # tokens per sequence
hidden_size = 512   # embedding dimension
h = 8 # number of heads
head = MultiAttentionHead(h, hidden_size).to(device)

# Create random input
X = torch.randn(batch_size, seq_len, hidden_size).to(device)
att_mask = (torch.randint(0, 2, (batch_size, seq_len)) == 1).to(device)
output = head(X, att_mask)

print("Shape of X:", X.shape)
print(output.shape)

Shape of X: torch.Size([2, 5, 512])
torch.Size([2, 5, 512])


In [37]:
transformer = Transformer(voc_size=len(word2idx), embed_size=256, num_heads=4, depth=6,\
                    pad_idx=word2idx['<PAD>'], p=0.35).to(device)

state = torch.load("models/overall_best.pth", map_location=device, weights_only=True)
transformer.load_state_dict(state)

<All keys matched successfully>

In [38]:
prompt = "Call me"
ans = transformer.generate_from_prompt(prompt.lower().split(), sample='p')
print(ans)

['call', 'me']
call me by alone called are of , ; this , the , particulars all ” “ with upon but another over fellow be great on other the the , having that , in ? , bamboozingly be more most s , ‘ ? way been of the ain instant turned : he saw monsieur , turned the now and upon the sea , didn , the still , them is quite , chase sailed the , shall he swings landlady the and ! service the began complimented noticed perhaps “ such and him at christianity assist . ship of . heartily some into never shocks over if thunder get nothing being trouble as took at it ahab . swelling tell much good hemp to beneath his levelled will new comprehend and afore ; they but that on half and , this , men that sir the corners , lad caulk ‘ him one to against lower the the to his an the if the matter the the . the can immediately and _that_ fast-fish his said . moment is for i have rising were no , , towards friend bowsman at beams ; are but , blinding all “ ’ i , his , not “ astern , we born end harpooneer

In [39]:
X_test = torch.tensor(X_pad[:10]).to(device)
#transformer = Transformer(voc_size=len(word2idx), embed_size=512, num_heads=8, depth=12,\
                          #pad_idx=word2idx['<PAD>'], p=0.3).to(device)

output = transformer(X_test)

prediction_k = transformer.top_k_sample(output)
prediction_greedy = transformer.greedy_sample(output)
prediction_p = transformer.top_p_sample(output)
print(f"Output shape of k sampling: {prediction_k.shape}")
print(f"Output shape of greedy sampling: {prediction_greedy.shape}")
print(f"Output shape of greedy sampling: {prediction_p.shape}")

Output shape of k sampling: torch.Size([10, 6])
Output shape of greedy sampling: torch.Size([10, 6])
Output shape of greedy sampling: torch.Size([10, 6])


In [40]:
print("Text for p sampling:")
for batch in prediction_p:
    print(decode(batch))

print("\nText for k sampling:")
for batch in prediction_k:
    print(decode(batch))

print("\nText for greedy sampling:")

for batch in prediction_greedy:
    print(decode(batch))

Text for p sampling:
s is , ship whale one
and , sir from talk seeva
moment advanced that white . his
chapter funeral azores ; . saw
chapter , dick representing certain ,
chapter he and here in tangles
was could , . disobey you
manx low middle beheld frock.—ahab then
“ for then could of to
as and night and , of

Text for k sampling:
the , in ; the ,
the and was no all in
, that ! , but and
“ the of or and that
“ man ’ ! . them
chapter i and with all ,
chapter the and , out the
“ , a of , the
“ and , from in of
chapter a ; at and and

Text for greedy sampling:
, , the , the ,
, , the , the ,
, , the , the ,
chapter , , , the ,
chapter , , , the ,
chapter , , , the ,
chapter , , , the ,
chapter , , , the ,
chapter , , , the ,
chapter , , , the ,


In [33]:
def train_one_epoch(model, train_loader, optimizer, scheduler, criterion):
    model.train()
    total_loss = 0.0
    for batch_X, batch_y in train_loader:
        
        optimizer.zero_grad()

        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device) # (batch_size, seq_len)
        
        outputs = model(batch_X) # (batch_size, seq_len, voc_size)
        outputs = outputs.permute(0, 2, 1) # (batch_size, voc_size, seq_len) -> expected size for nn.CrossEntropy

        loss = criterion(outputs, batch_y)
        loss.backward()


        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() 

    scheduler.step()

    return total_loss / len(train_loader.dataset)

def evaluate(model, val_loader, criterion):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch_X, batch_y in val_loader:

            batch_X = batch_X.to(device)
            batch_y = batch_y.to(device) # (batch_size, seq_len)
            
            outputs = model(batch_X) # (batch_size, seq_len, voc_size)
            outputs = outputs.permute(0, 2, 1) # (batch_size, voc_size, seq_len) -> expected size for nn.CrossEntropy

            loss = criterion(outputs, batch_y)

            total_loss += loss.item() 

    return total_loss / len(val_loader.dataset)

In [34]:
def setup_transformer_training(model, total_steps, warmup_steps):
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=5e-4,             
        weight_decay=0.02 
    )

    # Learning rate schedule with warmup + linear decay to zero
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda step: min((step + 1) / warmup_steps, 
                                   max(0.0, (total_steps - step) / max(1, total_steps - warmup_steps)))
    )

    return optimizer, scheduler

In [None]:
num_epochs = 100
warmup_epochs = 10
patience = 15
pad_idx = word2idx['<PAD>']
model = Transformer(voc_size=len(word2idx), embed_size=256, num_heads=4, depth=6,\
                    pad_idx=word2idx['<PAD>'], p=0.35).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx, label_smoothing=0.1)

opt, scheduler = setup_transformer_training(model, num_epochs, warmup_epochs)

best_loss = float('inf')
best_model_state = None
curr = 0

for epoch in range(1, num_epochs+1):
    print("-"*50)
    print(f"Epoch {epoch}")
    train_loss = train_one_epoch(model, dataloader, opt, scheduler, criterion)
    val_loss = evaluate(model, dataloader, criterion)
    print(f"Train Loss: {train_loss:.5f}")
    print(f"Val Loss: {val_loss:.5f}")
    
    # Early stopping logic
    if val_loss < best_loss:
        best_loss = val_loss
        best_model_state = model.state_dict()
        print(f"Found better model at epoch {epoch}. Saving...")
        torch.save(best_model_state, "models/best_model.pth")
        curr = 0
    else:
        curr += 1
        if curr == patience:
            print("No more patience. Ending Training Loop...")
            break

    print(f"Best Val Loss: {best_loss:.5f}")
    print("-"*50)

# If it ends the training loop
torch.save(best_model_state, "models/overall_best.pth")

--------------------------------------------------
Epoch 1
Train Loss: 0.56649
Val Loss: 0.50278
Found better model at epoch 1. Saving...
Best Val Loss: 0.50278
--------------------------------------------------
--------------------------------------------------
Epoch 2
Train Loss: 0.45088
Val Loss: 0.45538
Found better model at epoch 2. Saving...
Best Val Loss: 0.45538
--------------------------------------------------
--------------------------------------------------
Epoch 3
Train Loss: 0.42906
Val Loss: 0.45486
Found better model at epoch 3. Saving...
Best Val Loss: 0.45486
--------------------------------------------------
--------------------------------------------------
Epoch 4
Train Loss: 0.42742
Val Loss: 0.45482
Found better model at epoch 4. Saving...
Best Val Loss: 0.45482
--------------------------------------------------
--------------------------------------------------
Epoch 5
Train Loss: 0.42770
Val Loss: 0.45471
Found better model at epoch 5. Saving...
Best Val Loss: