<a href="https://colab.research.google.com/github/Kshitij-Ambilduke/NLP/blob/master/Attention_is_all_you_need_2_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
!pip install torchtext==0.6.0
import torch
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import torch.nn as nn
import torch.nn.functional as F
import spacy
import time
from copy import deepcopy as dcy



In [20]:
!python -m spacy download en --quiet
!python -m spacy download de --quiet

[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('en_core_web_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.7/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.7/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')
[38;5;2m✔ Download and installation successful[0m
You can now load the model via spacy.load('de_core_news_sm')
[38;5;2m✔ Linking successful[0m
/usr/local/lib/python3.7/dist-packages/de_core_news_sm -->
/usr/local/lib/python3.7/dist-packages/spacy/data/de
You can now load the model via spacy.load('de')


In [21]:
spacy_german = spacy.load("de")
spacy_english = spacy.load("en")
device = 'cuda'

In [22]:
def en_tokenizer(sen):
    tokens = []
    for token in spacy_english.tokenizer(sen):
        tokens.append(token.text)
    return tokens

def de_tokenizer(sen):
    tokens = []
    for token in spacy_german.tokenizer(sen):
        tokens.append(token.text)
    return tokens

In [23]:
# SOURCE_field = Field(eos_token='<src_eos>', init_token='<src_sos>',lower=True , tokenize = de_tokenizer)
# TARGET_field = Field(eos_token='<trg_eos>', init_token='<trg_sos>',lower=True , tokenize = en_tokenizer)

SOURCE_Field = Field(eos_token = '<src_eos>', init_token = '<src_sos>', lower = True, tokenize = de_tokenizer, batch_first = True )
TARGET_Field = Field(eos_token = '<trg_eos>', init_token = '<trg_sos>', lower = True, tokenize = en_tokenizer, batch_first = True )

train_data, valid_data, test_data = Multi30k.splits(exts=(".de",".en"), fields=(SOURCE_Field, TARGET_Field)) 

In [24]:
print("train length = "+str(len(train_data)))
print("test length = "+str(len(test_data)))
print("validation length = "+str(len(valid_data)))

train length = 29000
test length = 1000
validation length = 1014


In [25]:
print(vars(train_data[6])['src'])
print(vars(train_data[6])['trg'])

['ein', 'mann', 'lächelt', 'einen', 'ausgestopften', 'löwen', 'an', '.']
['a', 'man', 'is', 'smiling', 'at', 'a', 'stuffed', 'lion']


In [26]:
SOURCE_Field.build_vocab(train_data, min_freq=3)
TARGET_Field.build_vocab(train_data, min_freq=3)

In [27]:
BATCH_SIZE = 128
# train_iterator, valid_iterator, test_iterator = BucketIterator.splits((train_data, valid_data, test_data),batch_size=BATCH_SIZE, device=device)
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, device = device)

In [28]:
class Encoder(nn.Module):
  def __init__(self, vocab_size, enc_hidden_dim, query_dim, value_dim, src_pad_idx, start=False, final=False, max_len=100):
    super().__init__()

    self.src_padding_idx = src_pad_idx
    self.start = start
    self.final = final

    # if start:
    self.token_embed_enc = nn.Embedding(vocab_size, enc_hidden_dim)
    self.pos_embed_enc = nn.Embedding(max_len, enc_hidden_dim)

    #LAYER1
    self.x2Q1_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_enc = nn.Linear(enc_hidden_dim, query_dim)
    
    self.combine_enc = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm1_enc = nn.LayerNorm(enc_hidden_dim)
    self.dp = nn.Dropout(0.2)

    self.fc1_enc = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_enc = nn.Linear(2*enc_hidden_dim, enc_hidden_dim)

    self.norm2_enc = nn.LayerNorm(enc_hidden_dim)
    self.dp2_enc = nn.Dropout(0.2)

    #LAYER2
    self.x2Q1_2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_2_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_2_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_2_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_2_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_2_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_2_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_2_enc = nn.Linear(enc_hidden_dim, query_dim)

    
    self.combine_enc_2 = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm1_enc_2 = nn.LayerNorm(enc_hidden_dim)
    self.dp1_enc_2 = nn.Dropout(0.2)

    self.fc1_2_enc = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_2_enc = nn.Linear(2*enc_hidden_dim, enc_hidden_dim)

    self.norm2_enc_2 = nn.LayerNorm(enc_hidden_dim)
    self.dp2_enc_2 = nn.Dropout(0.2)

    #LAYER3
    self.x2Q1_3_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_3_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_3_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_3_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_3_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_3_enc = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_3_enc = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_3_enc = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_3_enc = nn.Linear(enc_hidden_dim, query_dim)
    
    self.combine_enc_3 = nn.Linear(3*value_dim, enc_hidden_dim)

    self.norm1_enc_3 = nn.LayerNorm(enc_hidden_dim)
    self.dp1_enc_2 = nn.Dropout(0.2)

    self.fc1_3_enc = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_3_enc = nn.Linear(2*enc_hidden_dim, enc_hidden_dim)

    self.norm2_enc_3 = nn.LayerNorm(enc_hidden_dim)
    self.dp2_enc_3 = nn.Dropout(0.2)

    # if self.final:
    #   FOR ENCODER FINAL STATE TO KEY AND VALUE
    # self.enc_op2V = nn.Linear(enc_hidden_dim, enc_hidden_dim)
    # self.enc_op2K = nn.Linear(enc_hidden_dim, enc_hidden_dim)
    

  def forward(self, x):
    # if self.start:
    mask = self.make_src_mask(x)

    # encop2V = None
    # encop2K = None
    
    #x = [ batch, seq_len ]

    device='cuda'
    x = x.to(device)

    seq_len = x.shape[1]
    batch_size = x.shape[0]

    # if self.start:
    temp = []
    for i in range(seq_len):
      temp.append(i)
    pos = []
    for i in range(batch_size):
      pos.append(temp)
    pos = torch.tensor(pos).to(device)

    token_embedding = self.token_embed_enc(x).to(device) #token_embedding = [batch, seq_len, embed_dim]
    position_embedding = self.pos_embed_enc(pos).to(device)
  
    x = position_embedding + token_embedding

    q1 = self.x2Q1_enc(x) #q = [batch, src_seq_len, query_dim]
    q2 = self.x2Q2_enc(x)
    q3 = self.x2Q3_enc(x)
    q4 = self.x2Q4_enc(x)

    v1 = self.x2V1_enc(x) #v = [batch, src_seq_len, value_dim]
    v2 = self.x2V2_enc(x)
    v3 = self.x2V3_enc(x)
    v4 = self.x2V4_enc(x)
  
    k1 = self.x2K1_enc(x) #k = [batch, src_seq_len, query_dim]
    k2 = self.x2K2_enc(x)
    k3 = self.x2K3_enc(x)
    k4 = self.x2K4_enc(x)

    attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
    attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
    attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
    attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]


    attention1/=8
    attention2/=8
    attention3/=8
    attention4/=8

    # if self.start:
    # print(mask.shape)
    attention1 = attention1.masked_fill(mask==False, -1e10)
    attention2 = attention2.masked_fill(mask==False, -1e10)
    attention3 = attention3.masked_fill(mask==False, -1e10)
    attention4 = attention4.masked_fill(mask==False, -1e10)

    attention1 = torch.softmax(attention1, dim=2)
    attention2 = torch.softmax(attention2, dim=2)
    attention3 = torch.softmax(attention3, dim=2)
    attention4 = torch.softmax(attention4, dim=2)

    weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
    weigthed2 = torch.matmul(attention2, v2)
    weigthed3 = torch.matmul(attention3, v3)
    weigthed4 = torch.matmul(attention4, v4)

    weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
    combined = self.combine_enc(weighted)

    combined = self.norm1_enc(self.dp(combined)+x)

    op1 = self.fc1_enc(combined)
    op1 = F.relu(op1)
    op1 = self.fc2_enc(op1)
    op1 = F.relu(op1)

    op1 = self.norm2_enc(self.dp(op1)+combined)
    x=op1.clone().detach() 

    #LAYER2
    q1 = self.x2Q1_2_enc(op1) #q = [batch, src_seq_len, query_dim]
    q2 = self.x2Q2_2_enc(op1)
    q3 = self.x2Q3_2_enc(op1)
    q4 = self.x2Q4_2_enc(op1)


    v1 = self.x2V1_2_enc(op1) #v = [batch, src_seq_len, value_dim]
    v2 = self.x2V2_2_enc(op1)
    v3 = self.x2V3_2_enc(op1)
    v4 = self.x2V4_2_enc(op1)
  
    k1 = self.x2K1_2_enc(op1) #k = [batch, src_seq_len, query_dim]
    k2 = self.x2K2_2_enc(op1)
    k3 = self.x2K3_2_enc(op1)
    k4 = self.x2K4_2_enc(op1)

    attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
    attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
    attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
    attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]


    attention1/=8
    attention2/=8
    attention3/=8
    attention4/=8

    attention1 = attention1.masked_fill(mask==False, -1e10)
    attention2 = attention2.masked_fill(mask==False, -1e10)
    attention3 = attention3.masked_fill(mask==False, -1e10)
    attention4 = attention4.masked_fill(mask==False, -1e10)

    attention1 = torch.softmax(attention1, dim=2)
    attention2 = torch.softmax(attention2, dim=2)
    attention3 = torch.softmax(attention3, dim=2)
    attention4 = torch.softmax(attention4, dim=2)

    weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
    weigthed2 = torch.matmul(attention2, v2)
    weigthed3 = torch.matmul(attention3, v3)
    weigthed4 = torch.matmul(attention4, v4)

    weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
    combined = self.combine_enc_2(weighted)

    combined=self.norm1_enc_2(self.dp(combined)+x)

    op1 = self.fc1_2_enc(combined)
    op1 = F.relu(op1)
    op1 = self.fc2_2_enc(op1)
    op1 = F.relu(op1)

    op1 = self.norm2_enc_2(self.dp(op1)+combined)

    #LAYER3
    q1 = self.x2Q1_3_enc(op1) #q = [batch, src_seq_len, query_dim]
    q2 = self.x2Q2_3_enc(op1)
    q3 = self.x2Q3_3_enc(op1)

    v1 = self.x2V1_3_enc(op1) #v = [batch, src_seq_len, value_dim]
    v2 = self.x2V2_3_enc(op1)
    v3 = self.x2V3_3_enc(op1)

    k1 = self.x2K1_3_enc(op1) #k = [batch, src_seq_len, query_dim]
    k2 = self.x2K2_3_enc(op1)
    k3 = self.x2K3_3_enc(op1)

    attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
    attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
    attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]

    attention1/=8
    attention2/=8
    attention3/=8

    attention1 = attention1.masked_fill(mask==False, -1e10)
    attention2 = attention2.masked_fill(mask==False, -1e10)
    attention3 = attention3.masked_fill(mask==False, -1e10)


    attention1 = torch.softmax(attention1, dim=2)
    attention2 = torch.softmax(attention2, dim=2)
    attention3 = torch.softmax(attention3, dim=2)

    weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
    weigthed2 = torch.matmul(attention2, v2)
    weigthed3 = torch.matmul(attention3, v3)

    weighted = torch.cat((weigthed1, weigthed2, weigthed3),dim=-1)
    combined = self.combine_enc_3(weighted)

    combined=self.norm1_enc_3(self.dp(combined)+op1)

    op1 = self.fc1_3_enc(combined)
    op1 = F.relu(op1)
    op1 = self.fc2_3_enc(op1)
    op1 = F.relu(op1)

    op1 = self.norm2_enc_3(self.dp(op1)+combined)

    

    # if self.final:
    encop2V = op1
    encop2K = op1

    return mask,op1

  def make_src_mask(self, src):                                                       # src = [batch_size, src_len]
        src_mask = (src != self.src_padding_idx).unsqueeze(1).to(device)   # src_mask = [batch_size, src_len, 1]
        return src_mask


In [29]:
class Decoder(nn.Module):
  def __init__(self, target_vocab_len, enc_hidden_dim, query_dim, value_dim, trg_pad_idx,start=False, final=False ,max_len=100):

    super().__init__()
    self.trg_pad_idx = trg_pad_idx
    
    self.final = final
    self.start = start

    # self.norm_enc_dec1 = nn.LayerNorm(enc_hidden_dim)
    # self.norm_enc_dec2 = nn.LayerNorm(enc_hidden_dim)
    # self.norm_enc_dec3 = nn.LayerNorm(enc_hidden_dim)

    # self.dp_enc_dec1 = nn.Dropout(0.4)
    # self.dp_enc_dec2 = nn.Dropout(0.4)
    # self.dp_enc_dec3 = nn.Dropout(0.4)

    # if start:
    self.token_embed_dec = nn.Embedding(target_vocab_len, enc_hidden_dim)
    self.pos_embed_dec = nn.Embedding(max_len, enc_hidden_dim)

    self.x2Q1_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_dec = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm1_dec = nn.LayerNorm(enc_hidden_dim)
    self.dp1 = nn.Dropout(0.2)

    self.x2Q1_enc_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_enc_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_enc_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_enc_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_enc_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_enc_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_enc_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_enc_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_enc_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_enc_dec = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_enc_dec = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_enc_dec = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_enc_dec = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm2_dec = nn.LayerNorm(enc_hidden_dim)
    self.dp2_dec = nn.Dropout(0.2)

    self.fc1_dec = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_dec = nn.Linear(enc_hidden_dim*2, enc_hidden_dim)

    self.norm3_dec = nn.LayerNorm(enc_hidden_dim)
    self.dp3_dec = nn.Dropout(0.2)

    #LAYER2
    self.x2Q1_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_dec_2 = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm1_dec_2 = nn.LayerNorm(enc_hidden_dim)
    self.dp1_dec_2 = nn.Dropout(0.2)

    self.x2Q1_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_enc_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_enc_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_enc_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_enc_dec_2 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_enc_dec_2 = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_enc_dec_2 = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm2_dec_2 = nn.LayerNorm(enc_hidden_dim)
    self.dp2_dec_2 = nn.Dropout(0.2)

    self.fc1_dec_2 = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_dec_2 = nn.Linear(enc_hidden_dim*2, enc_hidden_dim)

    self.norm3_dec_2 = nn.LayerNorm(enc_hidden_dim)
    self.dp3_dec_2 = nn.Dropout(0.2)

    #LAYER3
    self.x2Q1_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_dec_3 = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm1_dec_3 = nn.LayerNorm(enc_hidden_dim)
    self.dp1_dec_3 = nn.Dropout(0.2)

    self.x2Q1_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V1_enc_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K1_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q2_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V2_enc_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K2_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q3_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V3_enc_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K3_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.x2Q4_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim) 
    self.x2V4_enc_dec_3 = nn.Linear(enc_hidden_dim, value_dim)
    self.x2K4_enc_dec_3 = nn.Linear(enc_hidden_dim, query_dim)

    self.combine_enc_dec_3 = nn.Linear(4*value_dim, enc_hidden_dim)

    self.norm2_dec_3 = nn.LayerNorm(enc_hidden_dim)
    self.dp2_dec_3 = nn.Dropout(0.2)

    self.fc1_dec_3 = nn.Linear(enc_hidden_dim, 2*enc_hidden_dim)
    self.fc2_dec_3 = nn.Linear(enc_hidden_dim*2, enc_hidden_dim)

    self.norm3_dec_3 = nn.LayerNorm(enc_hidden_dim)

    # if final:
    self.output_layer = nn.Linear(dec_hidden_dim, target_vocab_len)

  def forward(self, encopV, encopK, x, enc_mask):


      
      # if self.start:
      mask = self.make_trg_mask(x)

      device='cuda'
      x = x.to(device)
      seq_len = x.shape[1]
      batch_size = x.shape[0]

      # if self.start:
      temp = []
      for i in range(seq_len):
        temp.append(i)
      pos = []
      for i in range(batch_size):
        pos.append(temp)
      pos = torch.tensor(pos).to(device)

      token_embedding = self.token_embed_dec(x)
      pos_embedding = self.pos_embed_dec(pos)

      x = pos_embedding + token_embedding

      q1 = self.x2Q1_dec(x)
      v1 = self.x2V1_dec(x)
      k1 = self.x2K1_dec(x) 

      q2 = self.x2Q2_dec(x)
      v2 = self.x2V2_dec(x) 
      k2 = self.x2K2_dec(x) 

      q3 = self.x2Q3_dec(x) 
      v3 = self.x2V3_dec(x) 
      k3 = self.x2K3_dec(x)

      q4 = self.x2Q4_dec(x) 
      v4 = self.x2V4_dec(x) 
      k4 = self.x2K4_dec(x)

      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]

      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8

      # if self.start:
      attention1 = attention1.masked_fill(mask==False, -1e10)
      attention2 = attention2.masked_fill(mask==False, -1e10)
      attention3 = attention3.masked_fill(mask==False, -1e10)
      attention4 = attention4.masked_fill(mask==False, -1e10)

      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention4, v4)


      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined = self.combine_dec(weighted)

      combined = self.norm1_dec(self.dp1(combined)+x)
      tba = combined.clone().detach()

      q1 = self.x2Q1_enc_dec(combined)
      v1 = self.x2V1_enc_dec(encopV)
      k1 = self.x2K1_enc_dec(encopK)

      q2 = self.x2Q2_enc_dec(combined) 
      v2 = self.x2V2_enc_dec(encopV)
      k2 = self.x2K2_enc_dec(encopK)

      q3 = self.x2Q3_enc_dec(combined)
      v3 = self.x2V3_enc_dec(encopV)
      k3 = self.x2K3_enc_dec(encopK)

      q4 = self.x2Q4_enc_dec(combined)
      v4 = self.x2V4_enc_dec(encopV)
      k4 = self.x2K4_enc_dec(encopK)


      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, trg_seq_len, src_seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, trg_seq_len, src_seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, trg_seq_len, src_seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, trg_seq_len, src_seq_len]

      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8

      attention1 = attention1.masked_fill(enc_mask==False, -1e10)
      attention2 = attention2.masked_fill(enc_mask==False, -1e10)
      attention3 = attention3.masked_fill(enc_mask==False, -1e10)
      attention4 = attention4.masked_fill(enc_mask==False, -1e10)


      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention4, v4)

      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined = self.combine_enc_dec(weighted)
      
      combined = self.norm2_dec(self.dp1(combined)+tba)
      # combined+=tba

      op1 = self.fc1_dec(combined)
      op1 = F.relu(op1)
      op1 = self.fc2_dec(op1)
      op1 = F.relu(op1)
      
      op1 = self.norm3_dec(self.dp1(op1)+combined)
      x = op1.clone().detach()

      #LAYER 2
      q1 = self.x2Q1_dec_2(op1)
      v1 = self.x2V1_dec_2(op1)
      k1 = self.x2K1_dec_2(op1) 

      q2 = self.x2Q2_dec_2(op1)
      v2 = self.x2V2_dec_2(op1) 
      k2 = self.x2K2_dec_2(op1) 

      q3 = self.x2Q3_dec_2(op1) 
      v3 = self.x2V3_dec_2(op1) 
      k3 = self.x2K3_dec_2(op1)

      q4 = self.x2Q4_dec_2(op1) 
      v4 = self.x2V4_dec_2(op1) 
      k4 = self.x2K4_dec_2(op1)

      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]

      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8

      attention1 = attention1.masked_fill(mask==False, -1e10)
      attention2 = attention2.masked_fill(mask==False, -1e10)
      attention3 = attention3.masked_fill(mask==False, -1e10)
      attention4 = attention4.masked_fill(mask==False, -1e10)

      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)
      
    

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention4, v4)      

      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined_layer1 = self.combine_dec_2(weighted)

      combined_layer1 = self.norm1_dec_2(self.dp1(combined_layer1)+op1)
      tba = combined_layer1.clone().detach()

      q1 = self.x2Q1_enc_dec_2(combined_layer1)
      v1 = self.x2V1_enc_dec_2(encopV)
      k1 = self.x2K1_enc_dec_2(encopV)

      q2 = self.x2Q2_enc_dec_2(combined_layer1) 
      v2 = self.x2V2_enc_dec_2(encopV)
      k2 = self.x2K2_enc_dec_2(encopV)

      q3 = self.x2Q3_enc_dec_2(combined_layer1)
      v3 = self.x2V3_enc_dec_2(encopV)
      k3 = self.x2K3_enc_dec_2(encopV)

      q4 = self.x2Q4_enc_dec_2(combined_layer1)
      v4 = self.x2V4_enc_dec_2(encopV)
      k4 = self.x2K4_enc_dec_2(encopV)

      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]


      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8


      attention1 = attention1.masked_fill(enc_mask==False, -1e10)
      attention2 = attention2.masked_fill(enc_mask==False, -1e10)
      attention3 = attention3.masked_fill(enc_mask==False, -1e10)
      attention4 = attention4.masked_fill(enc_mask==False, -1e10)



      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention4, v4)

      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined = self.combine_enc_dec_2(weighted)
      


      combined = self.norm2_dec_2(self.dp1(combined)+tba)
      # combined+=tba



      op1 = self.fc1_dec_2(combined)
      op1 = F.relu(op1)
      op1 = self.fc2_dec_2(op1)
      op1 = F.relu(op1)
      
      op1 = self.norm3_dec_2(self.dp1(op1)+combined)

      #LAYER 3
      
      q1 = self.x2Q1_dec_3(op1)
      v1 = self.x2V1_dec_3(op1)
      k1 = self.x2K1_dec_3(op1) 

      q2 = self.x2Q2_dec_3(op1)
      v2 = self.x2V2_dec_3(op1) 
      k2 = self.x2K2_dec_3(op1) 

      q3 = self.x2Q3_dec_3(op1) 
      v3 = self.x2V3_dec_3(op1) 
      k3 = self.x2K3_dec_3(op1)

      q4 = self.x2Q4_dec_3(op1)
      v4 = self.x2V4_dec_3(op1)
      k4 = self.x2K4_dec_3(op1) 


      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]

      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8

      attention1 = attention1.masked_fill(mask==False, -1e10)
      attention2 = attention2.masked_fill(mask==False, -1e10)
      attention3 = attention3.masked_fill(mask==False, -1e10)
      attention4 = attention4.masked_fill(mask==False, -1e10)

      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention4, v4)

      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined_layer2 = self.combine_dec_2(weighted)

      combined_layer2 = self.norm1_dec_2(self.dp1(combined_layer2)+op1)
      tba = combined_layer2.clone().detach()

      q1 = self.x2Q1_enc_dec_2(combined_layer2)
      v1 = self.x2V1_enc_dec_2(encopV)
      k1 = self.x2K1_enc_dec_2(encopV)

      q2 = self.x2Q2_enc_dec_2(combined_layer2) 
      v2 = self.x2V2_enc_dec_2(encopV)
      k2 = self.x2K2_enc_dec_2(encopV)

      q3 = self.x2Q3_enc_dec_2(combined_layer2)
      v3 = self.x2V3_enc_dec_2(encopV)
      k3 = self.x2K3_enc_dec_2(encopV)

      q4 = self.x2Q4_enc_dec_2(combined_layer2)
      v4 = self.x2V4_enc_dec_2(encopV)
      k4 = self.x2K4_enc_dec_2(encopV)

      attention1 = torch.matmul(q1, k1.permute(0,2,1)) #attention1 = [batch, seq_len, seq_len]
      attention2 = torch.matmul(q2, k2.permute(0,2,1)) #attention2 = [batch, seq_len, seq_len]
      attention3 = torch.matmul(q3, k3.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]
      attention4 = torch.matmul(q4, k4.permute(0,2,1)) #attention3 = [batch, seq_len, seq_len]

      attention1/=8
      attention2/=8
      attention3/=8
      attention4/=8

      attention1 = attention1.masked_fill(enc_mask==False, -1e10)
      attention2 = attention2.masked_fill(enc_mask==False, -1e10)
      attention3 = attention3.masked_fill(enc_mask==False, -1e10)
      attention4 = attention4.masked_fill(enc_mask==False, -1e10)

      attention1 = torch.softmax(attention1, dim=2)
      attention2 = torch.softmax(attention2, dim=2)
      attention3 = torch.softmax(attention3, dim=2)
      attention4 = torch.softmax(attention4, dim=2)

      weigthed1 = torch.matmul(attention1, v1) #weigthed = [batch, seq, value_dim]
      weigthed2 = torch.matmul(attention2, v2)
      weigthed3 = torch.matmul(attention3, v3)
      weigthed4 = torch.matmul(attention3, v4)

      weighted = torch.cat((weigthed1, weigthed2, weigthed3, weigthed4),dim=-1)
      combined = self.combine_enc_dec_3(weighted)
      
      combined = self.norm2_dec_3(self.dp1(combined_layer1)+tba)
      # combined+=tba

      op1 = self.fc1_dec_3(combined)
      op1 = F.relu(op1)
      op1 = self.fc2_dec_3(op1)
      op1 = F.relu(op1)
      
      op1 = self.norm3_dec_3(self.dp1(op1)+combined)

      # if self.final:
      output = self.output_layer(op1)
      # else:
      #   output = op1
      
      return output



  def make_trg_mask(self, trg):
        #trg = [batch size, trg len]
        
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1)
        #trg_pad_mask = [batch size, trg len, 1]
        
        trg_len = trg.shape[1]
        
        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = 'cuda')).bool()
        #trg_sub_mask = [trg len, trg len]
        
        trg_mask = trg_pad_mask & trg_sub_mask
        #trg_mask = [batch size, trg len, trg len]                              
        
        return trg_mask

  # def make_src_mask(self, src):                                                       # src = [batch_size, src_len]
  #       src_mask = (src != self.src_padding_idx).unsqueeze(2).to(device)   # src_mask = [batch_size, src_len, 1]
  #       return src_mask


