In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [4]:
class AddNorm(nn.Module):
    def __init__(self, d_model , dropout_rate=0.1):
        super(AddNorm, self).__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, sublayer_output  ):
        dropped_sublayer_output = self.dropout(sublayer_output)
        residual_output = x + dropped_sublayer_output
        normalized_output = self.norm(residual_output)
        return normalized_output
    






In [5]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model , d_ff, dropout_rate=0.1):
        super(PositionwiseFeedForward , self).__init__()
        self.w_1 = nn.Linear(d_model , d_ff)
        self.w_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_rate)


    def forward(self, x):
        x = self.w_1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.w_2(x)
        return x
    

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads , dropout_rate=0.1):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_k = d_model // num_heads
        self.num_heads = self.num_heads
        self.d_model = self.d_model

        self.W_q = nn.Linear(d_model , d_model)
        self.W_k = nn.Linear(d_model , d_model)
        self.W_v = nn.Linear(d_model , d_model)

        self.W_o = nn.Linear(d_model , d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def scaled_dot_product_attention(self, Q, K , V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float))
                                        #0 masked will 
    
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask==0 , float('-inf'))

        attn_weights = F.softmax(attn_scores,  dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights , V)
        return attn_output , attn_weights
    
    def forward(self, query_input ,key_input , value_input , mask=None):
        batch_size = query_input.size(0)

        Q = self.W_q(query_input)
        K = self.W_k(key_input)
        V = self.W_v(value_input)


        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)   


        attn_output, attn_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.W_o(attn_output)

        return output, attn_weights

In [7]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model , num_heads, d_ff, dropout_rate=0.1):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads, dropout_rate)
        self.addnorm_1 = AddNorm(d_model, dropout_rate)

        #feed forward
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout_rate)
        self.addnorm_2 = AddNorm(d_model , dropout_rate)

    def forward(self, x, src_mask):
        # multihead self attentioin sublayer
        self_attn_output , self_attn_weights = self.self_attention(x,x,x,mask = src_mask)
        x = self.addnorm_1(x,self_attn_output)

        #position wise feed forward network sublayer
        ffn_output = self.ffn(x)
        x = self.addnorm_2(x , ffn_output)
        return x, self_attn_weights
    
    