# Imports

In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import joblib
import time
import matplotlib.pyplot as plt
from flair.data import Sentence
from flair.data import Token
from flair.embeddings import WordEmbeddings
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
%matplotlib inline

device = torch.device("cuda:0")
dtype = torch.float

# embedding size of fasttext models
d_model = 300 

batch_size = 128

# select language
language_in = 'de'
language_out = 'en' 

# DataSet

In [2]:
class Language_DataSet(torch.utils.data.Dataset):

    def __init__(self, path, language_in, language_out):#D:\Transformer\dataset.data
        super().__init__()
        self.language_in = language_in
        self.language_out = language_out

        # load dat and vocab
        self.data = joblib.load(path+f'{language_in}_to_{language_out}.data')
        self.vocab_in = joblib.load(path+f'vocab_{language_in}.data')
        self.vocab_out = joblib.load(path+f'vocab_{language_out}.data')

        # zero padding info
        self.seq_len_in = self.vocab_in["max_sentence_len"] + 1 # additional <SOS>/<EOS> token
        self.seq_len_out = self.vocab_out["max_sentence_len"] + 1 # additional <SOS>/<EOS> token

        # precompute padding
        self.precompute_padding = torch.zeros(1, d_model).to(device)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        # return format: input_encoding, output_encoding, target

        x_in, x_out, target = self.data[index]
        x_in = torch.tensor(x_in).to(device)
        x_out = torch.tensor(x_out).to(device)
        # zero padding

        seq_len_in = x_in.shape[0]
        padding_length_in = self.seq_len_in - seq_len_in

        seq_len_out = x_out.shape[0]
        padding_length_out = self.seq_len_out - seq_len_out

        #padding_in = torch.zeros(padding_length_in, d_model).to(device)
        #padding_out = torch.zeros(padding_length_out, d_model).to(device)
        padding_in = self.precompute_padding.repeat([padding_length_in,1])
        padding_out = self.precompute_padding.repeat([padding_length_out,1])

        x_in = torch.cat([x_in, padding_in], dim=0)
        x_out = torch.cat([x_out, padding_out], dim=0)

        # padding target <EOS> as target for padding?
        target_len = len(target)
        padding_length = self.seq_len_out - target_len
        padding_index = target[-1] # <EOS>

        target = target + [ padding_index for i in range(padding_length)]
        target = torch.tensor(target).to(device)

        return x_in, x_out, target

# Test DataSet

path = "D:\Transformer\\"
dataset = Language_DataSet(path, language_in, language_out)
dl = DataLoader(dataset, batch_size=32,num_workers=4, shuffle=True)
if __name__ == '__main__':
    for eps_data in dl:
        x_in, x_out, target = eps_data
        print(x_in)
        print(x_out)
        print(target)

# Transformer Implementation
## Multi-Head-Attention

In [3]:
class Multi_Head_Attention(nn.Module):
    
    def __init__(self, heads, seq_len, dimensions):
        super().__init__()
        
        self.heads = heads
        self.dimensions = dimensions
        self.seq_len = seq_len
        
        self.qkv = nn.Linear(d_model, dimensions * heads * 3)
        self.final_linear = nn.Linear(self.heads * self.dimensions, d_model)
    
    def forward(self, x, mask=None):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # reshape for linear qkv layer
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # compute q,v,k for every head
        qkv = self.qkv(x) # (seq * batch, heads * 3 * dimensions)

        # reshape into (batch_size,seq_len,...)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * 3 * self.dimensions))
        # split into heads and seperate q, k, v in different dims
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads, 3, self.dimensions))
        
        # permute head to front for parallel processing
        qkv = qkv.permute(0,2,1,3,4)
        
        # extract q, k, v
        q = qkv[:,:,:,0,:]
        k = qkv[:,:,:,1,:]
        v = qkv[:,:,:,2,:]
        
        # fuse batch_size and head dim for parallel processing
        q = torch.reshape(q, (batch_size * self.heads, self.seq_len, self.dimensions))
        k = torch.reshape(k, (batch_size * self.heads, self.seq_len, self.dimensions))
        v = torch.reshape(v, (batch_size * self.heads, self.seq_len, self.dimensions))
        
        # transpose k
        k = torch.transpose(k, 1, 2)
        
        # multiply q and k
        qk = torch.bmm(q,k)
        # scale
        qk = qk / torch.sqrt(torch.tensor(self.dimensions).to(device).to(dtype))
        # optional masking
        if mask is not None:
            qk[mask == 1] = float('-inf')
        # softmax
        qk = F.softmax(qk, dim=2)     
        
        # multiply with v
        qkv = torch.bmm(qk, v)
        
        # reshape to cat heads
        qkv = torch.reshape(qkv, (batch_size, self.heads, self.seq_len, self.dimensions))
        # cat all heads
        qkv = qkv.permute(0,2,1,3)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * self.dimensions))
        
        # reshape to multiply with final linear
        qkv = torch.reshape(qkv, (batch_size * self.seq_len, self.heads * self.dimensions))
        # multiply with final linear
        z = self.final_linear(qkv)
        
        # reshape to input format
        z = torch.reshape(z, (batch_size, self.seq_len, d_model))
        
        return z