In [30]:
class Transformer(nn.Module):
  def __init__(self, enc0, dec0):
    super().__init__()
    self.enc0 = enc0
    # self.enc1 = enc1

    self.dec0 = dec0
    # self.dec1 = dec1

  def forward(self, source, target ):
    target = target[:,:-1]
    mask_of_enc, enc_op = self.enc0(source)
    # _,enc_op_V,enc_op_K = self.enc1(enc_op)

    output = self.dec0(enc_op,enc_op, target, mask_of_enc)
    # output = self.dec1(enc_op_V,enc_op_K, output)
    
    return output


In [31]:
target_padding_index = TARGET_Field.vocab.stoi[TARGET_Field.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = target_padding_index)


SRC_PAD_IDX = SOURCE_Field.vocab.stoi[SOURCE_Field.pad_token]
TRG_PAD_IDX = TARGET_Field.vocab.stoi[TARGET_Field.pad_token]

source_vocab_len = len(SOURCE_Field.vocab)
target_vocab_len = len(TARGET_Field.vocab)
enc_hidden_dim = 512
dec_hidden_dim = 512
query_len = 64
value_len = 64

enc0 = Encoder(source_vocab_len, enc_hidden_dim, query_len, value_len, SRC_PAD_IDX, start=True, final=True)
# enc1 = Encoder(source_vocab_len, enc_hidden_dim, query_len, value_len, SRC_PAD_IDX, start=False, final=True)

