In [9]:
# deep learning
import torch
import torch.nn.functional as F

# linear algebra
import numpy as np

In [None]:
def positional_encoding(L: int, d_model: int, N: int = 10000) -> np.array:
    pos = np.arange(L)[:, np.newaxis] # [L,1]
    i = np.arange(d_model)[np.newaxis, :] # [1,d_model]

    angle_rates = 1 / np.power(N, (2*(i//2)) / d_model)
    angle_rads = pos * angle_rates # [L,d_model]

    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

    return angle_rads

class PositionalEmbedding(torch.nn.Module):
    def __init__(self, d_model, max_len=64):
        super().__init__()

        pe = torch.zeros(max_len, d_model).float()
        pe.requires_grad = False

        pe += positional_encoding(max_len, d_model)

        self.pe = pe.unsqueeze(0) # extra batch dimension : [1, 64, 128]
    
    def forward(self):
        return self.pe
    

class BERTEmbedding(torch.nn.Module):
    def __init__(self, vocab_size: int, embed_size: int, max_len: int, dropout: float):
        super().__init__()
        self.embed_size = embed_size

        self.token = torch.nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.segment = torch.nn.Embedding(3, embed_size, padding_idx=0) # padding, seqA, seqB
        self.position = PositionalEmbedding(embed_size,max_len)
        self.dropout = torch.nn.Dropout(p=dropout)
    
    def forward(self, seq: list, segment_label: list):
        embs = self.token(seq) + self.position() + self.segment(segment_label)
        return self.dropout(embs)

In [3]:
class SingleHeadAttention(torch.nn.Module):
    def __init__(self,d_model: int, d_k: int,dropout_rate=0.1):
        super(SingleHeadAttention, self).__init__()

        self.query = torch.nn.Linear(in_features=d_model,out_features=d_k)
        self.key = torch.nn.Linear(d_model, d_k)
        self.values = torch.nn.Linear(d_model, d_k)
        self.dropout = torch.nn.Dropout(p=dropout_rate)

    def forward(self, E: torch.Tensor) -> torch.Tensor:

        # E : [batch_size, max_len, d_model]

        # Projection to attention space
        Q_emb = self.query(E)
        K_emb = self.key(E)
        V_emb = self.values(E)
        # [batch_size, max_len, d_k]


        dk = self.query.out_features

        # Scores
        scores = torch.matmul(Q_emb, K_emb.transpose(-2,-1)) / np.sqrt(dk)
        # [batch_size, max_len, max_len]
        
        # Attention weights
        # Softmax on the deeper dimension of scores
        attention_scores = F.softmax(scores, dim=-1)
        attention_scores = self.dropout(attention_scores)
        # [batch_size, max_len, max_len]

        # Innovations
        attention_values = torch.matmul(attention_scores,V_emb)
        # [batch_size, max_len, d_model]

        return attention_values

Let's improve the process with parallelization

In [4]:
class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_model, n_heads, dropout_rate=0.1):
        super().__init__()

        assert d_model % n_heads == 0

        self.d_k = d_model // n_heads
        self.n_heads = n_heads
        self.dropout = torch.nn.Dropout(p=dropout_rate)

        # [d_model, n_heads * d_k]
        self.query = torch.nn.Linear(d_model, d_model)
        self.key = torch.nn.Linear(d_model, d_model)
        self.value = torch.nn.Linear(d_model, d_model)
        self.output_linear = torch.nn.Linear(d_model, d_model)

    def forward(self, emb):

        # Projection to attention space
        # [batch_size, max_len, d_model]
        query_emb = self.query(emb)
        key_emb = self.key(emb)
        value_emb = self.value(emb)

        # Parallelization
        # [batch_size, n_heads, max_len, d_k]
        batch_size = query_emb.shape[0]
        query_emb = query_emb.view(batch_size, -1, self.n_heads, self.d_k).permute(0,2,1,3)
        key_emb = key_emb.view(batch_size, -1, self.n_heads, self.d_k).permute(0,2,1,3)
        value_emb = value_emb.view(batch_size, -1, self.n_heads, self.d_k).permute(0,2,1,3)

        # Scores => similarity scores between tokens
        # [batch_size, h, max_len, max_len]
        attention_scores = torch.matmul(query_emb, key_emb.transpose(-2,-1)) / np.sqrt(self.d_k)

        # Weights => part of each token that will update other tokens
        # [batch_size, h, max_len, max_len]
        weights = F.softmax(attention_scores, dim=-1)
        weights = self.dropout(weights)

        # Innovations
        # [batch_size, h, max_len, d_model]
        innovations = torch.matmul(weights, value_emb)

        # Concatenation 
        # [batch_size, max_len, d_model]
        innovations = innovations.permute(0,2,1,3).contiguous().view(batch_size, -1, self.n_heads * self.d_k)

        # Context to update embeddings
        context = self.output_linear(innovations)
        return context

In [5]:
class FeedForward(torch.nn.Module):
    def __init__(self, d_ff, d_model, dropout_rate=0.1):
        super().__init__()

        self.hidden_linear = torch.nn.Linear(d_model,d_ff)
        self.output_linear = torch.nn.Linear(d_ff,d_model)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.activation = torch.nn.GELU() # GELU performs better than RELU

    def forward(self,emb):
        
        # Projection to hidden space
        # [batch_size, max_len, d_ff]
        output_hidden = self.hidden_linear(emb)
        output_hidden = self.activation(output_hidden)

        # Projection to original space
        # [batch_size, max_len, d_model]
        output_ff = self.output_linear(output_hidden)
        output_ff = self.dropout(output_ff)

        return output_ff

In [6]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, d_model, n_heads, d_ff,dropout_rate=0.1):
        super().__init__()

        self.mha = MultiHeadAttention(d_model,n_heads, dropout_rate)
        self.layernorm = torch.nn.LayerNorm(normalized_shape=d_model)
        self.ff = FeedForward(d_ff,d_model, dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
    
    def forward(self, emb):
        
        # Linear context
        # [batch_size, max_len, d_model]
        output_mha = self.dropout(self.mha(emb))
        updated_emb = self.layernorm(emb + output_mha)

        # Context improved to non linear relationships
        # [batch_size, max_len, d_model]
        output_ff = self.ff(updated_emb)

        encoder_output = self.layernorm(output_ff)
        return encoder_output

In [12]:
class Encoder(torch.nn.Module):
    def __init__(self, d_model, n_heads, d_ff, n_layers, dropout_rate=0.1):
        super().__init__()

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.d_ff = d_ff

        self.encoder_layers = torch.nn.ModuleList(
            EncoderLayer(d_model, n_heads, d_ff, dropout_rate) for _ in range(n_layers)
        )

    def forward(self, emb):
        encoder_output = emb

        for layer in self.encoder_layers:
            encoder_output = layer(encoder_output)

        return encoder_output

In [None]:
## Couple of tests

batch_size = 3
n_heads = 4
max_len = 64
d_model = 32
d_k = d_model // n_heads # 8
d_ff = 128
n_layers = 2

embs = torch.randn(size=(batch_size,max_len,d_model))

sha = SingleHeadAttention(d_model,d_k)

output_sha = sha(embs)

output_sha.shape

mha = MultiHeadAttention(d_model,n_heads)

output_mha = mha(embs)

output_mha.shape == torch.Size([batch_size, max_len, d_model])

ff = FeedForward(d_ff, d_model)

output_ff = ff(output_mha)

output_ff.shape == torch.Size([batch_size, max_len, d_model])

enc_l = EncoderLayer(d_model, n_heads, d_ff)

output_enc_l = enc_l(embs)

output_enc_l.shape == torch.Size([batch_size, max_len, d_model])

enc = Encoder(d_model,n_heads,d_ff,n_layers)

enc_output = enc(embs)

enc_output.shape == torch.Size([batch_size, max_len, d_model])

True