# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import joblib
import time
import matplotlib.pyplot as plt
from flair.data import Sentence
from flair.embeddings import WordEmbeddings
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
%matplotlib inline

device = torch.device("cuda:0")
dtype = torch.float

d_model = 300 # embedding size of fasttext models
batch_size = 64
seq_len = 32
output_lang = 'en' # select language

# Select translation scheme in -> out

embedding_out = None

if output_lang == 'en'
    embedding_out = WordEmbeddings('en')
else:
    embedding_out = WordEmbeddings('de)

In [2]:
class Language_DataSet(torch.utils.data.Dataset):
    
    def __init__(self, path):#D:\Transformer\dataset.data
        super().__init__()
        
        self.data = joblib.load(path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        # return format: input_encoding, labels
        
        sample = self.data[index]
        print(sample[1].shape)
        
        item = None
        if output_lang == 'en':
            item = (sample[1],sample[2])
        else:
            item = (sample[0],sample[3])
            
        return sample

In [3]:
#dataset = Language_DataSet('D:\Transformer\dataset.data')
#dl = DataLoader(dataset, batch_size=64, shuffle=True)
#for eps_data in dl:
    #print(i)
    #print(eps_data[1])
    #break

# Transformer Implementation
## Multi-Head-Attention

In [4]:
class Multi_Head_Attention(nn.Module):
    
    def __init__(self, heads, dimensions):
        super().__init__()
        
        self.heads = heads
        self.dimensions = dimensions
        
        self.qkv = nn.Linear(d_model, dimensions * heads * 3)
        self.final_linear = nn.Linear(self.heads * self.dimensions, d_model)
    
    def forward(self, x, mask=None):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # reshape for linear qkv layer
        x = torch.reshape(x, (batch_size*seq_len, d_model))
        
        # compute q,v,k for every head
        qkv = self.qkv(x) # (seq * batch, heads * 3 * dimensions)

        # reshape into (batch_size,seq_len,...)
        qkv = torch.reshape(qkv, (batch_size, seq_len, self.heads * 3 * self.dimensions))
        # split into heads and seperate q, k, v in different dims
        qkv = torch.reshape(qkv, (batch_size, seq_len, self.heads, 3, self.dimensions))
        
        # permute head to front for parallel processing
        qkv = qkv.permute(0,2,1,3,4)
        
        # extract q, k, v
        q = qkv[:,:,:,0,:]
        k = qkv[:,:,:,1,:]
        v = qkv[:,:,:,2,:]
        
        # fuse batch_size and head dim for parallel processing
        q = torch.reshape(q, (batch_size * self.heads, seq_len, self.dimensions))
        k = torch.reshape(k, (batch_size * self.heads, seq_len, self.dimensions))
        v = torch.reshape(v, (batch_size * self.heads, seq_len, self.dimensions))
        
        # transpose k
        k = torch.transpose(k, 1, 2)
        
        # multiply q and k
        qk = torch.bmm(q,k)
        # scale
        qk = qk / torch.sqrt(torch.tensor(self.dimensions).to(device).to(dtype))
        # optional masking
        if mask is not None:
            qk[mask == 1] = float('-inf')
        # softmax
        qk = F.softmax(qk, dim=2)
        return qk        
    
        # multiply with v
        qkv = torch.bmm(qk, v)
        
        # reshape to cat heads
        qkv = torch.reshape(qkv, (batch_size, self.heads, seq_len, self.dimensions))
        # cat all heads
        qkv = qkv.permute(0,2,1,3)
        qkv = torch.reshape(qkv, (batch_size, seq_len, self.heads * self.dimensions))
        
        # reshape to multiply with final linear
        qkv = torch.reshape(qkv, (batch_size * seq_len, self.heads * self.dimensions))
        # multiply with final linear
        z = self.final_linear(qkv)
        
        # reshape to input format
        z = torch.reshape(z, (batch_size, seq_len, d_model))
        
        return z

## Encoder Cell

In [5]:
class Encoder_Cell(nn.Module):
    
    def __init__(self, heads, attention_dimension, ff_inner=1024):
        super().__init__()
        
        self.self_attention = Multi_Head_Attention(heads, attention_dimension)
        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.layer_norm_2 = nn.LayerNorm([d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z = self.self_attention(x)
        
        # 1st residual
        residual_1 = x + z
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_1, (batch_size*seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, seq_len, d_model))
        
        # 2nd residual
        residual_2 = norm_1 + ff_out
        # 2nd norm
        norm_2 = self.layer_norm_1(residual_2)
        
        return norm_2

## Encoder

In [6]:
class Encoder(nn.Module):
    
    def __init__(self, cells, heads, attention_dimensions):
        super().__init__()
        
        encoder_cells = [ Encoder_Cell(heads, attention_dimensions) for i in range(cells)]
        self.encode = nn.Sequential(*encoder_cells)
    
    def forward(self, x):
        encoding = self.encode(x)
        return encoding

## Encoder-Decoder Attention

In [None]:
class Encoder_Decoder_Attention(nn.Module):

    def __init__(self, heads, attention_dimension):
        super().__init__()
        
        self.q = nn.Linear(d_model, dimensions * heads)
        self.final_linear = nn.Linear(self.heads * self.dimensions, d_model) 
        
    def forward(self, x, encoder_k, encoder_v):
        # TODO
        

## Decoder Cell

In [7]:
class Decoder_Cell(nn.Module):

    def __init__(self, heads, attention_dimension, , ff_inner=1024):
        super().__init__()
        
        # construct decoder mask # MASK CORRECT?????
        a = np.triu(np.ones((seq_len,seq_len)), k=1) 
        mask = torch.unsqueeze(torch.tensor(a).to(device).long(),dim=0)
        self.mask = mask.repeat(batch_size*heads,1,1)
        
        self.self_attention = Multi_Head_Attention(heads, attention_dimension)
        self.enc_dec_attention = Encoder_Decoder_Attention(heads, attention_dimension)
        
        self.layer_norm_1 = nn.LayerNorm([d_model])
        self.layer_norm_2 = nn.LayerNorm([d_model])
        self.layer_norm_3 = nn.LayerNorm([d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x, encoder_k, encoder_v):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z_1 = self.self_attention(x)
        
        # 1st residual
        residual_1 = x + z_1
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # encoder-decoder attention
        z_2 = self.enc_dec_attention(norm_1, encoder_k, encoder_v)
        
        # 2nd residual
        residual_2 = norm_1 + z_2
        # 2nd norm
        norm_2 = self.layer_norm_2(residual_2)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_2, (batch_size*seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, seq_len, d_model))
        
        # 3rd residual
        residual_3 = norm_2 + ff_out
        # 3rd norm
        norm_3 = self.layer_norm_3(residual_3)
        
        return norm_3

In [8]:
decoder_in = torch.randn(batch_size,seq_len,d_model).to(device).to(dtype)
decoder = Decoder_Cell(4,64).to(device)
out = decoder(decoder_in)
out.shape

torch.Size([256, 32, 32])

In [9]:
out

tensor([[[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.5607, 0.4393, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.3693, 0.2933, 0.3375,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0496, 0.0253, 0.0292,  ..., 0.0292, 0.0000, 0.0000],
         [0.0625, 0.0624, 0.0211,  ..., 0.0262, 0.0349, 0.0000],
         [0.0228, 0.0234, 0.0183,  ..., 0.0464, 0.0207, 0.0172]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.6350, 0.3650, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.3662, 0.2677, 0.3661,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0316, 0.0240, 0.0366,  ..., 0.0524, 0.0000, 0.0000],
         [0.0265, 0.0173, 0.0492,  ..., 0.0335, 0.0365, 0.0000],
         [0.0173, 0.0188, 0.0330,  ..., 0.0229, 0.0258, 0.0283]],

        [[1.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.3467, 0.6533, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.4725, 0.2850, 0.2426,  ..., 0.0000, 0.0000, 0.

encoder_in = torch.randn(64,seq_len,d_model).to(device).to(dtype)
encoder = Encoder(4,4,64).to(device)
print("weights:",sum(p.numel() for p in encoder.parameters() if p.requires_grad))
start = time.time()
out = encoder(encoder_in)
end = time.time()
print(f'time: {end - start}')
print(out.shape)
print(out)
del out

a = np.array([[1,0],[1,1]])
b = a * np.array([[-np.inf,-np.inf],[-np.inf,-np.inf]])
b[b == float('nan')] = 0
torch.tensor(b)

In [10]:
a = np.ones((5,5))
b = np.triu(a, k=1) 
#b[b==1] = 0
torch.tensor(b)

tensor([[0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0.]], dtype=torch.float64)

In [12]:
b = np.triu(a, k=1)
b

array([[0., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1.],
       [0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0.]])