dec0 = Decoder(target_vocab_len, dec_hidden_dim, query_len, value_len, TRG_PAD_IDX,start=True, final=True)
# dec1 = Decoder(target_vocab_len, dec_hidden_dim, query_len, value_len, TRG_PAD_IDX,start=False, final=True)

translator = Transformer(enc0, dec0).to(device)

# optimizer = torch.optim.Adam(translator.parameters())
# print(translator)

In [32]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(translator):,} trainable parameters')

The model has 18,439,692 trainable parameters


In [39]:
target_padding_index = TARGET_Field.vocab.stoi[TARGET_Field.pad_token]
criterion = nn.CrossEntropyLoss(ignore_index = target_padding_index)
optimizer = torch.optim.Adam(translator.parameters(),lr=0.00005)

In [40]:
def valid(model, iterator, criterion, optimizer): 
  model.eval()                               #just tells pytorch that we are in training phase
  epoch_loss = 0
  with torch.no_grad():
    for i, batch in enumerate(iterator):

        source = batch.src  #source = [batch, seq_len_source]
        target = batch.trg  #target = [batch, seq_len_target]
        # print("source",source.shape)
        # print("target",target.shape)
        # source = source.to(device)
        # target = target.to(device)
        source.to(device)
        target.to(device)

        # optimizer.zero_grad()
        output = model.forward(source,target)        #output = [ batch, seq_len, output_dim]
        # output = output[:,,:]
        output = output.reshape(-1, output.shape[-1])
        # print("op",output.shape)
        target = target[:,1:].to(device)
        target = target.reshape(-1)

        # print(target.shape)
        loss = criterion(output, target)
        # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        # loss.backward()
        # optimizer.step()
        epoch_loss += loss.item()

  return epoch_loss / len(iterator)



