# text2rdf

## Imports

In [1]:
# General purpose
import os
import glob
import random
import xml.etree.ElementTree as ET

# PyTorch 
import torch
import torch.nn as nn
from torch import optim

# Bert
from transformers import BertTokenizer, BertModel, BertConfig
#Bleu score
from nltk.translate.bleu_score import corpus_bleu

In [2]:
# CUDA related
device = torch.device('cuda')
print("Device:", device)

Device: cuda


## Preprocessing

### Select Data-sub-set (given restricted nr of triples per sentence)

In [3]:
# How many triples to train and test system on (min: 1, max: 7)
MIN_NUM_TRIPLES = 1
MAX_NUM_TRIPLES = 1

In [4]:
# Set paths where to retrieve data from
DS_BASE_PATH = './WebNLG/'

TRAIN_PATH = DS_BASE_PATH + 'train/'
TEST_PATH = DS_BASE_PATH + 'dev/'

TRAIN_DIRS = [ TRAIN_PATH + str(i) + 'triples/' for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1) ]
TEST_DIRS  = [ TEST_PATH  + str(i) + 'triples/' for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1) ]

In [5]:
# Print selected directories
print('Train dirs:', TRAIN_DIRS)
print('Test  dirs:', TEST_DIRS)

Train dirs: ['./WebNLG/train/1triples/']
Test  dirs: ['./WebNLG/dev/1triples/']


### Load Data

#### Load Settings (do not touch)

In [6]:
originaltripleset_index = 0  # Index of 'originaltripleset' attribute in XML entry
modifiedtripleset_index = 1  # Index of 'modifiedtripleset' attribute in XML entry
first_lexical_index = 2      # Index as of which verbalizations of RDF triples start in entry

#### Train Data

In [7]:
# Usage of train: train[target_nr_triples][element_id]['target_attribute']
train = [[] for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)]

# Documents how many entries there are per number of triples
train_stats = [0 for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)]

# Iterate through all files per number of triples and per category and load data
for i, d in enumerate(TRAIN_DIRS):
    nr_triples = list(range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1))[i]
    
    for filename in glob.iglob(d+'/**', recursive=False):
        if os.path.isfile(filename): # Filter dirs
            #print('File:', filename)
            
            tree = ET.parse(filename)
            root = tree.getroot()
            
            entries = root[0]
            train_stats[nr_triples-MIN_NUM_TRIPLES] += len(entries)
            
            for entry in entries:
                #print('Original triple set: ', entry[originaltripleset_index])
                #print('Modified triple set: ', entry[modifiedtripleset_index])
                
                modified_triple_set = entry[modifiedtripleset_index]
                unified_triple_set = []
                
                for triple in modified_triple_set:
                    # Make a list containing a conjunction of all individual triples
                    triple_list = [x.strip() for x in triple.text.split('|')]
                    unified_triple_set += triple_list
                    
                verbalizations = entry[first_lexical_index:]

                for verbalization in verbalizations:
                    if verbalization.text.strip() == '':
                        continue
                    #print('Text:', verbalization, verbalization.tag, verbalization.attrib, verbalization.text)
                    #print('Trip:', triple, triple.tag, triple.attrib, triple.text)

                    train[i].append({ 'category': entry.attrib['category'],
                                      'id': entry.attrib['eid'],
                                      'triple_cnt': nr_triples,
                                      'text': verbalization.text,
                                      'triple': unified_triple_set,
                                    })
                        
print(train)
print(train_stats)

[[{'category': 'Food', 'id': 'Id1', 'triple_cnt': 1, 'text': 'Ajoblanco originates from the country of Spain.', 'triple': ['Ajoblanco', 'country', 'Spain']}, {'category': 'Food', 'id': 'Id1', 'triple_cnt': 1, 'text': 'Ajoblanco is from Spain.', 'triple': ['Ajoblanco', 'country', 'Spain']}, {'category': 'Food', 'id': 'Id2', 'triple_cnt': 1, 'text': 'Ajoblanco has almond as one of its ingredients.', 'triple': ['Ajoblanco', 'ingredient', 'Almond']}, {'category': 'Food', 'id': 'Id2', 'triple_cnt': 1, 'text': 'Almond is an ingredient in ajoblanco.', 'triple': ['Ajoblanco', 'ingredient', 'Almond']}, {'category': 'Food', 'id': 'Id3', 'triple_cnt': 1, 'text': 'Bread is an ingredient of Ajoblanco.', 'triple': ['Ajoblanco', 'ingredient', 'Bread']}, {'category': 'Food', 'id': 'Id4', 'triple_cnt': 1, 'text': 'An ingredient of Ajoblanco is garlic.', 'triple': ['Ajoblanco', 'ingredient', 'Garlic']}, {'category': 'Food', 'id': 'Id4', 'triple_cnt': 1, 'text': 'Garlic is an ingredient used in Ajoblanco

#### Test Data

In [8]:
# Usage of test: test[target_nr_triples][element_id]['target_attribute']
test = [[] for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)]

# Documents how many entries there are per number of triples
test_stats = [0 for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)]

# Iterate through all files per number of triples and per category and load data
for i, d in enumerate(TEST_DIRS):
    nr_triples = list(range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1))[i]
    
    for filename in glob.iglob(d+'/**', recursive=False):
        if os.path.isfile(filename): # Filter dirs
            #print('File:', filename)
            
            tree = ET.parse(filename)
            root = tree.getroot()
            
            entries = root[0]
            test_stats[nr_triples-MIN_NUM_TRIPLES] += len(entries)
            
            for entry in entries:
                #print('Original triple set: ', entry[originaltripleset_index])
                #print('Modified triple set: ', entry[modifiedtripleset_index])
                
                modified_triple_set = entry[modifiedtripleset_index]
                unified_triple_set = []
                
                for triple in modified_triple_set:
                    # Make a list containing a conjunction of all individual triples
                    triple_list = [x.strip() for x in triple.text.split('|')]
                    unified_triple_set += triple_list
                    
                verbalizations = entry[first_lexical_index:]

                for verbalization in verbalizations:
                    if verbalization.text.strip() == '':
                        continue
                    #print('Text:', verbalization, verbalization.tag, verbalization.attrib, verbalization.text)
                    #print('Trip:', triple, triple.tag, triple.attrib, triple.text)

                    test[i].append({ 'category': entry.attrib['category'],
                                      'id': entry.attrib['eid'],
                                      'triple_cnt': nr_triples,
                                      'text': verbalization.text,
                                      'triple': unified_triple_set,
                                    })
                        
print(test)
print(test_stats)

