In [2]:
import math
import torch
from torch import nn
import torch.nn.functional as F

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self, q_dim, k_dim, v_dim, hidden_dim, num_head, dropout):
        super(MultiHeadAttention, self).__init__()
        self.num_head = num_head
        self.hidden_dim = hidden_dim
        self.W_q = nn.Linear(q_dim, hidden_dim)
        self.W_k = nn.Linear(k_dim, hidden_dim)
        self.W_v = nn.Linear(v_dim, hidden_dim)
        self.W_o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, query, key, value, valid_len = None, causal = False):
        head_dim = self.hidden_dim // self.num_head
        B, Lq, _ = query.shape
        _, Lk, _ = key.shape
        _, Lv, _ = value.shape
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        Q = Q.reshape(B, Lq, self.num_head, head_dim).permute(0,2,1,3)
        K = K.reshape(B, Lk, self.num_head, head_dim).permute(0,2,1,3)
        V = V.reshape(B, Lv, self.num_head, head_dim).permute(0,2,1,3)
        scale = Q @ K.transpose(2,3) / (head_dim**0.5)
        if valid_len != None:
            mask = torch.arange(Lk)[None, None, None, : ] >= valid_len[:, None, None, None]
            scale = scale.masked_fill(mask, -1e6)
        if causal:
            causal_mask = torch.zeros(Lq,Lk).triu(1).bool()
            scale = scale.masked_fill(causal_mask[None, None, :,:], -1e6)
        weight = F.softmax(scale, dim = -1) 
        weight = self.dropout(weight)
        out = weight @ V
        O = out.reshape(B, Lq, self.hidden_dim)
        return self.W_o(O)

In [31]:
hidden_dim = 100
num_head = 5
attention = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, 0.5)
attention.eval()
batchsize = 2
num_queries = 4
X = torch.ones((batchsize, num_queries, hidden_dim))
valid_len = torch.tensor([3,2])
attention(X, X, X, valid_len).shape


torch.Size([2, 4, 100])

In [32]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_dim, dropout, max_len = 1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.P = torch.zeros((1, max_len, hidden_dim))
        X = torch.arange(max_len).reshape(-1,1)/torch.pow(10000, torch.arange(0, hidden_dim, 2)/hidden_dim)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        X = X + self.P[:, :X.shape[1],:]
        return self.dropout(X)

In [33]:
encoder_dim = 100
dropout = 0.5
tokens = 50
posEncoding = PositionalEncoding(encoder_dim, dropout)
X = posEncoding(torch.ones(1, tokens, encoder_dim))


In [34]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, ff_dim, dropout):
        super().__init()
        self.fc1 = nn.Linear(hidden_dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        return  self.fc2(self.dropout(F.relu(self.fc1(x))))

In [35]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_head, ff_dim, dropout):
        super().__init()
        self.attention = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, dropout)
        self.mlp = FeedForward(hidden_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
    def forward(self, x):
        att = self.attention(x, x, x)
        x = self.norm1(x + att)
        m = self.mlp(x)
        x = self.norm2(x + m)
        return x    

In [36]:
class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_head, ff_dim, dropout):
        super().__init__()
        self.attention1 = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, dropout)
        self.attention2 = MultiHeadAttention(hidden_dim, hidden_dim, hidden_dim, hidden_dim, num_head, dropout)
        self.mlp = FeedForward(hidden_dim, ff_dim, dropout)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
    def forward(self, x, input_embed):
        att1 = self.attention1(x, x, x, causal = True)
        x = self.norm1(x + att1)
        att2 = self.attention2(x, input_embed, input_embed)
        x = self.norm2(x + att2)
        m = self.mlp(x)
        x = self.norm3(x + m)
        return x

In [37]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, hidden_dim, num_layer, num_head, ff_dim, dropout):
        super().__init__()
        self.num_layer = num_layer
        self.embed = nn.Embedding(vocab_size, hidden_dim)
        self.PE = PositionalEncoding(hidden_dim, dropout)
        self.encoder = nn.ModuleList([EncoderBlock(hidden_dim, num_head, ff_dim, dropout) for _ in range(num_layer)])
        self.decoder = nn.ModuleList([DecoderBlock(hidden_dim, num_head, ff_dim, dropout) for _ in range(num_layer)])
        self.fc_out = nn.Linear(hidden_dim, vocab_size)
    def forward(self, inputs, outputs):
        input_embed = self.embed(inputs)
        input_embed = self.PE(input_embed)
        for layer in self.encoder:
            input_embed = layer(input_embed)
        output_embed = self.embed(outputs)
        output_embed = self.PE(output_embed)
        for layer in self.decoder:
            output_embed = layer(output_embed, input_embed)
        return self.fc_out(output_embed)