In [41]:
def train(model, iterator, criterion, optimizer): 
  model.train()                               #just tells pytorch that we are in training phase
  epoch_loss = 0

  for i, batch in enumerate(iterator):

      source = batch.src  #source = [batch, seq_len_source]
      target = batch.trg  #target = [batch, seq_len_target]
      # print("source",source.shape)
      # print("target",target.shape)
      # source = source.to(device)
      # target = target.to(device)
      source.to(device)
      target.to(device)

      # optimizer.zero_grad()        
      # loss, hidden = model(data, hidden, targets)
      # loss.backward()

      # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
      # optimizer.step()

      optimizer.zero_grad()
      output = model.forward(source,target)        #output = [ batch, seq_len, output_dim]
      # output = output[:,,:]
      output = output.reshape(-1, output.shape[-1])
      # print("op",output.shape)
      target = target[:,1:].to(device)
      target = target.reshape(-1)

      # print(target.shape)
      loss = criterion(output, target)

      loss.backward()
      torch.nn.utils.clip_grad_norm_(model.parameters(), 0.4)
      optimizer.step()
      epoch_loss += loss.item()

  return epoch_loss / len(iterator)



In [42]:
for epoch in range(30):
    start = time.time()
    train_loss = train(translator, train_iterator, criterion,  optimizer)
    valid_loss = valid(translator, valid_iterator, criterion,  optimizer)
    stop = time.time()
    print("train_loss ",train_loss)
    print("valid_loss ",valid_loss)
    print("time: ", stop-start)
    print()

