# Getting started

In [2]:
import torch
import torch.nn as nn
import math
import nltk
from nltk.tokenize import word_tokenize
from torch.utils.data import Dataset, DataLoader

In [2]:
# nltk.download('punkt_tab') -> Already executed once

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device}!")

Using cpu!


# Prepare data

In [5]:
path_to_data = "data/moby_dick.txt"
all_data = []
with open(path_to_data, 'r', encoding="utf-8") as f:
    curr = []
    for line in f:
        line = line.strip()
        if line:
            curr.append(line)
        else:
            all_data.append(' '.join(curr))
            curr = []

In [6]:
sentence_word = [word_tokenize(text.lower()) for text in all_data]
sentence_word.sort(key=lambda x: len(x))
sentence_word = [sentence for sentence in sentence_word if sentence]

# Get maximum length of sentence
print(max(len(sentence) for sentence in sentence_word))

# Get smallest length of sentence
print(min(len(sentence) for sentence in sentence_word))

944
1


In [7]:
print(sentence_word[:5])

[['epilogue'], ['fore-top', '.'], ['sir', '?'], ['ahab', 'turned', '.'], ['chapter', '1.', 'loomings', '.']]


In [8]:
word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
curr = 4
for sentence in sentence_word:
    for word in sentence:
        if word not in word2idx:
            word2idx[word] = curr
            curr += 1

idx2word = [''] * len(word2idx)
for token, idx in word2idx.items():
    idx2word[idx] = token

def process_sentence(sentence):
    input = [word2idx['<SOS>']]
    for word in sentence:
        if word not in word2idx:
            word = '<UNK>'
        input.append(word2idx[word])

    input.append(word2idx['<EOS>'])
    output = input[1:] 
    return input, output

def decode(sentence):
    res = []
    for idx in sentence:
        res.append(idx2word[idx])
    
    return ' '.join(res)


def add_padding(X, y, batch_size=32):
    num_sentences = len(X)
    pad_idx = word2idx['<PAD>'] 
    for i in range(0, num_sentences, batch_size):
        idx = min(num_sentences-1, i + batch_size-1)
        max_length = len(X[idx])
        for j in range(i, idx+1):
            missing_X = max_length - len(X[j])
            
            X[j].extend([pad_idx]*missing_X)
            y[j].extend([pad_idx]*(missing_X + 1))

    return X, y

class LMDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.long), torch.tensor(self.y[idx], dtype=torch.long)

batch_size = 32
data = [process_sentence(sentence) for sentence in sentence_word]
X = [sentence[0] for sentence in data]
y = [sentence[1] for sentence in data]

X_pad, y_pad = add_padding(X, y, batch_size)
dataset = LMDataset(X_pad, y_pad)
dataloader = DataLoader(dataset, shuffle=False, batch_size=batch_size)

In [41]:
vocabulary_size = len(word2idx)
print(f"Size of vocabulary: {len(word2idx)}")

for X, y in dataloader:
    print(X.shape)
    print(y.shape)
    break

Size of vocabulary: 19625
torch.Size([32, 6])
torch.Size([32, 6])


# Architecture