## Encoder Cell

In [4]:
class Encoder_Cell(nn.Module):
    
    def __init__(self, heads, seq_len, attention_dimension, ff_inner=1024):
        super().__init__()
        self.seq_len = seq_len
        
        self.self_attention = Multi_Head_Attention(heads, seq_len, attention_dimension)
        self.layer_norm_1 = nn.LayerNorm([seq_len,d_model])
        self.layer_norm_2 = nn.LayerNorm([seq_len,d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z = self.self_attention(x)
        
        # 1st residual
        residual_1 = x + z
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_1, (batch_size*self.seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, self.seq_len, d_model))
        
        # 2nd residual
        residual_2 = norm_1 + ff_out
        # 2nd norm
        norm_2 = self.layer_norm_1(residual_2)
        
        return norm_2

## Encoder

In [5]:
class Encoder(nn.Module):
    
    def __init__(self, cells, heads, seq_len, attention_dimensions):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        
        # stacked encoder cells
        encoder_cells = [ Encoder_Cell(heads, seq_len, attention_dimensions).to(device) for i in range(cells)]
        self.encode = nn.Sequential(*encoder_cells)
        
        # key and value output of encoder
        self.kv = nn.Linear(d_model, attention_dimensions * heads * 2)
    
    def forward(self, x):
        # encoding shape: batch_size, seq_len, d_model
        encoding = self.encode(x) 
        
        # reshape to feed into linear kv layer
        encoding = torch.reshape(encoding, (batch_size * self.seq_len, d_model))
        
        # apply linear
        kv = self.kv(encoding)
        # reshape back
        kv = torch.reshape(kv, (batch_size, self.seq_len, self.attention_dimensions * self.heads * 2))
        
        # seperate k and v
        kv = torch.reshape(kv, (batch_size, self.seq_len, self.heads, 2, self.attention_dimensions))
        
        # permute head to front for parallel processing
        kv = kv.permute(0,2,1,3,4)
        
        # split k, v
        k = kv[:,:,:,0,:]
        v = kv[:,:,:,1,:]
        
        # fuse batch_size and head dim for parallel processing
        k = torch.reshape(k, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        v = torch.reshape(v, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        
        return k, v

## Encoder-Decoder Attention

In [6]:
class Encoder_Decoder_Attention(nn.Module):

    def __init__(self, heads, seq_len, attention_dimensions):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        
        self.q = nn.Linear(d_model, attention_dimensions * heads).to(device)
        self.final_linear = nn.Linear(heads * attention_dimensions, d_model).to(device)
        
    def forward(self, x, encoder_k, encoder_v):
        # x shape: batch_size, seq_len, d_model
        # encoder k/v shape: batch_size*heads, seq_len, attention_dimensions
        
        # reshape for linear q layer
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # compute q for every head
        q = self.q(x) # (seq * batch, heads * attention_dimensions)

        # reshape into (batch_size,seq_len,...)
        q = torch.reshape(q, (batch_size, self.seq_len, self.heads * self.attention_dimensions))
        # split into heads 
        q = torch.reshape(q, (batch_size, self.seq_len, self.heads, self.attention_dimensions))
        
        # permute head to front for parallel processing
        q = q.permute(0,2,1,3)
        
        # fuse batch_size and head dim for parallel processing
        q = torch.reshape(q, (batch_size * self.heads, self.seq_len, self.attention_dimensions))
        
        # transpose k
        k = torch.transpose(encoder_k, 1, 2)
        
        # multiply q and k
        qk = torch.bmm(q,k)
        # scale
        qk = qk / torch.sqrt(torch.tensor(self.attention_dimensions).to(device).to(dtype))
        # softmax
        qk = F.softmax(qk, dim=2)     
        
        # multiply with v
        qkv = torch.bmm(qk, encoder_v)
        
        # reshape to cat heads
        qkv = torch.reshape(qkv, (batch_size, self.heads, self.seq_len, self.attention_dimensions))
        # cat all heads
        qkv = qkv.permute(0,2,1,3)
        qkv = torch.reshape(qkv, (batch_size, self.seq_len, self.heads * self.attention_dimensions))
        
        # reshape to multiply with final linear
        qkv = torch.reshape(qkv, (batch_size * self.seq_len, self.heads * self.attention_dimensions))
        # multiply with final linear
        z = self.final_linear(qkv)
        
        # reshape to input format
        z = torch.reshape(z, (batch_size, self.seq_len, d_model))
        
        return z     

## Decoder Cell

In [7]:
class Decoder_Cell(nn.Module):

    def __init__(self, heads, seq_len, attention_dimension, ff_inner=1024):
        super().__init__()
        self.seq_len = seq_len
        
        # construct decoder mask # MASK CORRECT?????
        a = np.triu(np.ones((seq_len,seq_len)), k=1) 
        mask = torch.unsqueeze(torch.tensor(a).to(device).long(),dim=0)
        self.mask = mask.repeat(batch_size*heads,1,1)
        
        self.self_attention = Multi_Head_Attention(heads, seq_len, attention_dimension).to(device)
        self.enc_dec_attention = Encoder_Decoder_Attention(heads, seq_len, attention_dimension).to(device)
        
        self.layer_norm_1 = nn.LayerNorm([seq_len,d_model])
        self.layer_norm_2 = nn.LayerNorm([seq_len,d_model])
        self.layer_norm_3 = nn.LayerNorm([seq_len,d_model])
        
        ff_network = [
            nn.Linear(d_model, ff_inner),
            nn.ReLU(),
            nn.Linear(ff_inner, d_model),
        ]
        self.feed_forward_net = nn.Sequential(*ff_network)
        
    def forward(self, x, encoder_k, encoder_v):
        # x shape: batch,seq_len,d_model
        batch_size, _, _ = x.shape
        
        # self attention
        z_1 = self.self_attention(x, self.mask)
        
        # 1st residual
        residual_1 = x + z_1
        # 1st norm
        norm_1 = self.layer_norm_1(residual_1)
        
        # encoder-decoder attention
        z_2 = self.enc_dec_attention(norm_1, encoder_k, encoder_v)
        
        # 2nd residual
        residual_2 = norm_1 + z_2
        # 2nd norm
        norm_2 = self.layer_norm_2(residual_2)
        
        # reshape norm for feed forward network
        ff_in = torch.reshape(norm_2, (batch_size*self.seq_len, d_model))
        # feed forward
        ff_out = self.feed_forward_net(ff_in)
        # reshape back
        ff_out = torch.reshape(ff_out, (batch_size, self.seq_len, d_model))
        
        # 3rd residual
        residual_3 = norm_2 + ff_out
        # 3rd norm
        norm_3 = self.layer_norm_3(residual_3)
        
        return norm_3

In [8]:
class Decoder(nn.Module):
    
    def __init__(self, cells, heads, seq_len, attention_dimensions, vocab_size):
        super().__init__()
        self.heads = heads
        self.seq_len = seq_len
        self.attention_dimensions = attention_dimensions
        self.vocab_size = vocab_size
        
        # stacked encoder cells
        self.decoder_cells = [ Decoder_Cell(heads, seq_len, attention_dimensions).to(device) for i in range(cells)]
        
        # output layer and then softmax
        self.final_linear = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, encoder_k, encoder_v):
        
        for decoder_cell in self.decoder_cells:
            x = decoder_cell(x,encoder_k, encoder_v)
            
        # reshape for linear
        x = torch.reshape(x, (batch_size*self.seq_len, d_model))
        
        # feed in final layer
        x = self.final_linear(x)
        
        # softmax over vocab_size
        #softmax = F.softmax(x, dim=1) maybe do, when input to cross entropy loss
        #return softmax
        
        return x

## Transformer

In [9]:
class Transformer(nn.Module):
    
    def __init__(self, cells, heads, seq_len_enc, seq_len_dec, attention_dimensions, vocab_size):
        super().__init__()
        
        self.encoder = Encoder(cells, heads, seq_len_enc, attention_dimensions).to(device)
        self.decoder = Decoder(cells, heads, seq_len_dec, attention_dimensions, vocab_size).to(device)
    
    def forward(self, x_encoder, x_decoder):
        # x_in shape: batch_size, seq_len_in, d_model
        # x_out shape: batch_size, seq_len_out, d_model
        
        encoder_k, encoder_v = self.encoder(x_encoder)
        out = self.decoder(x_decoder, encoder_k, encoder_v)
        
        # maybe reshape? to (batch_size, seq_len_out, vocab_size)
        return out

encoder_in = torch.randn(batch_size,seq_len_encoder,d_model).to(device).to(dtype)
decoder_in = torch.randn(batch_size,seq_len_decoder,d_model).to(device).to(dtype)
transformer = Transformer(3,4,seq_len_encoder,seq_len_decoder,64,40000).to(device)

start = time.time()
out = transformer(encoder_in,decoder_in)
end = time.time()
print(f'time: {end - start}')
out.shape
del out

print("weights:",sum(p.numel() for p in transformer.parameters() if p.requires_grad))

In [10]:
path = "D:\Transformer\\"
vocab_in = joblib.load(path+f'vocab_{language_in}.data')
vocab_out = joblib.load(path+f'vocab_{language_out}.data')

num_cells = 3
cell_embedding_size = 64
num_heads = 4
seq_len_encoder = vocab_in["max_sentence_len"] + 1
seq_len_decoder = vocab_out["max_sentence_len"] + 1
vocab_size = vocab_out["vocab_size"] + 1 # + <EOS>

transformer = Transformer(num_cells,num_heads,
                          seq_len_encoder,seq_len_decoder,
                          cell_embedding_size,vocab_size).to(device)

dataset = Language_DataSet(path, language_in, language_out)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=True,drop_last=True)

cross_entropy_loss = nn.CrossEntropyLoss()
lr = 5e-4
optimizer = torch.optim.Adam(transformer.parameters(), lr=lr)

for epoch in range(25):
    print("epoch:",epoch)
    for i, batch in enumerate(dataloader):
        start = time.time()
        optimizer.zero_grad()
        
        # data
        x_in, x_out, target = batch
        
        # reshape target for cross entropy loss
        target = torch.flatten(target)
        
        # run transformer
        out = transformer(x_in, x_out)
        #print(out.shape)
        #print(target.shape)
        #print(torch.flatten(target).shape)
        
        loss = cross_entropy_loss(out, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(transformer.parameters(), 0.5)
        optimizer.step()
        end = time.time()
        
        if i%25 == 0:
            print(f'i:{i} | loss: {loss} time: {end - start}')
        

epoch: 0




i:0 | loss: 9.823493003845215 time: 0.4517970085144043
i:25 | loss: 0.815750002861023 time: 0.18344879150390625
i:50 | loss: 0.5972856283187866 time: 0.18168139457702637
i:75 | loss: 0.5429232716560364 time: 0.16954612731933594
i:100 | loss: 0.5321913361549377 time: 0.2876427173614502
epoch: 1
i:0 | loss: 0.4799748957157135 time: 0.18187642097473145
i:25 | loss: 0.4780837595462799 time: 0.1825394630432129
i:50 | loss: 0.4553200602531433 time: 0.17553114891052246
i:75 | loss: 0.46833452582359314 time: 0.18348145484924316
i:100 | loss: 0.4419066607952118 time: 0.1820204257965088
epoch: 2
i:0 | loss: 0.43334677815437317 time: 0.18553400039672852
i:25 | loss: 0.4146915674209595 time: 0.17900419235229492
i:50 | loss: 0.42326411604881287 time: 0.17054486274719238
i:75 | loss: 0.3904164135456085 time: 0.18601083755493164
i:100 | loss: 0.3928542137145996 time: 0.18249058723449707
epoch: 3
i:0 | loss: 0.3707960247993469 time: 0.18550419807434082
i:25 | loss: 0.3635280132293701 time: 0.180488348

# Show token predictions while training 
(not inference atm and only small dataset)

In [84]:
for j in range(50):
    for i, batch in enumerate(dataloader):
        # data
        x_in, x_out, _ = batch

        # run transformer
        out = transformer(x_in, x_out)
        #print(out.shape)
        # reshape to batch_size, seq, vocab
        out = torch.reshape(out, (batch_size, seq_len_decoder,vocab_size))
        #print(out.shape)
        out = F.softmax(out, dim=2) 
        s = out[0]
        max_index = torch.argmax(s, dim=1).cpu().numpy()
        #print(max_index)

        s = ""
        for index in max_index:
            token = vocab_out[index]
            if token != "<EOS>":
                s +=  token+" "
        print(s)
        break



Where 's Boston ? 
Tom is a psycho . 
Happy birthday ! 
Follow me . 
Do it your way . 
Why do you work ? 
They kissed . 
I 'm biased . 
Stay calm . 
I 'm so fat . 
Tom has changed . 
Tom felt tired . 
Tom teaches . 
I love jokes . 
Tom drinks beer . 
Do n't hurt him . 
Examine this . 
We need them . 
You were brave . 
Is dinner ready ? 
I do n't want it . 
Stay for supper . 
Who succeeded ? 
Tom is credible . 
I know Tom . 
Who are they ? 
We respect them . 
Just swim . 
Tom enlisted . 
Open the doors . 
It is n't new . 
They saw me . 
Tom loves pasta . 
We 're very poor . 
We 're rich . 
I felt left out . 
When do you going ? 
How about you ? 
Come back in . 
I was at home . 
Tom felt cold . 
Kill them . 
How do I I it ? 
Tom is glad . 
Go lost . 
Breathe deeply . 
I 'll take Tom . 
He ca n't do it . 
I ca n't do it . 
I 'll pay . 


torch.save(transformer,f"transformer.pt")