train_loss  0.9893843515854049
valid_loss  1.7034222483634949
time:  62.826109886169434

train_loss  0.9522323718680159
valid_loss  1.7053358852863312
time:  62.84400272369385



KeyboardInterrupt: ignored

In [43]:

def translate_sentence(sentence, src_field, trg_field, model, device, enc0, dec0, max_len = 50):
    
    model.eval()
        
    if isinstance(sentence, str):
        nlp = spacy.load('de_core_news_sm')
        tokens = [token.text.lower() for token in nlp(sentence)]
    else:
        tokens = [token.lower() for token in sentence]

    tokens = [src_field.init_token] + tokens + [src_field.eos_token]
        
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    # src_mask = enc0.make_src_mask(src_tensor)
    
    with torch.no_grad():
        mask_of_enc, enc_src = model.enc0(src_tensor)
        # _,encV,encK = model.enc1(enc_src)
        # enc_src = model.enc2(enc_src)
        # enc_src = model.enc3(enc_src)

    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

        # trg_mask = dec0.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output = model.dec0(enc_src,enc_src, trg_tensor,mask_of_enc)
            # output = model.dec1(encV,encK, output)
            # output = model.dec2(enc_src, output)
            # output = model.dec3(enc_src, output)

        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break
    
    trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:]

