# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import joblib
import time
import matplotlib.pyplot as plt
from flair.data import Sentence
from flair.data import Token
from flair.embeddings import WordEmbeddings
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
%matplotlib inline

device = torch.device("cuda:0")
dtype = torch.float

d_model = 300 # embedding size of fasttext models
batch_size = 128
seq_len_encoder = 32
seq_len_decoder = 36
output_lang = 'en' # select language

# Select translation scheme in -> out

embedding_out = None

if output_lang == 'en'
    embedding_out = WordEmbeddings('en')
else:
    embedding_out = WordEmbeddings('de)

In [2]:
class Language_DataSet(torch.utils.data.Dataset):
    
    def __init__(self, path):#D:\Transformer\dataset.data
        super().__init__()
        
        self.data = joblib.load(path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # return format: input_encoding, labels
        
        sample = self.data[index]
        print(sample[1].shape)
        
        item = None
        if output_lang == 'en':
            item = (sample[1],sample[2])
        else:
            item = (sample[0],sample[3])
            
        return sample

In [3]:
#dataset = Language_DataSet('D:\Transformer\dataset.data')
#dl = DataLoader(dataset, batch_size=64, shuffle=True)
#for eps_data in dl:
    #print(i)
    #print(eps_data[1])
    #break

# Transformer Implementation
## Multi-Head-Attention

In [4]:
class Multi_Head_Attention(nn.Module):
    
    def __init__(self, heads, seq_len, dimensions):
        super().__init__()
        
        self.heads = heads
        self.dimensions = dimensions
        self.seq_len = seq_len
        
        self.qkv = nn.Linear(d_model, dimensions * heads * 3)
        self.final_linear = nn.Linear(self.heads * self.dimensions, d_model)
    
    def forward(self, x, mask=None):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # reshape for linear qkv layer
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # compute q,v,k for every head
        qkv = self.qkv(x) # (seq * batch, heads * 3 * dimensions)

        # reshape into (batch_size,seq_len,...)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * 3 * self.dimensions))
        # split into heads and seperate q, k, v in different dims
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads, 3, self.dimensions))
        
        # permute head to front for parallel processing
        qkv = qkv.permute(0,2,1,3,4)
        
        # extract q, k, v
        q = qkv[:,:,:,0,:]
        k = qkv[:,:,:,1,:]
        v = qkv[:,:,:,2,:]
        
        # fuse batch_size and head dim for parallel processing
        q = torch.reshape(q, (batch_size * self.heads, self.seq_len, self.dimensions))
        k = torch.reshape(k, (batch_size * self.heads, self.seq_len, self.dimensions))
        v = torch.reshape(v, (batch_size * self.heads, self.seq_len, self.dimensions))
        
        # transpose k
        k = torch.transpose(k, 1, 2)
        
        # multiply q and k
        qk = torch.bmm(q,k)
        # scale
        qk = qk / torch.sqrt(torch.tensor(self.dimensions).to(device).to(dtype))
        # optional masking
        if mask is not None:
            qk[mask == 1] = float('-inf')
        # softmax
        qk = F.softmax(qk, dim=2)     
        
        # multiply with v
        qkv = torch.bmm(qk, v)
        
        # reshape to cat heads
        qkv = torch.reshape(qkv, (batch_size, self.heads, self.seq_len, self.dimensions))
        # cat all heads
        qkv = qkv.permute(0,2,1,3)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * self.dimensions))
        
        # reshape to multiply with final linear
        qkv = torch.reshape(qkv, (batch_size * self.seq_len, self.heads * self.dimensions))
        # multiply with final linear
        z = self.final_linear(qkv)
        
        # reshape to input format
        z = torch.reshape(z, (batch_size, self.seq_len, d_model))
        
        return z

## Encoder Cell

