In [11]:
import torch
from torch import nn
from copy import copy

## BatchNorm : 
- normlaizes activations of a layer by adjusting and scaling them within a mini-batch. works on each feature independently
- mean and std dev for a feature across all batch

## LayerNorm: 
- normalizes actovations of one layer across the features. works on each training example separately and noralizes across feature axis
- mean and std dev for aacross all feature of a layer

In [12]:
def clones(module, N):
    '''Produce N identical layers'''
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [13]:
class LayerNorm(nn.Module):

    def __init__(self, features, eps = 1e-6):
        super(LayerNorm,self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features))
        self.b_2 = nn.Parameter(torch.zeros(features))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)        
        std = x.std(-1, keepdim=True) 
        return self.a_2*(x-mean) / (std + self.eps) + self.b_2       

1 layer of encoder has -> 2 sub-layer ie. 1 multi-head self-attention and 1 feed forward network

In [14]:
class SubLayerConnection(nn.Module):
    '''
    common framework for all sublayers
    '''
    def __init__(self,size,dropout):
        super(SubLayerConnection, self).__init__()
        self.norm = LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

In [15]:
class EncoderLayer(nn.Module):
    '''
    specifying both sublayers within 1 layer of Encoder
    '''
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.seed_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout),2)
        self.size = size

    def forward(self, x, mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x,mask))
        return self.sublayer[1](x, self.feed_forward)

In [16]:
class Encoder(nn.Module):
    ''' stack of N layers'''

    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        '''pass input and mask through each layer'''
        for layer in self.layers:
            x = layer(x,mask)
        return self.norm(x)

# Decoder

In [None]:
class Decoder(nn.Module):
    ''' overall structure of Decoder - a stack of N decoderlayers'''
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x , memory , src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return self.norm(x)

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, size, self_attn , src_attn , feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SubLayerConnection(size, dropout), 3) #each decoder layer has 3 sublayers

    def forward (self, x, memory, src_mask, tgt_mask):

        m = memory
        x = self.sublayer[0](x, lambda x: self.self_attn(x,x,x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
        return self.sublayer[2](x, self.feed_forward)