[[{'category': 'ComicsCharacter', 'id': 'Id1', 'triple_cnt': 1, 'text': 'Asterix was created by Rene Goscinny.', 'triple': ['Asterix_(character)', 'creators', 'René_Goscinny']}, {'category': 'ComicsCharacter', 'id': 'Id1', 'triple_cnt': 1, 'text': 'The creator of Asterix (comics character) is René Goscinny.', 'triple': ['Asterix_(character)', 'creators', 'René_Goscinny']}, {'category': 'ComicsCharacter', 'id': 'Id1', 'triple_cnt': 1, 'text': 'The comic character Asterix, was created by René Goscinny.', 'triple': ['Asterix_(character)', 'creators', 'René_Goscinny']}, {'category': 'ComicsCharacter', 'id': 'Id2', 'triple_cnt': 1, 'text': 'The comic character Auron was created by Marv Wolfman.', 'triple': ['Auron_(comics)', 'creator', 'Marv_Wolfman']}, {'category': 'ComicsCharacter', 'id': 'Id2', 'triple_cnt': 1, 'text': 'Auron was created by Marv Wolfman.', 'triple': ['Auron_(comics)', 'creator', 'Marv_Wolfman']}, {'category': 'ComicsCharacter', 'id': 'Id2', 'triple_cnt': 1, 'text': 'The 

#### Spilt Train Data into Train and Dev (for intermindiate validation throughout training)

In [9]:
# Percentage of train data reserved for validation throughout training
dev_percentage = 0.15

In [10]:
# Init dev dataset
dev = [[] for i in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)]

# Sample number of dev instances per number of triples
dev_stats = [int(dev_percentage * train_stats[i]) for i in range(0, MAX_NUM_TRIPLES+1-MIN_NUM_TRIPLES)]

print('Samples per nr of triples:', dev_stats)

# Sample indices to be reserved for dev dataset for each nr of triples
dev_indices = [random.sample(range(0, len(train[i])), dev_stats[i]) for i in range(0, MAX_NUM_TRIPLES+1-MIN_NUM_TRIPLES)]
for i in range(len(dev_indices)):
    dev_indices[i].sort(reverse=True)

# Copy selected dev-entries into dev & delete all duplicates/related entries from train dataset
for nr_triples in range(0, MAX_NUM_TRIPLES+1-MIN_NUM_TRIPLES):
    
    # Iterate through all indices reserved for validation set (per nr of triples)
    for index in dev_indices[nr_triples]:
        
        # Select index'th train entry (to become dev/validation data)
        selected_entry = train[nr_triples][index]
        
        # Extract indentifying attributes
        entrys_category = selected_entry['category']
        entrys_idx = selected_entry['id']
        
        # Put selected entry into dev set
        dev[nr_triples].append(selected_entry)
        
        # Find all entries of matching index & category and remove them from train data
        for entry in train[nr_triples]:
            if entry['id'] == entrys_idx and entry['category'] == entrys_category:
                train[nr_triples].remove(entry)
                
print(dev)
print(dev_stats)

Samples per nr of triples: [268]
[[{'category': 'University', 'id': 'Id46', 'triple_cnt': 1, 'text': 'Romania has many ethnic groups one of which are Germans.', 'triple': ['Romania', 'ethnicGroups', 'Germans_of_Romania']}, {'category': 'University', 'id': 'Id42', 'triple_cnt': 1, 'text': 'To the northeast of Karnataka is Telagana.', 'triple': ['Karnataka', 'has to its northeast', 'Telangana']}, {'category': 'University', 'id': 'Id36', 'triple_cnt': 1, 'text': "One of Denmark's leaders is Lars Lokke Rasmussen.", 'triple': ['Denmark', 'leader', 'Lars_Løkke_Rasmussen']}, {'category': 'University', 'id': 'Id34', 'triple_cnt': 1, 'text': 'All India Council for Technical Education is located in Mumbai.', 'triple': ['All_India_Council_for_Technical_Education', 'location', 'Mumbai']}, {'category': 'University', 'id': 'Id28', 'triple_cnt': 1, 'text': 'There are around 10000 undergraduate students at the Acharya Institute of Technology.', 'triple': ['Acharya_Institute_of_Technology', 'numberOfUn

#### Print Stats

In [11]:
print('Minimal number of triples:', MIN_NUM_TRIPLES)
print('Maximal number of triples:', MAX_NUM_TRIPLES)

print()

print('Training: ')
for nr_triples in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1):
    print('Given %i triples per sentence: ' % nr_triples)
    print('Number of combinations of triples and verbalizations:', len(train[nr_triples-MIN_NUM_TRIPLES]))

print()

print('Dev: ')
for nr_triples in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1):
    print('Given %i triples per sentence: ' % nr_triples)
    print('Number of combinations of triples and verbalizations:', len(dev[nr_triples-MIN_NUM_TRIPLES]))

print()

print('Testing: ')
for nr_triples in range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1):
    print('Given %i triples per sentence: ' % nr_triples)
    print('Number of combinations of triples and verbalizations:', len(test[nr_triples-MIN_NUM_TRIPLES]))

Minimal number of triples: 1
Maximal number of triples: 1

Training: 
Given 1 triples per sentence: 
Number of combinations of triples and verbalizations: 3906

Dev: 
Given 1 triples per sentence: 
Number of combinations of triples and verbalizations: 268

Testing: 
Given 1 triples per sentence: 
Number of combinations of triples and verbalizations: 532


# Neural Machine Translation (NMT) Model Definition 

## TODO: needs updating

## Idea:
1. Encoder: 
1.1 Input==Word Embedding; 
1.2 Output==Context Vector (that is: Encoding of sentence; contained in hidden state after having observed last embedding)

2. Decoder:
2.1 Input==Context Vector
2.2 Output==Probability distribution over output vocab

3. Seq2Seq model: Combining the two

## BERT Encoder

#### Tokenizer

In [12]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', device=device)

#### Model

In [13]:
bert_model = BertModel.from_pretrained('bert-base-uncased', return_dict=True).to(device)

## Decoder

### Soft Attention Model
This model implements the Soft Attention model presented in http://proceedings.mlr.press/v37/xuc15.pdf. 
1. Attention energies (i.e. energy per annotation vector) get computed: $e_{ti}=f_{att}(a_i,h_{t−1})$. Note that this formula implies that the Decoder's previous hidden state $h_{t-1}$ needs to be appended to each individual annotation vector $a_i$ before feeding their concatenation through a fully-connected layer $f_{att}$. 
2. Attention weights $\alpha$ get computed from the aforementioned energies: $\alpha_t = softmax(e_t)$, where $\alpha_{ti} = \frac{exp(e_{ti})}{\sum^L_{k=1} exp(e_{tk})}$.

Note: $t$ stands for time, while $i$ identifies the particular annotation vector currently under consideration.

In [14]:
class SoftAttention(nn.Module):
    
    def __init__(self, 
                 annotation_size,  # Tuple: (num_annotations, num_features_per_annotation)
                 hidden_len        # Number of nodes in Decoder's hidden state weight matrix
                ):
        
        super(SoftAttention, self).__init__()
        #print('SA INIT')
        # Variables
        self.num_annotations = annotation_size[0]
        self.annotation_features = annotation_size[1]
        self.hidden_size = hidden_len
        
        # Layers
        self.attn = nn.Linear(self.annotation_features + self.hidden_size, 1, bias=True)
        self.softmax = nn.Softmax(dim=1)
        #print('an size:', annotation_size) # 8x96
        #print('an features (96?):', self.annotation_features) # 96
        #print('hid size:', hidden_len)     # 96
        
    def forward(self, annotations, prev_hidden):
        
        # Repeat prev_hidded X times to append it to each of the annotation vectors (per batch element)
        repeated_hidden = torch.cat(
            [
                torch.repeat_interleave(hid, repeats=self.num_annotations, dim=0).unsqueeze(0)
                for hid in prev_hidden.split(1)
            ]
        )
        
        # Append previous hidden state to all annotation vectors (for each individual batch element)
        # Input to attention weight calculation
        #print('SA:', annotations.size(), repeated_hidden.size())
        input = torch.cat((annotations, repeated_hidden), dim=2)
        #print('Input size:', input.size())
        #print(self.attn)
        
        # Compute the relative attention scores per feaure (e_{ti}=f_{att}(a_i,h_{t−1}) from paper)
        energies = self.attn(input)
        
        #print('energies...')
        
        # Compute final attention weights (i.e. alpha)
        attn_weights = self.softmax(energies)
        #print('attn_weights:', attn_weights.size())
        return attn_weights


### Decoder itself (employing Soft Attention)

In [15]:
class Decoder(nn.Module):
    
    def __init__(self, 
                 annotation_size,      # Size of annotation vectors produced by Encoder
                 out_vocab_size,       # How many words there are in the RDF-output language
                 embedding_dim,        # Length of a word embedding
                 hidden_dim,           # Nr hidden nodes
                 output_dim,           # Vocab size
                 bidirectional=False,  # Whether to have bidirectional GRU
                 n_layers=1,           # Nr layers in GRU
                 drop_prob=0.2         # Percent of node-dropouts
                ):
        
        super(Decoder, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers
        self.n_directions = 1 if not bidirectional else 2  # TODO: make use of it...
        
        self.attn = SoftAttention(annotation_size=annotation_size, hidden_len=hidden_dim)
        self.gru = nn.GRU(annotation_size[1]+embedding_dim, hidden_dim, n_layers, batch_first=True, dropout=drop_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
        
    def forward(self, 
                annotations,  # Static annotation vectors (for each batch element)
                embeddings,   # Word embeddings of most recently generated word (per batch element)
                h_old         # Previous hidden state per batch element
               ):
        #print('Decoder forward:')
        #print('embeddings:\t', embeddings.size())
        #print('h_old:\t\t', h_old.size())
        
        annotation_weights = self.attn(annotations, h_old.squeeze())#.unsqueeze(2)
        #print('annotations:', annotations.size())
        #print('annotation_weights:', annotation_weights.size())
        weighted_annotations = annotations * annotation_weights
        #print('weighted_annotations:', weighted_annotations.size())
        context_vectors = torch.sum(weighted_annotations, dim=1)
        #print('context_vectors:', context_vectors.size())
        
        x = torch.cat((context_vectors, embeddings), dim=1)
        #print('x:', x.size())
        x = x.unsqueeze(1) # Add une dimension for 'sequence'
        
        #print('Decoder x:', x.size(), 'h_old:', h_old.size())
        #print(self.gru)
        out, h = self.gru(x, h_old)
        out = out.squeeze()
        out = self.softmax(self.fc(self.relu(out)))
        #print('h:', h.size())
        return out, h
    
    
    def init_hidden(self, annotation_vectors):
        # Mean of annotation vector per batch element
        # Assumes that number of hidden nodes == number annotation features
        hidden = torch.mean(annotation_vectors, dim=1)#.to(device)
        return hidden

## Encoder for seq2seq

In [16]:
# use it only if required otherwise can be removed.


class Encoder(nn.Module):
    
    def __init__(self, 
                 input_dim,
                 embedding_dim,        # Length of a word embedding
                 hidden_dim,           # Nr hidden nodes
                 n_layers = 1,         # Nr layers in GRU
                 dropout =0.2
                ):
        super().__init__()
        
        self.hid_dim = hid_dim
        self.n_layers = n_layers
        
        self.embedding = nn.Embedding(input_dim, emb_dim)
        
        self.rnn = nn.LSTM(emb_dim, hid_dim, n_layers, dropout = dropout)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, src):        
        embedded = self.dropout(self.embedding(src))
        #embedded = [src len, batch size, emb dim] 
        
        outputs, (hidden, cell) = self.rnn(embedded)
        
        #outputs =
        #hidden =
        #cell =
        
        #outputs are always from the top hidden layer
        
        return hidden, cell

## Train loop

In [17]:
def predict(x, 
            word_embeddings,        # Decoder's word embeddings
            word2idx,               # 
            idx2word,               # 
            encoder,                # 
            decoder,                # 
            tokenizer,              # 
            loss_fn,                # 
            max_len=7,              # 
            batch_size=32,          # 
            compute_grads=False,    # 
            targets=None,           # 
            return_textual=False    # Whether to return predictions in index-form (default) or as textual strings
           ):
    
    print('In predict:')
    
    accumulated_loss = 0.
    
    # Init documentation of predictions
    predicted_indices = torch.zeros([batch_size, max_len]).to(device) # Numeric
    if return_textual:
        predicted_words = ['']*batch_size
    
    # Tokenize sampled minibatch sentences
    inputs = tokenizer(x, 
                       return_tensors="pt",     # Return tensors in pt=PyTorch format
                       padding=True,            # Pad all sentences in mini-batch to have the same length
                       add_special_tokens=True).to(device) # Add "Start of sequence", "End of sequence", ... tokens. 
    
    #print('Tokenized Inputs:', inputs)
    
    # Encode sentences: Pass tokenization output-dict-contents to model
    outputs = encoder(**inputs)
    #print('Got outputs:', outputs)

    # Retrieve hidden state to be passed into Decoder as annotation vectors
    # Reshape to get a set of 8 feature vectors from last hidden state
    annotations = outputs.last_hidden_state[:, -1, :].reshape(batch_size,8,-1).to(device)
    #print('Annotations size after cropping & reshape:', annotations.size())

    # Init Decoder's hidden state
    hidden = decoder.init_hidden(annotations).unsqueeze(0).to(device)
    #print('Initial hidden size:', hidden.size(), 'given annotations:', annotations.size())
    
    # Construct initial embeddings (start tokens)
    embeddings = word_embeddings(torch.zeros([batch_size], dtype=int).to(device)).to(device)
    
    for t in range(max_len):
        #print('START OF ITERATION', t)
        # Get decodings (aka prob distrib. over output vocab per batch element) for time step t
        prob_dist, hidden = decoder(annotations, # Static vector containing annotations per batch element 
                                    embeddings,  # Word embedding predicted last iteration (per batch element)
                                    hidden       # Decoder's hidden state of last iteratipn per batch element
                                    )

        # Get predicted word index from predicted probability distribution (per batch element)
        word_index = torch.max(prob_dist, dim=1).indices
        #print('Predicted word indices batch:', word_index)
        
        # Get corresponding word embedding (by index; per batch element)
        embedding = word_embeddings(word_index.to(device))
        
        # TODO: optional teacher forcing?

        # Record predicted words
        predicted_indices[:, t] = word_index
        #print('Predicted indices:', predicted_indices)
        
        # Record textual words if required
        if return_textual:
            
            # Get predicted word (per batch element)
            predicted_word = [idx2word[batch_element.item()] for batch_element in word_index]
        
            for e in range(batch_size):
                predicted_words[e] += (predicted_word[e] + ' ')

        if compute_grads:
            
            #print('prob_dist:', prob_dist.size())
            #print('targets:', targets[:, t].size(), targets[:, t])
            
            # Compute (averaged over all batch elements given current time step t)
            loss = loss_fn(prob_dist, targets[:, t]).to(device)

            # Compute & back-propagate gradients
            loss.backward(retain_graph=True)
            
            # Document loss
            accumulated_loss += loss.item()
        #print('END OF ITERATION', t)
            
    ret_object = {
        'predicted_indices': predicted_indices,
    }
    
    print('Targets:\n', targets)
    print('Predicted idxs:\n', predicted_indices)
    
    if compute_grads:
        ret_object['loss'] = accumulated_loss
        #print('Accumulated loss:', accumulated_loss)
    if return_textual:
        ret_object['predicted_words'] = predicted_words
        #print("Predicted words:", predicted_words)
    #print("Returning from predict")
    return ret_object 

In [18]:
def rdf_vocab_constructor(raw_vocab):
    #print(raw_vocab)
    vocab_count, word2idx, idx2word = 3, {'START': 0, 'PAD': 1, 'END': 2}, {0: 'START', 1: 'PAD', 2: 'END'}
    
    for partition in raw_vocab: # Different partitions with respect to nr or triples per sentence
        for train_instance in partition:
            triple = train_instance['triple']
            for token in triple:
                if token not in word2idx:
                    word2idx[token] = vocab_count
                    idx2word[vocab_count] = token
                    vocab_count += 1
    return vocab_count, word2idx, idx2word

In [28]:
def training(train_data, 
          val_data,  
          epochs, 
          minibatch_size=32,
          embedding_dim=300,
          eval_frequency=10, # Every how many epochs to run intermediate evaluation
          learning_rate_en=0.00001,
          learning_rate_de=0.00001
         ):
    
    # Construct RDF vocab
    vocab_count, word2idx, idx2word = rdf_vocab_constructor(train_data)
    
    # Construct embeddings
    rdf_vocab = nn.Embedding(num_embeddings=vocab_count, embedding_dim=embedding_dim, padding_idx=0).to(device)
    
    # Define model
    encoder = bert_model.to(device)
    decoder = Decoder(
        annotation_size=(8,96),    # Size of annotation vectors produced by Encoder
        out_vocab_size=vocab_count, # How many words there are in the RDF-output language
        embedding_dim=300,          # Length of a word embedding
        hidden_dim=96,             # Nr hidden nodes
        output_dim=vocab_count,             # Vocab size
    ).to(device)
    
    # Optimizer
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate_en)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate_de)

    loss = nn.CrossEntropyLoss()
    
    # For both train and validation data & for all number of tuples per sentence 
    # (in [MIN_NUM_TRIPLES, MAX_NUM_TRIPLES]), get the nr of train-/test instances
    len_x_train = [len(train_set) for train_set in train]
    len_x_val = [len(val_set) for val_set in dev]
    
    # Development of both train- and validation losses over course of training
    train_losses, val_losses = [0.]*epochs, [0.]*epochs
    
    print('Starting training.')
    
    # Train
    for epoch in range(epochs):
        print('Epoch:', epoch)
        
        train_loss, eval_loss = 0., 0.
        
        # Reset gradients
        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()
        
        # Perform own train step for each nr of triples per sentence separately
        for i, nt in enumerate(range(MIN_NUM_TRIPLES, MAX_NUM_TRIPLES+1)):
            print(str(i) + '. Condition:', nt, 'triples per sentence.')
            
            # Sample minibatch indices
            minibatch_idx = random.sample(population=range(len_x_train[i]), k=minibatch_size)
            #print('MB indices:', minibatch_idx)
            
            # Number of tokens to be predicted (per batch element)
            num_preds = nt*3+1 # = nr triples * 3 + stop_token 
            #print('Number of predictions:', num_preds)
            
            # Construct proper minibatch
            inputs = [train_data[i][idx]['text'] for idx in minibatch_idx]
            targets = torch.ones([minibatch_size, num_preds], dtype=int).to(device)
            
            #print('Inputs:', inputs)
            #print('Targets:', targets)
            
            for mb_i, idx in enumerate(minibatch_idx):
                #print('Text:', train_data[i][idx]['text'])
                #print('Triple:', train_data[i][idx]['triple'])
                for t, token in enumerate(train_data[i][idx]['triple']):
                    targets[mb_i, t] = word2idx[token]
            targets[:, -1] = 2  # 2 = Stop word index
            
            #print('Processed targets:', targets)
            #print('Predicting:')
            
            # Predict
            ret_object = predict(inputs,
                                 rdf_vocab,              # Decoder's word embeddings
                                 word2idx,               # 
                                 idx2word,               # 
                                 encoder,                # 
                                 decoder,                # 
                                 tokenizer,              # 
                                 loss,                   # 
                                 max_len=num_preds,      # Nr of tokens to be predicted
                                 batch_size=32,          # 
                                 compute_grads=True,     # 
                                 targets=targets,        # 
                                 return_textual=True     # Whether to return predictions in index-form (default) or as textual strings
                                )
            
            print('Return object:', ret_object)
            print("Predicted texts:", ret_object['predicted_words'])
            train_loss += ret_object['loss']
            #print("Returned loss:", ret_object['loss'])
            
        # Apply gradients
        encoder_optimizer.step()
        decoder_optimizer.step()
        #print('Optimizations performed.')
        
        # Intermediate evaluation
        
        # Save losses
        train_losses[epoch] = train_loss
        
    return train_losses, val_losses, encoder, decoder


# Train

In [29]:
# Free CUDA memory
if str(device) == 'cuda':
    torch.cuda.empty_cache()
    
# Train
train_losses, val_losses, encoder, decoder = training(train, dev, epochs=100)
print('Train losses:', train_losses)

Starting training.
Epoch: 0
0. Condition: 1 triples per sentence.
In predict:
Targets:
 tensor([[1738,  209, 1801,    2],
        [1598,  201, 1732,    2],
        [ 984,   55,  985,    2],
        [ 667,  608,  609,    2],
        [1461, 1355, 1492,    2],
        [ 308,  201,  310,    2],
        [ 367,  548,  549,    2],
        [ 905,  927,  928,    2],
        [1175,  297, 1179,    2],
        [ 656,  608,  666,    2],
        [1635, 1560, 1636,    2],
        [ 154,    4,  156,    2],
        [1725, 1708, 1616,    2],
        [1461, 1352, 1364,    2],
        [1042, 1003, 1043,    2],
        [  38,  204,  222,    2],
        [ 122,    6,  128,    2],
        [ 213,  197,  214,    2],
        [1688,  440, 1712,    2],
        [ 709,  220,  864,    2],
        [  86,   33,  278,    2],
        [ 725,  297,  528,    2],
        [1191,    4, 1192,    2],
        [ 874,  880,  881,    2],
        [ 486,  299,  491,    2],
        [ 908,  913,  914,    2],
        [ 331,  314,  332,  

Targets:
 tensor([[1361, 1350, 1359,    2],
        [1629,  201, 1722,    2],
        [ 510,  287, 1674,    2],
        [1416, 1536, 1538,    2],
        [ 721,  612,  726,    2],
        [ 701,  608,  609,    2],
        [1762, 1765, 1766,    2],
        [ 172,    6,  178,    2],
        [  68,    6,   70,    2],
        [ 531,  197,  533,    2],
        [ 963,  973,  974,    2],
        [1312, 1314, 1317,    2],
        [1621,  287, 1632,    2],
        [ 580, 1564, 1698,    2],
        [ 262,   75,  145,    2],
        [ 378,   33,  557,    2],
        [  45,    6,   10,    2],
        [  92,    6,   99,    2],
        [ 645,  606,  650,    2],
        [  84,    6,   89,    2],
        [1165, 1163, 1170,    2],
        [1426,  254, 1531,    2],
        [1216, 1203, 1217,    2],
        [1232,    4,   79,    2],
        [1644,  201,   79,    2],
        [ 874,  254,  877,    2],
        [ 181,    6,  186,    2],
        [1654, 1657, 1658,    2],
        [1222, 1203, 1223,    2],
    

Targets:
 tensor([[1248,  201, 1252,    2],
        [  79,  220,  279,    2],
        [1202,  297, 1211,    2],
        [ 651,  612,  653,    2],
        [ 312,  197,  265,    2],
        [ 639,  606,  644,    2],
        [ 291,  299,  301,    2],
        [ 154,    4,  156,    2],
        [  73,   33,  242,    2],
        [1494, 1497, 1498,    2],
        [ 468,  302,  471,    2],
        [ 986,  105,  988,    2],
        [  78,    4,   79,    2],
        [ 334,  201,  102,    2],
        [1744,  824, 1749,    2],
        [1181,  297, 1189,    2],
        [1575,    4,   79,    2],
        [  94,  197,  198,    2],
        [1084, 1003, 1086,    2],
        [ 977,  919,  983,    2],
        [ 685,  742,  801,    2],
        [ 531,  209,  535,    2],
        [ 103,   25,  107,    2],
        [  86,  204, 1340,    2],
        [ 672,  612,  622,    2],
        [ 349,  302,  355,    2],
        [1607,  201, 1605,    2],
        [  86,  197,  265,    2],
        [ 654,  750,  826,    2],
    

Targets:
 tensor([[ 725,  297,  760,    2],
        [1449, 1383, 1454,    2],
        [1233,  797,  874,    2],
        [1142, 1113, 1146,    2],
        [1621, 1558, 1626,    2],
        [ 452,  297,  458,    2],
        [ 442,  324,  446,    2],
        [1222, 1122, 1228,    2],
        [ 134,  140,  141,    2],
        [ 731,    4,   79,    2],
        [ 389,  197,  265,    2],
        [ 331,  302,  337,    2],
        [1084, 1003, 1086,    2],
        [1076,  105, 1078,    2],
        [ 334,  201,  102,    2],
        [ 163, 1793, 1794,    2],
        [ 718,  750,  696,    2],
        [ 977,  915,   79,    2],
        [1424, 1352, 1428,    2],
        [1762, 1781, 1782,    2],
        [ 605,  606,  607,    2],
        [ 703,  821,  696,    2],
        [1642, 1705, 1706,    2],
        [ 353,  309,  579,    2],
        [1481, 1407, 1484,    2],
        [1292,  204,  205,    2],
        [1372, 1375, 1376,    2],
        [1605,  197, 1711,    2],
        [1635,  309, 1605,    2],
    

Targets:
 tensor([[  84,   52,   91,    2],
        [ 277,   33, 1320,    2],
        [1084, 1003, 1086,    2],
        [ 502,  586,  588,    2],
        [ 625,  608,  632,    2],
        [ 508,  456,  509,    2],
        [1153, 1120, 1155,    2],
        [ 313,  302,  328,    2],
        [ 130,   16,  234,    2],
        [1008, 1087, 1097,    2],
        [ 718,  749,  783,    2],
        [1753,    4, 1755,    2],
        [ 331,  302,  338,    2],
        [1663, 1708, 1729,    2],
        [1462, 1377, 1466,    2],
        [  79,  209,  284,    2],
        [1359, 1357, 1361,    2],
        [   3,   12,   13,    2],
        [ 725,  297,  760,    2],
        [1262,  297, 1266,    2],
        [1481, 1375, 1483,    2],
        [1753, 1756, 1757,    2],
        [ 508,  456,  509,    2],
        [  79,  216, 1089,    2],
        [ 389,  197,  265,    2],
        [ 977,  889,  978,    2],
        [1239,  824, 1244,    2],
        [ 580,  201, 1697,    2],
        [ 605,  608,  609,    2],
    

Targets:
 tensor([[1416,  771, 1533,    2],
        [1447, 1383, 1448,    2],
        [ 311,  440,  308,    2],
        [1085, 1087, 1095,    2],
        [1412, 1383, 1423,    2],
        [1696,  440,  580,    2],
        [1443, 1420, 1447,    2],
        [ 633,  614,  637,    2],
        [1488,  971, 1517,    2],
        [1635,  201, 1638,    2],
        [1635, 1562, 1637,    2],
        [ 531,  197,  525,    2],
        [ 664,  750,  755,    2],
        [1804, 1739, 1806,    2],
        [ 513,  299,  518,    2],
        [ 103,  110,  111,    2],
        [ 531,   33, 1291,    2],
        [1481, 1407, 1484,    2],
        [1359, 1355, 1365,    2],
        [ 122,    6,  126,    2],
        [ 804,  750,  807,    2],
        [1222,  297, 1231,    2],
        [ 442,  324,  446,    2],
        [ 835,  841,  842,    2],
        [1140,  201,  825,    2],
        [1654,  287, 1572,    2],
        [1361,  197,  265,    2],
        [ 656,  626,  657,    2],
        [ 391,  316,  317,    2],
    

Targets:
 tensor([[1762, 1772, 1773,    2],
        [1026,  105, 1027,    2],
        [ 633,  619,   22,    2],
        [  73,  231,  240,    2],
        [ 888,  898,  899,    2],
        [ 888,  896,  897,    2],
        [1033,  105, 1034,    2],
        [1394, 1541, 1542,    2],
        [ 407,  297,  410,    2],
        [ 419,  302,  345,    2],
        [ 908,  904,  921,    2],
        [ 977,  904,  931,    2],
        [ 181,   12,  189,    2],
        [  84,    6,    8,    2],
        [ 154,   12,  161,    2],
        [1762, 1765, 1766,    2],
        [ 937,  771,  944,    2],
        [ 163, 1793, 1794,    2],
        [1664,  201, 1669,    2],
        [1635,  201, 1638,    2],
        [  49,    6,   50,    2],
        [ 398,  299,  402,    2],
        [ 477,  324,  481,    2],
        [ 962,  906,  965,    2],
        [ 664,  749,  758,    2],
        [ 993,  105,  994,    2],
        [ 954,  955,  956,    2],
        [ 291,  302,  304,    2],
        [1426,  254, 1531,    2],
    

Targets:
 tensor([[1639,    4,   79,    2],
        [ 308,  201,  311,    2],
        [ 761,  749,  763,    2],
        [   3,    4,    5,    2],
        [1055, 1087, 1097,    2],
        [  86,  220,  277,    2],
        [ 439,  440,  441,    2],
        [ 473,  309,  474,    2],
        [1609,    4,   79,    2],
        [1755,  865, 1810,    2],
        [  65,   16,   66,    2],
        [ 637,  798,  796,    2],
        [ 398,  299,  402,    2],
        [  78,    6,   80,    2],
        [  92,   12,  102,    2],
        [1191,    4, 1192,    2],
        [1631,  297, 1702,    2],
        [ 250,  201, 1280,    2],
        [ 696,  413,  699,    2],
        [ 680,  608,  609,    2],
        [1601, 1564, 1608,    2],
        [ 721,  705,  727,    2],
        [ 181,    6,  184,    2],
        [1437, 1355, 1441,    2],
        [1632,  201, 1701,    2],
        [1443, 1355, 1446,    2],
        [  84,    4,   86,    2],
        [ 654,  748,  651,    2],
        [1009, 1003, 1010,    2],
    

Targets:
 tensor([[ 378,   33,  557,    2],
        [ 331,  302,  338,    2],
        [ 213,  209,  215,    2],
        [ 839,  854,  855,    2],
        [ 419,  316,  317,    2],
        [ 172,   52,  176,    2],
        [ 654,  750,  831,    2],
        [ 107,   12,  109,    2],
        [  54,   58,   59,    2],
        [ 419,  302,  429,    2],
        [ 398,  297,  401,    2],
        [  21,   25,   26,    2],
        [ 398,  302,  404,    2],
        [ 339,  302,  346,    2],
        [ 473,  201,  476,    2],
        [1462, 1383, 1454,    2],
        [1461, 1417, 1455,    2],
        [ 107,    6,  112,    2],
        [  79,   33,  283,    2],
        [1598,  297,   79,    2],
        [ 374,  302,  380,    2],
        [ 844,  836,  845,    2],
        [ 367,  550,  551,    2],
        [ 872,  868,  873,    2],
        [ 908,  906,  924,    2],
        [ 331,  324,  336,    2],
        [ 122,    6,  130,    2],
        [ 452,  305,  461,    2],
        [   5,  209,  270,    2],
    

Targets:
 tensor([[1644,    4,   79,    2],
        [ 580,    4,   79,    2],
        [1158, 1163, 1164,    2],
        [ 761,  750,  696,    2],
        [ 313,  305,  329,    2],
        [1586,  440, 1581,    2],
        [ 374,  305,  385,    2],
        [1214,   33, 1309,    2],
        [1202, 1209, 1210,    2],
        [ 937,  938,  939,    2],
        [1015, 1022, 1023,    2],
        [1687,  201, 1690,    2],
        [1028,  105, 1030,    2],
        [  21,    6,   29,    2],
        [ 664,  748,  756,    2],
        [ 173,  254,  257,    2],
        [ 905,  927,  928,    2],
        [  79,  231,  282,    2],
        [ 718,  750,  779,    2],
        [ 663,  742,  799,    2],
        [1325,  201,  102,    2],
        [ 173,  254,  257,    2],
        [1395, 1375, 1397,    2],
        [ 400,   33,  562,    2],
        [  78,   52,   82,    2],
        [1093,  915, 1089,    2],
        [ 708,  297,  709,    2],
        [   3,    4,    5,    2],
        [1015, 1022, 1023,    2],
    

Targets:
 tensor([[1616, 1567, 1617,    2],
        [1688,    4,   79,    2],
        [1589,  201, 1594,    2],
        [ 711,  606,  713,    2],
        [1502, 1407, 1507,    2],
        [1345, 1355, 1356,    2],
        [ 874,  220,  277,    2],
        [ 715,  413,  719,    2],
        [  68,   42,   71,    2],
        [1681,    4,   79,    2],
        [1412, 1343, 1416,    2],
        [1042,  105, 1044,    2],
        [1093,  915, 1089,    2],
        [ 339,  297,  342,    2],
        [1437, 1383, 1442,    2],
        [ 867,  870,  871,    2],
        [1191,  201, 1193,    2],
        [1172,  201, 1292,    2],
        [ 513,  302,  519,    2],
        [ 148,    6,  116,    2],
        [ 172,    6,  176,    2],
        [ 538,  539,  470,    2],
        [ 835,  836,  837,    2],
        [ 888,  902,  903,    2],
        [ 937,  938,  939,    2],
        [1239,  824, 1244,    2],
        [1738, 1798, 1799,    2],
        [ 615,  752,  774,    2],
        [ 672,  608,  679,    2],
    

Targets:
 tensor([[1385, 1389, 1390,    2],
        [1621, 1560, 1627,    2],
        [ 977,  919,  983,    2],
        [ 908,  916,  907,    2],
        [  54,    4,   38,    2],
        [1202, 1108, 1205,    2],
        [ 701,  626,  702,    2],
        [1372, 1383, 1384,    2],
        [ 664,  748,  756,    2],
        [ 486,  299,  491,    2],
        [1239, 1245, 1246,    2],
        [ 486,  318,  488,    2],
        [  22,  209,  226,    2],
        [1654, 1633, 1656,    2],
        [1689,    4,   79,    2],
        [  92,    6,   97,    2],
        [1475, 1383, 1480,    2],
        [ 507,  209,  543,    2],
        [1432, 1343, 1434,    2],
        [ 163, 1793, 1794,    2],
        [1654,  287, 1598,    2],
        [ 685,  742,  801,    2],
        [1803, 1765, 1791,    2],
        [1803, 1765, 1791,    2],
        [1216, 1113, 1220,    2],
        [ 496,  302,  504,    2],
        [ 419,  305,  436,    2],
        [1605,    4,   79,    2],
        [  68,    6,   70,    2],
    

Targets:
 tensor([[  21,   30,   31,    2],
        [ 667,  608,  671,    2],
        [1052,  105, 1053,    2],
        [1664, 1591, 1667,    2],
        [ 502,  586,  587,    2],
        [ 172,    6,  175,    2],
        [1098,  915, 1099,    2],
        [1292,  197, 1329,    2],
        [ 291,  302,  303,    2],
        [ 804,  749,  806,    2],
        [1351, 1357, 1345,    2],
        [ 678,  752,  753,    2],
        [1192,  204, 1297,    2],
        [1395, 1379, 1398,    2],
        [1222,  861, 1229,    2],
        [1707,    4,   79,    2],
        [ 148,   30,  152,    2],
        [  84,    6,    8,    2],
        [1687,    4,   79,    2],
        [ 709,  865,  866,    2],
        [1589,  309, 1595,    2],
        [1248,  201,   79,    2],
        [1632,  201, 1701,    2],
        [ 804,  750,  810,    2],
        [ 633,  614,  637,    2],
        [ 605,  608,  609,    2],
        [1579,  220, 1664,    2],
        [1079,  915,   79,    2],
        [1181, 1120, 1187,    2],
    

Targets:
 tensor([[ 213,  197,  214,    2],
        [ 688,  614,  690,    2],
        [1681, 1560, 1684,    2],
        [1461, 1352, 1364,    2],
        [1598,  201, 1732,    2],
        [ 419,  302,  431,    2],
        [ 949,  197,  214,    2],
        [1253, 1254, 1255,    2],
        [ 452,  302,  433,    2],
        [ 358,  302,  369,    2],
        [ 637,  750,  795,    2],
        [ 134,  142,  143,    2],
        [1253, 1108, 1259,    2],
        [1762, 1769, 1770,    2],
        [ 672,  608,  609,    2],
        [ 688,  606,  691,    2],
        [1437, 1352, 1354,    2],
        [1621,  201,  949,    2],
        [1726,    4,   79,    2],
        [ 398,  292,  399,    2],
        [1128,  413, 1133,    2],
        [ 986,  105,  987,    2],
        [ 804,  750,  813,    2],
        [1644,  201, 1605,    2],
        [ 667,  612,  670,    2],
        [1762, 1785, 1786,    2],
        [ 277,  287, 1321,    2],
        [ 835,  824,  843,    2],
        [1394, 1541, 1542,    2],
    

Targets:
 tensor([[1459, 1350, 1530,    2],
        [1681, 1560, 1684,    2],
        [ 313,  318,  319,    2],
        [ 391,  305,  383,    2],
        [1762, 1775, 1776,    2],
        [ 162,    6,  165,    2],
        [1325,    4,   79,    2],
        [1784, 1811, 1812,    2],
        [1432, 1343, 1434,    2],
        [1461,    4, 1458,    2],
        [1677,  201,   79,    2],
        [ 937,  891,  942,    2],
        [1202,  297, 1210,    2],
        [1050,  915, 1089,    2],
        [ 358,  364,  365,    2],
        [ 547,  201,  102,    2],
        [ 452,  456,  457,    2],
        [ 835,  297,  839,    2],
        [  54,    4,   38,    2],
        [1786,  297, 1787,    2],
        [ 442,  302,  447,    2],
        [ 411,    4,   79,    2],
        [1664,    4,   79,    2],
        [ 559,    4,  560,    2],
        [ 114,    4,   38,    2],
        [1644,  201, 1645,    2],
        [1192,  197, 1298,    2],
        [ 353,  309,  529,    2],
        [1437, 1352, 1353,    2],
    

Targets:
 tensor([[  79,  197,  265,    2],
        [ 419,  302,  431,    2],
        [ 925,  821,  926,    2],
        [1481, 1383, 1488,    2],
        [1331,  254, 1332,    2],
        [1312,  771,  949,    2],
        [1214,   33, 1308,    2],
        [1502, 1511, 1512,    2],
        [1526, 1357, 1527,    2],
        [1192,  197, 1298,    2],
        [1635,  201, 1641,    2],
        [ 496,  305,  505,    2],
        [ 130,  194,  235,    2],
        [ 678,  750,  751,    2],
        [ 984,   55,  985,    2],
        [ 625,  619,  629,    2],
        [ 725,  297,  759,    2],
        [ 538,  539,  470,    2],
        [ 580, 1564, 1698,    2],
        [ 419,  302,  433,    2],
        [1031, 1003, 1032,    2],
        [ 477,  305,  485,    2],
        [1165, 1163, 1170,    2],
        [ 711,  626,  712,    2],
        [ 442,  422,  423,    2],
        [1165, 1167, 1168,    2],
        [1664, 1591, 1667,    2],
        [ 486,  299,  491,    2],
        [1635, 1564, 1643,    2],
    

Targets:
 tensor([[1598,  297,   79,    2],
        [ 518,  572,  102,    2],
        [ 911,  967,  968,    2],
        [1726,    4,   79,    2],
        [ 486,  305,  495,    2],
        [1026,  105, 1027,    2],
        [1175, 1113, 1114,    2],
        [   7,   18,   20,    2],
        [1345, 1357, 1358,    2],
        [1755,  209, 1808,    2],
        [1786,  297, 1787,    2],
        [ 308, 1558, 1559,    2],
        [ 400,   33,  562,    2],
        [ 908,  896,  912,    2],
        [ 486,  299,  491,    2],
        [ 844,  846,  847,    2],
        [1803, 1765, 1791,    2],
        [ 308, 1558, 1559,    2],
        [1116,   33, 1279,    2],
        [1489, 1547, 1548,    2],
        [ 277,  287, 1294,    2],
        [ 696,  608,  609,    2],
        [ 173,  231,  253,    2],
        [ 686,  748,  680,    2],
        [1689,    4,   79,    2],
        [1579,  220, 1664,    2],
        [ 639,  614,  643,    2],
        [ 908,  771,  911,    2],
        [1269,   33, 1323,    2],
    

Targets:
 tensor([[  95,   33,  208,    2],
        [ 643,  750,  785,    2],
        [1214,  231, 1306,    2],
        [1214,  231, 1307,    2],
        [1635,  201, 1641,    2],
        [1644,  201, 1645,    2],
        [1055, 1087, 1095,    2],
        [1202, 1203, 1204,    2],
        [ 908,  922,  923,    2],
        [ 419,  302,  432,    2],
        [1443, 1352, 1353,    2],
        [ 341,  201,  386,    2],
        [ 344,  536,  342,    2],
        [ 181,   12,  189,    2],
        [  79,  287,  288,    2],
        [1475, 1373, 1476,    2],
        [ 888,  891,  892,    2],
        [ 250,   33, 1282,    2],
        [ 643,  750,  784,    2],
        [ 997,  105,  999,    2],
        [ 605,  606,  607,    2],
        [ 863,  204,  885,    2],
        [ 688,  606,  691,    2],
        [ 310,    4,   79,    2],
        [1394,    4,   79,    2],
        [ 374,  299,  379,    2],
        [1639,    4,   79,    2],
        [1581, 1558, 1583,    2],
        [ 937,  896,  945,    2],
    

Targets:
 tensor([[1134,  297, 1139,    2],
        [1502, 1375, 1506,    2],
        [ 172,   52,  178,    2],
        [ 358,  297,  363,    2],
        [1069,  105, 1070,    2],
        [ 531,  880, 1790,    2],
        [ 308, 1560, 1561,    2],
        [1677, 1560, 1678,    2],
        [ 863,  209,  887,    2],
        [ 107,   12,  109,    2],
        [1134, 1135, 1136,    2],
        [ 313,  302,  326,    2],
        [1416, 1536, 1537,    2],
        [1736,  705, 1741,    2],
        [ 378,  204,  554,    2],
        [ 867,  868,  869,    2],
        [1005,   55, 1006,    2],
        [1664,  201, 1579,    2],
        [ 911,  967,  968,    2],
        [ 468,  299,  469,    2],
        [1632,  297, 1702,    2],
        [1475, 1379, 1479,    2],
        [1488,  975, 1520,    2],
        [  22,   33,  227,    2],
        [1116,   33, 1276,    2],
        [ 977,  915,   79,    2],
        [ 392,  566,  567,    2],
        [ 358,  359,  360,    2],
        [  78,   52,   82,    2],
    

Targets:
 tensor([[ 531,   33,  535,    2],
        [1654, 1622, 1624,    2],
        [1501, 1343, 1496,    2],
        [1066,  940, 1097,    2],
        [1601, 1564, 1608,    2],
        [ 452,  305,  461,    2],
        [ 531,  209, 1291,    2],
        [1210,   33, 1339,    2],
        [  79,  231,  282,    2],
        [  79,  209,  284,    2],
        [1738,  254, 1800,    2],
        [ 419,  297,  427,    2],
        [1607,  201, 1606,    2],
        [ 312,  216,  582,    2],
        [1345,  197,  265,    2],
        [1345, 1352, 1354,    2],
        [1041,  771,   86,    2],
        [  54,    4,   38,    2],
        [ 468,  302,  472,    2],
        [1210,  204, 1337,    2],
        [ 835,  297,  840,    2],
        [1032, 1087, 1088,    2],
        [  45,   25,   46,    2],
        [ 442,  302,  447,    2],
        [1635,  309, 1605,    2],
        [ 162,   12,  167,    2],
        [1262, 1108, 1265,    2],
        [ 122,   12,   35,    2],
        [1140,  309, 1284,    2],
    

Targets:
 tensor([[ 103,   12,  109,    2],
        [1681, 1564, 1686,    2],
        [ 172,   52,  176,    2],
        [ 908,  904,  921,    2],
        [ 462,  463,  464,    2],
        [  54,    4,   38,    2],
        [ 701,  626,  702,    2],
        [1581, 1564, 1588,    2],
        [ 734,  612,  736,    2],
        [  79,  216, 1089,    2],
        [ 688,  612,  685,    2],
        [ 507,   33,  543,    2],
        [ 962,  889,  963,    2],
        [  79,  231,  280,    2],
        [ 510,  201,  537,    2],
        [ 695,  749,  769,    2],
        [ 358,  364,  365,    2],
        [ 507,   33,  542,    2],
        [ 888,  902,  903,    2],
        [ 507,   33,  542,    2],
        [ 334,    4,   79,    2],
        [1361, 1350, 1359,    2],
        [1753,  297, 1759,    2],
        [1670,    4,   79,    2],
        [ 310,    4,   79,    2],
        [1214,   33, 1308,    2],
        [1026,  105, 1027,    2],
        [ 908,  922,  923,    2],
        [ 701,  705,  706,    2],
    

Targets:
 tensor([[ 703,  413,  699,    2],
        [1609, 1564, 1615,    2],
        [1214,   33, 1308,    2],
        [ 860,  297,  863,    2],
        [  65,   16,   66,    2],
        [1621,  287, 1632,    2],
        [ 312,  216,  581,    2],
        [ 977,  919,  983,    2],
        [1635,  201, 1606,    2],
        [ 721,  608,  730,    2],
        [ 911,  967,  968,    2],
        [1359, 1357, 1361,    2],
        [1461, 1355, 1492,    2],
        [ 374,  299,  379,    2],
        [ 374,  297,  377,    2],
        [ 398,  299,  402,    2],
        [ 400,   33,  563,    2],
        [1385, 1389, 1391,    2],
        [  78,    6,   81,    2],
        [ 173,  254,  257,    2],
        [ 711,  608,  609,    2],
        [1412, 1383, 1423,    2],
        [1677,  201,   79,    2],
        [   3,    6,   10,    2],
        [1222, 1112, 1227,    2],
        [ 458,  309,  604,    2],
        [ 172,   12,  180,    2],
        [1654, 1633, 1656,    2],
        [1384,   33, 1515,    2],
    

Targets:
 tensor([[1066,  915, 1089,    2],
        [ 262,   75,  145,    2],
        [ 181,   52,  188,    2],
        [1762, 1783, 1784,    2],
        [ 363,   33,  603,    2],
        [1616, 1564, 1620,    2],
        [ 308,  201,  311,    2],
        [  68,   42,   71,    2],
        [ 680,  608,  609,    2],
        [1253, 1256, 1257,    2],
        [ 931,  933,  932,    2],
        [1640,    4,   79,    2],
        [ 688,  608,  632,    2],
        [1424, 1357, 1430,    2],
        [ 162,   12,  167,    2],
        [1157,    4,   79,    2],
        [ 531,  880, 1790,    2],
        [ 962,  915,   79,    2],
        [ 407,  302,  415,    2],
        [1589,  201, 1594,    2],
        [ 398,  305,  405,    2],
        [1116,   33,  746,    2],
        [1644,  201, 1645,    2],
        [ 502,  585,  496,    2],
        [1222,  297, 1230,    2],
        [1005,   55, 1006,    2],
        [ 477,  292,  473,    2],
        [1345, 1355, 1356,    2],
        [  79,   33,  285,    2],
    

Targets:
 tensor([[1791, 1705, 1792,    2],
        [  79,  209,  284,    2],
        [1416, 1534, 1535,    2],
        [1253, 1254, 1255,    2],
        [1005,   55, 1006,    2],
        [ 949,  952,  953,    2],
        [ 486,  305,  493,    2],
        [1762, 1779, 1780,    2],
        [1677, 1564, 1680,    2],
        [ 167,  858, 1796,    2],
        [   7,   18,   19,    2],
        [ 353,  309,  529,    2],
        [ 486,  305,  493,    2],
        [1664,  201, 1579,    2],
        [  37,   23,   39,    2],
        [1437,  197,  265,    2],
        [ 491,  297,  526,    2],
        [ 419,  302,  429,    2],
        [1475, 1379, 1479,    2],
        [ 623,  750,  794,    2],
        [  92,    6,   98,    2],
        [1192,  204, 1297,    2],
        [1101,   55, 1103,    2],
        [ 122,    6,  131,    2],
        [1762, 1775, 1776,    2],
        [1084, 1003, 1085,    2],
        [1481, 1379, 1485,    2],
        [ 313,  320,  321,    2],
        [ 513,  299,  518,    2],
    

Targets:
 tensor([[ 358,  320,  361,    2],
        [1718,  209, 1719,    2],
        [ 291,  299,  300,    2],
        [  22,  220,   48,    2],
        [ 701,  705,  706,    2],
        [ 656,  608,  666,    2],
        [1502, 1407, 1507,    2],
        [ 618,  619,  621,    2],
        [ 344,  536,  342,    2],
        [1140,  309, 1284,    2],
        [   3,    6,    7,    2],
        [ 664,  750,  755,    2],
        [  78,    6,   80,    2],
        [1046, 1038, 1062,    2],
        [1142, 1120, 1144,    2],
        [1325,    4,   79,    2],
        [ 844,  413,  851,    2],
        [   3,    6,   11,    2],
        [1181, 1120, 1187,    2],
        [ 452,  302,  433,    2],
        [ 814,  752,  816,    2],
        [1069,  105, 1072,    2],
        [1467, 1383, 1473,    2],
        [ 339,  297,  343,    2],
        [1147,  297, 1152,    2],
        [ 392,  570,  571,    2],
        [1175, 1135, 1176,    2],
        [  21,   25,    9,    2],
        [ 513,  521,  522,    2],
    

Targets:
 tensor([[ 908,  904,  921,    2],
        [ 358,  320,  361,    2],
        [ 696,  413,  699,    2],
        [1664, 1560, 1668,    2],
        [ 419,  305,  434,    2],
        [1681, 1564, 1686,    2],
        [1677, 1564, 1680,    2],
        [ 374,  350,  376,    2],
        [1736, 1742, 1743,    2],
        [ 181,    6,  184,    2],
        [ 339,  295,  340,    2],
        [1738,  254, 1800,    2],
        [ 654,  749,  832,    2],
        [1762, 1777, 1778,    2],
        [ 623,  750,  790,    2],
        [1359,  197,  265,    2],
        [ 163,  209,  218,    2],
        [1117,  297, 1127,    2],
        [1082,   55, 1083,    2],
        [1677,  309, 1679,    2],
        [ 114,    4,   38,    2],
        [1292,  197, 1329,    2],
        [ 715,  619,  717,    2],
        [ 977,  904,  931,    2],
        [ 312,  216,  581,    2],
        [ 715,  413,  719,    2],
        [ 172,   52,  179,    2],
        [ 692,  612,  694,    2],
        [ 398,  297,  400,    2],
    

Targets:
 tensor([[1372, 1373, 1374,    2],
        [ 577,   33,  578,    2],
        [ 874,  254,  877,    2],
        [1165,  297, 1171,    2],
        [ 844,  413,  851,    2],
        [1601,  309, 1605,    2],
        [ 411,  201,   79,    2],
        [ 695,  750,  768,    2],
        [ 502,  586,  587,    2],
        [  37,   23,   39,    2],
        [ 114,    6,  119,    2],
        [ 452,  456,  457,    2],
        [1359, 1352, 1364,    2],
        [ 547,  201,  102,    2],
        [ 419,  302,  345,    2],
        [ 507,  216,  541,    2],
        [ 289,   33,  290,    2],
        [ 835,  297,  839,    2],
        [1117, 1120, 1121,    2],
        [ 401,   33,  573,    2],
        [  73,  231,  240,    2],
        [1395, 1383, 1404,    2],
        [ 339,  302,  345,    2],
        [1043,  915, 1089,    2],
        [ 190,   33,  192,    2],
        [1180,  297, 1324,    2],
        [1803, 1765, 1791,    2],
        [ 867,  868,  869,    2],
        [ 977,  900,  901,    2],
    

Targets:
 tensor([[ 656,  612,  663,    2],
        [ 173,  231,  253,    2],
        [ 656,  619,  658,    2],
        [ 656,  619,  662,    2],
        [ 374,  320,  375,    2],
        [1222, 1110, 1226,    2],
        [1192,  197, 1298,    2],
        [1738, 1798, 1799,    2],
        [ 391,  302,  395,    2],
        [ 531,  880, 1790,    2],
        [ 401,   33,  574,    2],
        [1236, 1112, 1237,    2],
        [ 605,  606,  607,    2],
        [1621,  201,  949,    2],
        [1635,  201, 1640,    2],
        [ 349,  297,  353,    2],
        [1222, 1108, 1225,    2],
        [  79,  231,  280,    2],
        [ 181,    6,  186,    2],
        [ 468,  302,  472,    2],
        [ 672,  608,  609,    2],
        [ 468,  302,  472,    2],
        [ 863,   33,  887,    2],
        [  64,  254,  272,    2],
        [ 452,  456,  457,    2],
        [1621, 1622, 1623,    2],
        [1607,  201, 1605,    2],
        [ 508,  302,  433,    2],
        [1041,  771,   86,    2],
    

Targets:
 tensor([[1076,  105, 1077,    2],
        [ 701,  626,  702,    2],
        [1601,  201, 1607,    2],
        [1175, 1135, 1176,    2],
        [1345, 1352, 1353,    2],
        [1331,  231, 1333,    2],
        [ 680,  619,  683,    2],
        [ 725,  297,  759,    2],
        [ 615,  752,  774,    2],
        [ 374,  299,  379,    2],
        [1601, 1560, 1603,    2],
        [   5,  209,  270,    2],
        [1474,  572, 1546,    2],
        [1647,  209, 1652,    2],
        [ 411,    4,   79,    2],
        [ 611,  612,  613,    2],
        [ 639,  606,  644,    2],
        [1416, 1534, 1535,    2],
        [ 122,    6,  129,    2],
        [1426,  254, 1531,    2],
        [ 531,  880, 1790,    2],
        [ 656,  612,  663,    2],
        [ 358,  299,  366,    2],
        [ 651,  608,  655,    2],
        [ 949,  950,  951,    2],
        [1050,  915, 1089,    2],
        [1581,  201, 1587,    2],
        [ 148,   23,  149,    2],
        [1629,  201, 1722,    2],
    

Targets:
 tensor([[  64,  254,  272,    2],
        [ 908,  904,  921,    2],
        [ 190,   33,  191,    2],
        [ 997,  105,  999,    2],
        [ 908,  919,  920,    2],
        [ 835,  824,  843,    2],
        [  45,   12,   48,    2],
        [ 654,  750,  830,    2],
        [1107, 1113, 1114,    2],
        [1421, 1350, 1412,    2],
        [  92,    4,   79,    2],
        [1107, 1112,  981,    2],
        [ 531,  197,  532,    2],
        [ 507,  197,  765,    2],
        [1609,  309, 1614,    2],
        [ 888,  893,  894,    2],
        [ 122,    6,  126,    2],
        [ 508,  292,  506,    2],
        [ 452,  302,  459,    2],
        [ 633,  626,  634,    2],
        [ 378,  209,  556,    2],
        [1681, 1562, 1685,    2],
        [1449, 1373, 1450,    2],
        [ 122,   12,   35,    2],
        [ 378,   33,  557,    2],
        [1153, 1130, 1156,    2],
        [1107,  297, 1115,    2],
        [ 312,  220,  580,    2],
        [ 692,  614,  695,    2],
    

Targets:
 tensor([[ 291,  302,  304,    2],
        [ 761,  752,  762,    2],
        [ 379,  589,  590,    2],
        [  79,  220,  279,    2],
        [ 984,   55,  985,    2],
        [ 162,  168,  169,    2],
        [   3,    6,    8,    2],
        [ 835,  841,  842,    2],
        [ 419,  424,  425,    2],
        [1412, 1343, 1416,    2],
        [1635,  201, 1638,    2],
        [1009,  105, 1011,    2],
        [1718,  209, 1719,    2],
        [ 335,  309,  547,    2],
        [ 997,  105,  999,    2],
        [1202, 1108, 1205,    2],
        [ 173,  231,  253,    2],
        [1697,  440, 1735,    2],
        [1736, 1739, 1740,    2],
        [1481, 1383, 1487,    2],
        [1443,    4,   79,    2],
        [1384,  564, 1325,    2],
        [1575,  309, 1579,    2],
        [1454, 1547, 1549,    2],
        [1066,  940, 1097,    2],
        [ 615,  750,  611,    2],
        [1140,    4,   79,    2],
        [ 363,   33,  603,    2],
        [1513,    4,   79,    2],
    

Targets:
 tensor([[ 844,  824,  853,    2],
        [  79,  231,  280,    2],
        [   3,    6,    7,    2],
        [ 358,  295,  362,    2],
        [ 419,  422,  423,    2],
        [1598,  201, 1732,    2],
        [1215,    4, 1214,    2],
        [1077,  915, 1090,    2],
        [ 618,  614,  623,    2],
        [ 804,  750,  813,    2],
        [  79,   33,  284,    2],
        [ 715,  413,  719,    2],
        [ 908,  898,  899,    2],
        [  73,   33,  244,    2],
        [ 517,  309,  537,    2],
        [ 331,  314,  332,    2],
        [ 173,  231,  258,    2],
        [ 196,    4,    5,    2],
        [1427, 1357, 1424,    2],
        [1359, 1355, 1365,    2],
        [ 442,  324,  446,    2],
        [ 835,  297,  839,    2],
        [ 637,  797,   22,    2],
        [1449, 1379, 1452,    2],
        [1140,   33, 1285,    2],
        [1598,  201, 1732,    2],
        [1345, 1350, 1351,    2],
        [1681, 1567, 1682,    2],
        [ 667,  612,  670,    2],
    

Targets:
 tensor([[1523, 1524, 1525,    2],
        [1647,  287, 1653,    2],
        [ 580,  201, 1696,    2],
        [1359,    4,   79,    2],
        [ 507,   33,  543,    2],
        [1248,    4,   79,    2],
        [1202, 1108, 1205,    2],
        [1424,    4, 1426,    2],
        [ 615,  748,  777,    2],
        [1331,   33, 1335,    2],
        [ 664,  750,  755,    2],
        [1474,  564, 1545,    2],
        [1581,  201, 1587,    2],
        [1026,  105, 1027,    2],
        [  21,   12,   32,    2],
        [ 701,  608,  609,    2],
        [  68,    6,   70,    2],
        [ 633,  608,  609,    2],
        [ 167,  858, 1796,    2],
        [ 190,   33,  192,    2],
        [ 442,  302,  449,    2],
        [ 531,  197,  532,    2],
        [1762,    4, 1768,    2],
        [ 867,  868,  869,    2],
        [1244,   33, 1296,    2],
        [ 298,    4,  531,    2],
        [1426,  440, 1172,    2],
        [ 358,  297,  163,    2],
        [ 117,   16,  193,    2],
    

Targets:
 tensor([[1744, 1745, 1746,    2],
        [1749,   33, 1797,    2],
        [ 908,  891,  910,    2],
        [ 847,  882,  883,    2],
        [ 991,  105,  992,    2],
        [1165,  297, 1171,    2],
        [ 937,  896,  945,    2],
        [ 962,  889,  963,    2],
        [ 312,  216,  581,    2],
        [ 148,    6,  151,    2],
        [ 419,  302,  430,    2],
        [  21,   30,   31,    2],
        [1443, 1346, 1444,    2],
        [1607,  201, 1606,    2],
        [1647,  201, 1651,    2],
        [ 496,  456,  500,    2],
        [ 374,  305,  385,    2],
        [ 419,  302,  429,    2],
        [  92,    4,   94,    2],
        [ 144,    4,   38,    2],
        [1738,  287, 1802,    2],
        [  54,   58,   60,    2],
        [1331,  216, 1331,    2],
        [1664, 1564, 1672,    2],
        [1502, 1511, 1512,    2],
        [ 458,    4,  560,    2],
        [ 367,  550,  551,    2],
        [1432, 1343, 1434,    2],
        [ 349,  305,  357,    2],
    

Targets:
 tensor([[1475, 1375, 1477,    2],
        [1395, 1381, 1399,    2],
        [1142, 1135, 1143,    2],
        [ 517,    4,   79,    2],
        [1804, 1739, 1806,    2],
        [1210,  204, 1337,    2],
        [1214,   33, 1308,    2],
        [1654, 1657, 1658,    2],
        [1202,  297, 1211,    2],
        [ 379,  597,  598,    2],
        [ 452,  314,  453,    2],
        [1755,  209, 1808,    2],
        [ 963,  971,  972,    2],
        [1687, 1564, 1691,    2],
        [ 538,  539,  540,    2],
        [1449, 1373, 1450,    2],
        [1194, 1110, 1195,    2],
        [  79,  287,  288,    2],
        [ 547,    4,   79,    2],
        [1005,   55, 1006,    2],
        [ 615,  750,  611,    2],
        [1631,  201, 1701,    2],
        [1175,  413, 1178,    2],
        [ 331,  302,  338,    2],
        [   3,   12,   13,    2],
        [ 508,  456,  509,    2],
        [ 167,  858, 1796,    2],
        [ 692,  608,  609,    2],
        [ 513,  295,  516,    2],
    

Targets:
 tensor([[ 908,  896,  912,    2],
        [ 937,  938,  939,    2],
        [ 398,  297,  401,    2],
        [  95,  197,  206,    2],
        [  22,  220,   48,    2],
        [   5,  254,  268,    2],
        [1181,  413, 1190,    2],
        [1086, 1087, 1096,    2],
        [  64,  254,  272,    2],
        [1372, 1375, 1376,    2],
        [1762, 1781, 1782,    2],
        [1009, 1003, 1010,    2],
        [1165, 1163, 1170,    2],
        [ 645,  606,  650,    2],
        [1481, 1383, 1488,    2],
        [ 391,  305,  397,    2],
        [ 908,  919,  920,    2],
        [ 963,  973,  974,    2],
        [ 804,  752,  805,    2],
        [  22,  209,  226,    2],
        [1575, 1560, 1576,    2],
        [  86,   33,  278,    2],
        [1395, 1381, 1399,    2],
        [  21,    4,   22,    2],
        [ 654,  750,  831,    2],
        [ 664,  748,  756,    2],
        [1202, 1108, 1206,    2],
        [1647, 1560, 1649,    2],
        [1175, 1113, 1114,    2],
    

Targets:
 tensor([[ 477,  297,  480,    2],
        [ 477,  297,  480,    2],
        [ 341,  201,  386,    2],
        [ 615,  752,  774,    2],
        [1726,    4,   79,    2],
        [1158, 1163, 1164,    2],
        [  79,   33,  283,    2],
        [1292,  197, 1329,    2],
        [ 398,  297,  400,    2],
        [1412, 1346, 1413,    2],
        [1214,  231, 1307,    2],
        [ 378,  209,  556,    2],
        [1601, 1560, 1603,    2],
        [ 577,   33,  578,    2],
        [1687,  287, 1572,    2],
        [ 477,  322,  479,    2],
        [1359, 1352, 1364,    2],
        [ 391,  316,  317,    2],
        [1762, 1783, 1784,    2],
        [1467, 1377, 1469,    2],
        [ 398,  302,  403,    2],
        [1753,  333, 1754,    2],
        [ 963,  973,  974,    2],
        [1043,  915, 1089,    2],
        [ 844,  850,  849,    2],
        [ 686,  749,  788,    2],
        [ 690,  749,  801,    2],
        [ 997,  105,  999,    2],
        [ 331,  333,  334,    2],
    

KeyboardInterrupt: 

# Test
### Used exclusively for evaluation on test data after training is fuly finished

## Evaluation - Bleu Score

In [None]:
# Function for calculating the BLEU score for multiple sentence.
def calculate_bleu(data, train, dev, model, max_len = 7):
    
    trgs = []
    pred_trgs = []       
    src = dev
    trg = test
    # Get the data and feed it into pred_trg after Seq2seq
    #pred_trg = pred_trg[:-1]      
    #pred_trgs.append(pred_trg)
    #trgs.append([trg])
        
    return corpus_bleu(pred_trgs, trgs)