In [3]:
import torch
import torch.nn as nn
import math
import os
import sys

ImportError: cannot import name 'scaling_layer' from 'src.scaling_layer' (/home/jason/Documents/DownScaleTransformerEncoder/src/scaling_layer.py)

In [10]:
encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)

In [11]:
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)

In [132]:
src = torch.rand(32, 40, 768)

In [366]:
class MultiHeadSelfAttn(nn.Module):
    def __init__(self, in_features, out_features, heads=8):
        '''
        in_features: should be equal to the embedding_dimension
        output: what the size of the output embedding should be
        '''
        
        super(MultiHeadSelfAttn, self).__init__()
        self.in_features = in_features
        self.heads = heads
        self.out_features = out_features
        
        # Make sure that in_features is compatible with the number of heads
        assert in_features % heads == 0
        
        # dk is the size of each of the linear projections of the embedding
        self.dk = in_features // 8
        
        # These are the parameters to project the matrix to the amount of heads
        self.key_projections = nn.Linear(self.in_features, self.in_features)
        self.value_projections = nn.Linear(self.in_features, self.in_features)
        self.query_projections = nn.Linear(self.in_features, self.in_features)
        
        # The final linear layer
        self.end_linear = nn.Linear(self.in_features, self.out_features)
        
        # Softmax
        self.softmax = nn.Softmax(dim=1)
        
    def scaled_attention(self, head_query, head_keys, head_values):
        # Calculate the scaled dot-product
        attn = torch.matmul(head_query, head_keys.transpose(-2, -1)) / math.sqrt(self.dk)
        
        # Get the softmax
        attn = self.softmax(attn)
        
        # Multiply the softmax output by the values
        attn = torch.matmul(attn, head_values)
        return attn
        
    def forward(self, embeddings):
        '''
        Forward propagate through the multi-head attention
        embeddings: should be of dimensions (batch, sequence_length, embedding_dimension)
        '''
        
        batches, sequence_length, embeddings_dim = embeddings.size()
        
        # Get the query projections
        query = self.query_projections(embeddings)
        query = query.view(batches, self.heads, sequence_length, self.dk)
        
        # Get the key projections
        keys = self.key_projections(embeddings)
        keys = keys.view(batches, self.heads, sequence_length, self.dk)
        
        # Get the value projections
        values = self.value_projections(embeddings)
        values = values.view(batches, self.heads, sequence_length, self.dk)
        
        # Calculated the scaled dot-product attention
        attn_out = self.scaled_attention(query, keys, values)
        
        # Put it in dimensions (batches, sequence_length, in_features)
        attn_out = attn_out.view(batches, sequence_length, self.in_features)
        
        # Apply the final linear layer
        return self.end_linear(attn_out)
        

In [367]:
multi_attn = MultiHeadSelfAttn(768, 768)
x = torch.randn(32, 40, 768)
multi_attn(x).size()

torch.Size([32, 40, 768])

In [372]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, in_features, inner_features, out_features, dropout=0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.in_features = in_features
        self.inner_features = inner_features
        self.out_features = out_features
        
        # First linear layer goes from the in_features to inner_features dimension
        self.fc1 = nn.Linear(in_features, inner_features)
        
        # Second linear layer goes from inner_features to out_features
        self.fc2 = nn.Linear(inner_features, out_features)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, embeddings):
        '''
        Forward propagate through the Feed-Forward network
        '''
        
        embeddings = self.fc1(embeddings)
        embeddings = self.dropout(embeddings)
        embeddings = self.relu(self.fc2(embeddings))
        embeddings = self.dropout(embeddings)
        return embeddings

In [373]:
feed_forward = PositionWiseFeedForward(in_features=768, inner_features=2560, out_features=768)
x = torch.randn(32, 40, 768)
feed_forward(x).size()

torch.Size([32, 40, 768])

In [374]:
ln = nn.LayerNorm(768)
ln(feed_forward(x)).size()

torch.Size([32, 40, 768])

In [383]:
class ScalingLayer(nn.Module):
    def __init__(self, in_features, out_features, pwff_inner_features):
        super(ScalingLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.pwff_inner_features = pwff_inner_features
        
        # Multi-Head Self Attention
        self.multihead = MultiHeadSelfAttn(in_features=self.in_features,
                                           out_features=self.out_features)
        
        # Position-Wise Feed Forward
        self.pwff = PositionWiseFeedForward(in_features=self.out_features,
                                            inner_features=self.pwff_inner_features,
                                            out_features=self.out_features)
        
        # This is used to scale the original embedding to make a residual connection if in_features != out_features
        self.residual_scale = nn.Linear(in_features=self.in_features,
                                        out_features=self.out_features)
        
        # The Layer Normalization layers
        self.multihead_ln = nn.LayerNorm(out_features)
        self.pwff_ln = nn.LayerNorm(out_features)
    
    def forward(self, embeddings):
        # This will be for adding later
        residual = embeddings.clone()
        
        # Forward propagate through the multihead attention layer
        out = self.multihead(embeddings)
        
        # Scale the original embeddings down to add element-wise
        residual = self.residual_scale(residual)
        out = self.multihead_ln(out + residual)
        
        # Keep a copy of out now to add after the Position-Wise Feed Forward
        residual = out.clone()
        
        # Forward propagate through the Position-Wise Feed Forward
        out = self.pwff_ln(out)
        
        # Complete the residual connection and apply Layer Normalization
        out = self.pwff_ln(out + residual)
        return out

In [387]:
x = torch.randn(32, 40, 768)
scale = ScalingLayer(768, 512, 2048)
scale2 = ScalingLayer(512, 256, 1024)

In [388]:
scale2(scale(x)).size()

torch.Size([32, 40, 256])