In [44]:
example_idx = 3

src = vars(test_data.examples[example_idx])['src']
trg = vars(test_data.examples[example_idx])['trg']

print(f'src = {src}')
print(f'trg = {trg}')

translation = translate_sentence(src, SOURCE_Field, TARGET_Field, translator, device, enc0, dec0)

print(f'predicted trg = {translation}')

src = ['fünf', 'leute', 'in', 'winterjacken', 'und', 'mit', 'helmen', 'stehen', 'im', 'schnee', 'mit', 'schneemobilen', 'im', 'hintergrund', '.']
trg = ['five', 'people', 'wearing', 'winter', 'jackets', 'and', 'helmets', 'stand', 'in', 'the', 'snow', ',', 'with', 'snowmobiles', 'in', 'the', 'background', '.']
predicted trg = ['five', 'people', 'in', 'winter', 'jackets', 'and', 'helmets', 'are', 'standing', 'in', 'the', 'snow', 'with', '<unk>', 'in', 'the', 'background', '.', '<trg_eos>']


In [45]:
from torchtext.data.metrics import bleu_score
def Calculate_BLEU(data, SOURCE_Field, TARGET_Field, translator,device,enc0,dec0):
    trgs = []
    predicted_trgs = []
    for i in range(len(data.examples)):
        # print(predicted_trgs)
        # print(trgs)
        src_sentence = vars(data[i])['src']
        trg_sentence = vars(data[i])['trg']
        try:                                # Sometimes(rarely) CUDA throws a "Device side assert triggered" error. So, just to avoid restarting runtime.
            translation = translate_sentence(src_sentence, SOURCE_Field, TARGET_Field, translator, device, enc0, dec0)
            predicted_trgs.append(translation[:-1])
            trgs.append([trg_sentence])
            # print(predicted_trgs)
            # print(trgs)
        except:
            pass
    print(predicted_trgs)
    print(trgs)
    return bleu_score(predicted_trgs, trgs)