In [215]:
class MultiAttentionHead(nn.Module):
    def __init__(self, num_heads, hidden_size):
        super().__init__()
        head_dim = hidden_size // num_heads

        # Weight matrices for queries, keys, values
        self.W_q = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))
        self.W_k = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))
        self.W_v = nn.Parameter(torch.randn(num_heads, hidden_size, head_dim))

        self.hidden_size = hidden_size
        self.head_dim = head_dim
        self.num_heads = num_heads

        # Output linear projection
        self.output_proj = nn.Parameter(torch.randn(hidden_size, hidden_size))

        # Normalization layer
        self.norm = nn.LayerNorm(hidden_size)

        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):        
        std = 0.02  
        
        with torch.no_grad():
            self.W_q.normal_(0, std)
            self.W_k.normal_(0, std) 
            self.W_v.normal_(0, std)
            self.output_proj.normal_(0, std)

    def forward(self, X, padding_mask):
        """
        Forward pass for multi-head attention.

        Inputs:
            X            : (batch_size, seq_len, hidden_size)
            padding_mask : (batch_size, seq_len)  -> 1 for real tokens, 0 for padding
        """
        batch_size, seq_len, _ = X.shape

        # Create padding mask
        mask_matrix = padding_mask.unsqueeze(-1) * padding_mask.unsqueeze(-2)  # (batch, seq_len, seq_len)
        pad_mask = mask_matrix.unsqueeze(1).expand(-1, self.num_heads, -1, -1)

        # Create causal mask
        causal_mask = torch.tril(torch.ones((seq_len, seq_len), device=device)).unsqueeze(0).expand(batch_size, self.num_heads, -1, -1)

        # Finally, create attention mask
        att_mask = pad_mask * causal_mask

        # Add head dimension to input
        X_heads = X.unsqueeze(1)  # (batch, 1, seq_len, hidden_size)

        # Compute Q, K, V
        Q, K, V = torch.matmul(X_heads, self.W_q), torch.matmul(X_heads, self.W_k), \
                  torch.matmul(X_heads, self.W_v)  

        # Compute attention scores
        K_transposed = K.transpose(-2, -1)
        scores = torch.matmul(Q, K_transposed)
        scaled_scores = scores / (self.head_dim ** 0.5)

        # Apply mask and softmax
        LARGE_NEG = -1e9
        masked_scores = scaled_scores.masked_fill(~(att_mask == 1), LARGE_NEG)
        att_weights = torch.softmax(masked_scores, dim=-1)

        # Attention output
        att_output = torch.matmul(att_weights, V) # (batch, num_heads, seq_len, head_dim)
        combined_heads = att_output.reshape(batch_size, seq_len, self.hidden_size)

        # Apply layer normalization and the final projection
        norm_output = self.norm(combined_heads)
        projected_output = torch.matmul(norm_output, self.output_proj)

        return projected_output # (batch_size, seq_length, embed_size)

class FFN(nn.Module):
    def __init__(self, hidden_layers=2, act='relu', hidden_size=[32, 32], input_size=32):
        super().__init__()
        layers = [nn.Linear(input_size, hidden_size[0])]

        for i in range(1, hidden_layers):
            if act == 'relu':
                layers.append(nn.ReLU())

            elif act == 'tanh':
                layers.append(nn.Tanh())
            layers.append(nn.Linear(hidden_size[i-1], hidden_size[i]))
            
                
        self.layers = nn.Sequential(*layers)
        

    def forward(self, X):
        for layer in self.layers:
            X = layer(X)

        return X

class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, seq_len, dropout=0.3):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
    
        position = torch.arange(seq_len).unsqueeze(1) # (seq_len, 1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-math.log(10000.0) / embed_size))
        pe = torch.zeros(1, seq_len, embed_size)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe) # So it goes to CPU/GPU as well

    def forward(self, X):
        """
        Arguments:
            X (batch_size, seq_len, embed_size)
        """
        X = X + self.pe
        return self.dropout(X)  

