In [1]:
from torch import nn
import torch
import torch.nn.functional as F
from math import sqrt

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, embed_dim, input_dim, dropout=0.1):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dim_heads = embed_dim // num_heads     # dim_heads aka d_k

        self.q_lin = nn.Linear(input_dim, embed_dim)
        self.k_lin = nn.Linear(input_dim, embed_dim)
        self.v_lin = nn.Linear(input_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        num_heads, dim_heads = self.num_heads, self.dim_heads

        q = self.q_lin(q).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)
        k = self.k_lin(k).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)
        v = self.v_lin(v).reshape(batch_size, -1, num_heads, dim_heads).transpose(1, 2)

        scores = attention(q, k, v, dim_heads, mask=mask, dropout=self.dropout)
        
        scores = scores.transpose(1, 2).contiguous().reshape(batch_size, -1, self.embed_dim)

        output = self.out_proj(scores)

        return output

In [3]:
def attention(q, k, v, d_k, mask=None, dropout=None):
    scaled_dot = torch.matmul(q, k.transpose(-2, -1)) / sqrt(d_k)
    if mask is not None:
        scaled_dot = scaled_dot.masked_fill(mask == 0, -1e9)
    scaled_dot = F.softmax(scaled_dot, dim=-1)
    if dropout is not None:
        scaled_dot = dropout(scaled_dot)
    output =  torch.matmul(scaled_dot, v)
    return output

In [4]:
class PositionwiseFeedForward(nn.Module):

    def __init__(self, embed_dim, input_dim, dropout_rate=0.1):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        """
        super(PositionwiseFeedForward, self).__init__()
        self.embed_dim = embed_dim
        self.input_dim = input_dim
        self.dropout_rate = dropout_rate
        self.w_1 = nn.Linear(embed_dim, input_dim)
        self.w_2 = nn.Linear(input_dim, embed_dim)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, x):
        # x = (batch_size, input_dim, embed_dim)
        x = self.dropout(F.relu(self.w_1(x))) 
        x = self.w_2(x)  
        return x

In [5]:
class WordEmbedding(nn.Module):
    def __init__(self, vocab_size, embed_dim=512):
        super().__init__()
        # embed_dim: embedding dimension (usually 1024 or 512)
        self.embed_dim = embed_dim
        self.embed_matrix = torch.empty([vocab_size, embed_dim])

        nn.init.xavier_normal_(self.embed_matrix)
        self.embed_matrix = nn.Parameter(self.embed_matrix)
        self.embed_matrix = self.embed_matrix.to(torch.float)
        # seq len x vocab_size, vocab_size x embed_dim
        # embedding matrix dimensions: number of words in vocab x embed_dim (usually 1024 or 512)

    def forward(self, x):
        # x: embedding tensor (batch_size by seq_len by vocab_size)
        return torch.matmul(x, self.embed_matrix)

In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, embed_dim, input_dim):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        """
        super().__init__()

        encod = torch.zeros(input_dim, embed_dim)

        position = torch.arange(0, input_dim, dtype=torch.float).unsqueeze(1)   # numerator

        i = torch.arange(0, embed_dim, 2, dtype=torch.float)

        denom = torch.exp(log(10000.0) * i / embed_dim)

        encod[ : , 0::2] = torch.sin(position / denom)
        encod[ : , 1::2] = torch.cos(position / denom)
        encod.unsqueeze(0)

        self.pe = encod


    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]
        return x

In [7]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, input_dim, num_heads):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        num_heads: num of heads
        """
        super().__init__()

        self.attention1 = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)
        self.attention2 = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)
        self.feedforward = PositionwiseFeedForward(embed_dim=embed_dim, input_dim=input_dim)

        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.norm3 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)
        self.dropout3 = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        # masked attention output
        attn_1_out = self.attention1(q=x, k=x, v=x, mask=mask)
        x = x + self.dropout1(attn_1_out)
        x = self.norm1(x)

        # unmasked attention output
        attn_2_out = self.attention2(q=x, k=x, v=x, mask=None)
        x = x + self.dropout2(attn_2_out)
        x = self.norm2(x)

        # feedforward output
        ff_out = self.feedforward(x)
        x = x + self.dropout3(ff_out)
        x = self.norm3(x)

        return x

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, input_dim, num_heads):
        """
        embed_dim: num of expected features in input (same as d_model)
        input_dim: length of sequence
        num_heads: num of heads
        """
        super().__init__()

        self.attention = MultiHeadAttention(num_heads=num_heads, embed_dim=embed_dim, input_dim=input_dim, dropout=0.1)
        self.feedforward = PositionwiseFeedForward(embed_dim, input_dim)

        self.norm1 = nn.LayerNorm(input_dim)
        self.norm2 = nn.LayerNorm(input_dim)
        self.dropout1 = nn.Dropout(0.1)
        self.dropout2 = nn.Dropout(0.1)

    def forward(self, x, mask=None):
        # attention output
        attn_out = self.attention(q=x, k=x, v=x, mask=mask)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)

        # feedforward output
        ff_out = self.feedforward(x)
        x = x + self.dropout2(ff_out)
        x = self.norm2(x)

        return x

In [10]:
# TESTING

embed_dim = 3
num_heads = 1

x = torch.tensor([[0, 10, 0]], dtype=torch.float32)
input_dim = 3

encoder = EncoderLayer(embed_dim=embed_dim, input_dim=input_dim, num_heads=num_heads)
output = encoder.forward(x)
print(output)

TypeError: __init__() missing 2 required positional arguments: 'embed_dim' and 'input_dim'