bleu_score_test = Calculate_BLEU(test_data, SOURCE_Field, TARGET_Field, translator,device,enc0,dec0)
print(f"BLEU score on Testing Data: {bleu_score_test*100:.2f}")

[['a', 'man', 'in', 'an', 'orange', 'hat', 'is', 'cooking', 'something', '.'], ['a', 'boston', 'terrier', 'runs', 'over', 'grass', 'in', 'front', 'of', 'a', 'white', 'fence', '.'], ['a', 'girl', 'in', 'a', 'karate', 'outfit', 'is', 'taking', 'a', 'picture', 'with', 'a', 'hula', 'hoop', '.'], ['five', 'people', 'in', 'winter', 'jackets', 'and', 'helmets', 'are', 'standing', 'in', 'the', 'snow', 'with', '<unk>', 'in', 'the', 'background', '.'], ['people', 'repair', 'the', 'roof', 'of', 'a', 'house', '.'], ['a', 'man', 'dressed', 'in', 'colorful', 'clothing', 'and', 'a', 'group', 'of', 'men', 'in', 'dark', 'suits', 'and', 'hats', 'are', 'standing', 'around', 'a', 'woman', 'in', 'a', 'white', 'dress', '.'], ['a', 'group', 'of', 'people', 'standing', 'in', 'front', 'of', 'a', '<unk>', '.'], ['a', 'boy', 'in', 'a', 'red', 'jersey', 'trying', 'to', 'reach', 'the', 'right', 'as', 'the', 'kicker', 'passes', 'in', 'the', 'blue', 'uniform', 'is', 'trying', 'to', 'pass', 'the', 'ball', '.'], ['a',

In [None]:
a = torch.randn(128,20).unsqueeze(2)
a1 = torch.randn(128,20,64)
a1 = a1.masked_fill(a==0,0)
print(a.shape)
print(a1.shape)

In [None]:
from copy import deepcopy as dcy
a = torch.ones(2,2)
b = a.clone().detach()
a += torch.ones(2,2)
a+= torch.ones(2,2)
print(a)
print(b)