class Transformer(nn.Module):
    def __init__(self, voc_size, embed_size, num_heads, depth, pad_idx, p=0.3):
        super().__init__()
        self.voc_size = voc_size
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.depth = depth
        self.pad_idx = pad_idx
        
        # Embeddings layer
        self.embed = nn.Embedding(voc_size, embed_size)
        nn.init.xavier_uniform_(self.embed.weight)

        # Dropout layer
        self.dropout = nn.Dropout(p)

        # Attention Heads
        self.heads = nn.ModuleList([
            nn.ModuleList([
                MultiAttentionHead(num_heads, embed_size),
                nn.LayerNorm(embed_size),
                FFN(hidden_layers=2, act="relu",
                    hidden_size=[4 * embed_size, embed_size],
                    input_size=embed_size),
                nn.LayerNorm(embed_size)
            ])
            for _ in range(depth)
        ])

        # Last linear layer
        self.linear = nn.Linear(embed_size, voc_size, bias=False)
    

    def forward(self, X):
        '''
        Inputs:
            X (batch_size, seq_length)
        '''
        _, seq_len = X.shape
        padding_mask = X != self.pad_idx

        # Positional Encoding
        pe = PositionalEncoding(embed_size=self.embed_size, seq_len=seq_len).to(device)

        X = self.embed(X) # (batch_size, seq_length, embedding_size)
        X = pe(X)

        for head, norm1, FNN, norm2 in self.heads:
            # Go through the multi heads
            normed_X = norm1(X)
            X = X + self.dropout(head(normed_X, padding_mask))

            # Do the same for the FFN
            normed_X_ffn = norm2(X)
            X = X + self.dropout(FNN(normed_X_ffn))
        
        output = self.linear(X)
        
        return output

    def predict(self, output):
        return torch.argmax(output, dim=-1)

In [216]:
# Example dimensions
batch_size = 2    # number of sequences
seq_len = 5       # tokens per sequence
hidden_size = 512   # embedding dimension
h = 8 # number of heads
head = MultiAttentionHead(h, hidden_size)

# Create random input
X = torch.randn(batch_size, seq_len, hidden_size)
att_mask = torch.randint(0, 2, (batch_size, seq_len)) == 1
output = head(X, att_mask)

print("Shape of X:", X.shape)
print(output.shape)

Shape of X: torch.Size([2, 5, 512])
torch.Size([2, 5, 512])


In [217]:
X_test = torch.tensor(X_pad[:10])
transformer = Transformer(voc_size=len(word2idx), embed_size=512, num_heads=8, depth=12,\
                          pad_idx=word2idx['<PAD>'], p=0.3)

output = transformer(X_test)

prediction = transformer.predict(output)
print(f"Output shape: {prediction.shape}")

Output shape: torch.Size([10, 6])


In [218]:
for batch in prediction:
    print(decode(batch))

rides thank i hear fond boarding
pulsations frail whirls enveloping clove penalties
wheezing lists retarding terribleness favourable iv
lathering skies transfixedly outward dined adventurously
puddings entablatures shrinked spirits independent life-buoy
unfort archangelical career plug contingency ineffably
disport mocks waked guineas flight battering-ram
asylum conquered buttress heel—it marvelled hasn
azores simoon fronting quahogs truths boarding
sea—mark hooroosh dance mounts raise cheeseries


In [220]:
def train_one_epoch(model, train_loader, optimizer, pad_idx):
    model.train()
    torch.autograd.set_detect_anomaly(True)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    total_loss = 0.0
    for i, (batch_X, batch_y) in enumerate(train_loader):
        if i % 30 == 0 or i == len(train_loader) - 1:
            print(f"{(i / len(train_loader)) * 100:.2f}% done")

        optimizer.zero_grad()

        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device) # (batch_size, seq_len)
        
        outputs = model(batch_X) # (batch_size, seq_len, voc_size)
        outputs = outputs.permute(0, 2, 1) # (batch_size, voc_size, seq_len) -> expected size for nn.CrossEntropy

        loss = criterion(outputs, batch_y)
        loss.backward()


        nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

        optimizer.step()

        total_loss += loss.item() 


    return total_loss / len(train_loader)

In [None]:
num_epochs = 10
pad_idx = word2idx['<PAD>']
model = Transformer(voc_size=len(word2idx), embed_size=512, num_heads=8, depth=12,\
                    pad_idx=word2idx['<PAD>'], p=0.3)
opt = torch.optim.Adam(model.parameters(), lr=1e-10)

for i in range(1, num_epochs+1):
    print(f"Epoch {i}")
    loss = train_one_epoch(model, dataloader, opt, pad_idx)
    print(f"Loss: {loss}")

Epoch 1
0.00% done
37.04% done


In [None]:
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

print(input.shape)
print(target.shape)

torch.Size([3, 5])
torch.Size([3])