In [5]:
class Encoder_Cell(nn.Module):
    
    def __init__(self, heads, seq_len, attention_dimension, ff_inner=1024):
        super().__init__()
        self.seq_len = seq_len
        
        self.self_attention = Multi_Head_Attention(heads, seq_len, attention_dimension)
        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.layer_norm_2 = nn.LayerNorm([d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z = self.self_attention(x)
        
        # 1st residual
        residual_1 = x + z
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_1, (batch_size*self.seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, self.seq_len, d_model))
        
        # 2nd residual
        residual_2 = norm_1 + ff_out
        # 2nd norm
        norm_2 = self.layer_norm_1(residual_2)
        
        return norm_2

## Encoder

In [6]:
class Encoder(nn.Module):
    
    def __init__(self, cells, heads, seq_len, attention_dimensions):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        
        # stacked encoder cells
        encoder_cells = [ Encoder_Cell(heads, seq_len, attention_dimensions).to(device) for i in range(cells)]
        self.encode = nn.Sequential(*encoder_cells)
        
        # key and value output of encoder
        self.kv = nn.Linear(d_model, attention_dimensions * heads * 2)
    
    def forward(self, x):
        # encoding shape: batch_size, seq_len, d_model
        encoding = self.encode(x) 
        
        # reshape to feed into linear kv layer
        encoding = torch.reshape(encoding, (batch_size * self.seq_len, d_model))
        
        # apply linear
        kv = self.kv(encoding)
        # reshape back
        kv = torch.reshape(kv, (batch_size, self.seq_len, self.attention_dimensions * self.heads * 2))
        
        # seperate k and v
        kv = torch.reshape(kv, (batch_size, self.seq_len, self.heads, 2, self.attention_dimensions))
        
        # permute head to front for parallel processing
        kv = kv.permute(0,2,1,3,4)
        
        # split k, v
        k = kv[:,:,:,0,:]
        v = kv[:,:,:,1,:]
        
        # fuse batch_size and head dim for parallel processing
        k = torch.reshape(k, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        v = torch.reshape(v, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        
        return k, v

## Encoder-Decoder Attention

In [7]:
class Encoder_Decoder_Attention(nn.Module):

    def __init__(self, heads, seq_len, attention_dimensions):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        
        self.q = nn.Linear(d_model, attention_dimensions * heads).to(device)
        self.final_linear = nn.Linear(heads * attention_dimensions, d_model).to(device)
        
    def forward(self, x, encoder_k, encoder_v):
        # x shape: batch_size, seq_len, d_model
        # encoder k/v shape: batch_size*heads, seq_len, attention_dimensions
        
        # reshape for linear q layer
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # compute q for every head
        q = self.q(x) # (seq * batch, heads * attention_dimensions)

        # reshape into (batch_size,seq_len,...)
        q = torch.reshape(q, (batch_size, self.seq_len, self.heads * self.attention_dimensions))
        # split into heads 
        q = torch.reshape(q, (batch_size, self.seq_len, self.heads, self.attention_dimensions))
        
        # permute head to front for parallel processing
        q = q.permute(0,2,1,3)
        
        # fuse batch_size and head dim for parallel processing
        q = torch.reshape(q, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        
        # transpose k
        k = torch.transpose(encoder_k, 1, 2)
        
        # multiply q and k
        qk = torch.bmm(q,k)
        # scale
        qk = qk / torch.sqrt(torch.tensor(self.attention_dimensions).to(device).to(dtype))
        # softmax
        qk = F.softmax(qk, dim=2)     
        
        # multiply with v
        qkv = torch.bmm(qk, encoder_v)
        
        # reshape to cat heads
        qkv = torch.reshape(qkv, (batch_size, self.heads, self.seq_len, self.attention_dimensions))
        # cat all heads
        qkv = qkv.permute(0,2,1,3)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * self.attention_dimensions))
        
        # reshape to multiply with final linear
        qkv = torch.reshape(qkv, (batch_size * self.seq_len, self.heads * self.attention_dimensions))
        # multiply with final linear
        z = self.final_linear(qkv)
        
        # reshape to input format
        z = torch.reshape(z, (batch_size, self.seq_len, d_model))
        
        return z     

## Decoder Cell

In [8]:
class Decoder_Cell(nn.Module):

    def __init__(self, heads, seq_len, attention_dimension, ff_inner=1024):
        super().__init__()
        self.seq_len = seq_len
        
        # construct decoder mask # MASK CORRECT?????
        a = np.triu(np.ones((seq_len,seq_len)), k=1) 
        mask = torch.unsqueeze(torch.tensor(a).to(device).long(),dim=0)
        self.mask = mask.repeat(batch_size*heads,1,1)
        
        self.self_attention = Multi_Head_Attention(heads, seq_len, attention_dimension).to(device)
        self.enc_dec_attention = Encoder_Decoder_Attention(heads, seq_len, attention_dimension).to(device)
        
        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.layer_norm_2 = nn.LayerNorm([d_model])
        self.layer_norm_3 = nn.LayerNorm([d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x, encoder_k, encoder_v):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z_1 = self.self_attention(x, self.mask)
        
        # 1st residual
        residual_1 = x + z_1
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # encoder-decoder attention
        z_2 = self.enc_dec_attention(norm_1, encoder_k, encoder_v)
        
        # 2nd residual
        residual_2 = norm_1 + z_2
        # 2nd norm
        norm_2 = self.layer_norm_2(residual_2)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_2, (batch_size*self.seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, self.seq_len, d_model))
        
        # 3rd residual
        residual_3 = norm_2 + ff_out
        # 3rd norm
        norm_3 = self.layer_norm_3(residual_3)
        
        return norm_3

In [9]:
class Decoder(nn.Module):
    
    def __init__(self, cells, heads, seq_len, attention_dimensions, vocab_size):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        self.vocab_size = vocab_size
        
        # stacked encoder cells
        self.decoder_cells = [ Decoder_Cell(heads, seq_len, attention_dimensions).to(device) for i in range(cells)]
        
        # output layer and then softmax
        self.final_linear = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, encoder_k, encoder_v):
        
        for decoder_cell in self.decoder_cells:
            x = decoder_cell(x,encoder_k, encoder_v)
            
        # reshape for linear
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # feed in laysr
        x = self.final_linear(x)
        
        # softmax over vocab_size
        softmax = F.softmax(x, dim=1)
        
        return softmax

## Transformer

In [10]:
class Transformer(nn.Module):
    
    def __init__(self, cells, heads, seq_len_enc, seq_len_dec, attention_dimensions, vocab_size):
        super().__init__()
        
        self.encoder = Encoder(cells, heads, seq_len_enc, attention_dimensions).to(device)
        self.decoder = Decoder(cells, heads, seq_len_dec, attention_dimensions, vocab_size).to(device)
    
    def forward(self, x_encoder, x_decoder):
        # x_in shape: batch_size, seq_len_in, d_model
        # x_out shape: batch_size, seq_len_out, d_model
        
        encoder_k, encoder_v = self.encoder(x_encoder)
        out = self.decoder(x_decoder, encoder_k, encoder_v)
        
        # maybe reshape? to (batch_size, seq_len_out, vocab_size)
        return out

decoder_in = torch.randn(batch_size,seq_len,d_model).to(device).to(dtype)
decoder = Decoder_Cell(4,64).to(device)
out = decoder(decoder_in)
out.shape

In [11]:
encoder_in = torch.randn(batch_size,seq_len_encoder,d_model).to(device).to(dtype)
decoder_in = torch.randn(batch_size,seq_len_decoder,d_model).to(device).to(dtype)

transformer = Transformer(3,4,seq_len_encoder,seq_len_decoder,64,40000).to(device)

In [23]:
start = time.time()
out = transformer(encoder_in,decoder_in)
end = time.time()
print(f'time: {end - start}')
out.shape
del out

time: 0.03393268585205078


In [13]:
print("weights:",sum(p.numel() for p in transformer.parameters() if p.requires_grad))

weights: 14969688
