In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super().__init__()
        self.scale = d_k ** 0.5

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch_size, num_heads, seq_len, d_k)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, H, L, L)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = F.softmax(scores, dim=-1)  # (B, H, L, L)
        output = torch.matmul(attention_weights, V)    # (B, H, L, d_k)
        return output

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_embedding, num_heads):
        super().__init__()
        assert d_embedding % num_heads == 0
        self.d_k = d_embedding // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_embedding, d_embedding)
        self.W_k = nn.Linear(d_embedding, d_embedding)
        self.W_v = nn.Linear(d_embedding, d_embedding)
        self.W_o = nn.Linear(d_embedding, d_embedding)

        self.attention = ScaledDotProductAttention(self.d_k)

        self.norm = nn.LayerNorm(d_embedding)

    def forward(self, x, mask=None):
        x_input = x
        x = self.norm(x)

        B, L, d_embedding = x.size()  # Batch, Sequence Length, Embedding Dim
        H = self.num_heads

        # Linear projections
        Q = self.W_q(x).view(B, H, L, self.d_k)  # (B, H, L, d_k)
        K = self.W_k(x).view(B, H, L, self.d_k)
        V = self.W_v(x).view(B, H, L, self.d_k)

        # Apply attention
        context = self.attention(Q, K, V, mask)  # (B, H, L, d_k)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(B, L, d_embedding)  # (B, L, d_embedding)

        # Final linear projection
        output = self.W_o(context)  # (B, L, d_embedding)

        # Add (& pre-Norm)
        #my preference is to do pre-norm for better stabiliy, even though the original paper used post-norm
        output = x_input + output
        return output


In [4]:
class MultiHeadCrossAttention(nn.Module):
    def __init__(self, d_embedding, num_heads):
        super().__init__()
        assert d_embedding % num_heads == 0
        self.d_k = d_embedding // num_heads
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_embedding, d_embedding)
        self.W_k = nn.Linear(d_embedding, d_embedding)
        self.W_v = nn.Linear(d_embedding, d_embedding)
        self.W_o = nn.Linear(d_embedding, d_embedding)

        self.attention = ScaledDotProductAttention(self.d_k)

        self.norm = nn.LayerNorm(d_embedding)

    def forward(self, x_decoder, x_encoder, mask=None):
        assert x_decoder.size() == x_encoder.size() #x and x_encoder must have the same size
        x_input = x_decoder
        x_decoder = self.norm(x_decoder)
        
        B, L, d_embedding = x_decoder.size()  # Batch, Sequence Length, Embedding Dim
        H = self.num_heads

        # Linear projections
        Q = self.W_q(x_encoder).view(B, H, L, self.d_k)  # (B, H, L, d_k)
        K = self.W_k(x_encoder).view(B, H, L, self.d_k)
        V = self.W_v(x_decoder).view(B, H, L, self.d_k)

        # Apply attention
        context = self.attention(Q, K, V, mask)  # (B, H, L, d_k)

        # Concatenate heads
        context = context.transpose(1, 2).contiguous().view(B, L, d_embedding)  # (B, L, d_embedding)

        # Final linear projection
        output = self.W_o(context)  # (B, L, d_embedding)

        output = output + x_input
        return output


In [5]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, d_embedding, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_embedding, d_ff)
        self.linear2 = nn.Linear(d_ff, d_embedding)
        self.activation = nn.ReLU()
        self.norm = nn.LayerNorm(d_embedding)
    
    def forward(self, x):
        x_input = x
        x = self.norm(x)
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = x_input + x
        return x

In [6]:
class Encoder(nn.Module):
    def __init__(self, d_embedding, num_heads, d_ff, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([
            MultiHeadAttention(d_embedding, num_heads),
            FeedForwardNetwork(d_embedding, d_ff)
        ] * num_layers)
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

In [7]:
class Decoder(nn.Module):
    def __init__(self, encoder, d_embedding, num_heads, d_ff, num_layers):
        super().__init__()
        self.encoder = encoder
        self.norm = nn.LayerNorm(d_embedding)
        self.layers = nn.ModuleList([
            MultiHeadAttention(d_embedding, num_heads),
            MultiHeadCrossAttention(d_embedding, num_heads),
            FeedForwardNetwork(d_embedding, d_ff)
        ] * num_layers)
    
    def forward(self, x_decoder, x_encoder):
        for i, layer in enumerate(self.layers):
            if i % 3 == 1:
                x_decoder = layer(x_decoder, x_encoder)
            else:
                x_decoder = layer(x_decoder)
        return x_decoder

In [8]:
# Dummy input
batch_size = 2
seq_len = 5
d_embedding = 64
num_heads = 8

x = torch.randn(batch_size, seq_len, d_embedding)

# Apply attention
mha = MultiHeadAttention(d_embedding, num_heads)
output = mha(x)
print(output.shape)  # (2, 5, 64)


torch.Size([2, 5, 64])


In [9]:
ffn = FeedForwardNetwork(d_embedding, 256)
output_2 = ffn(x)
print(output_2.shape)

torch.Size([2, 5, 64])


In [10]:
d_ff = 256
num_layers = 6
encoder = Encoder(d_embedding, num_heads, d_ff, num_layers)
output_3 = encoder(x)
print(output_3.shape)

torch.Size([2, 5, 64])


project idea \
it can predict sequence of numbers, in words. \
two four six eight ten - twelve \
three six nine twelve - fifteen \
could be arithmetic and geometric. I will generate them, code up the number to string mapper, pass it mapped to strings