# Implementation of the Transformer architecture

### This is the notebook used to train and test the SPARQL NMT Transformer model for the Knowledge Base-aware SPARQL Query Translation from Natural Language article


Here are some interesting references that helped us in our implementation:
- https://github.com/bentrevett/pytorch-seq2seq
- https://huggingface.co/spaces/gradio/HuBERT/blob/main/fairseq/models/transformer.py
- https://github.com/cestwc/pointer-transformer-model-pytorch

## Setup

Please note that using [wandb](https://wandb.ai/site) is not required, but suggested as it provides a great way to track model perfomances during training. Install the package and set the const USE_WANDB to true if you wish to use it!

In [None]:
!pip install transformers
!pip install --upgrade spacy
!python -m spacy download en_core_web_sm
!pip install SPARQLWrapper
!pip install torchtext==0.11.0
!pip install wandb

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
from torch import tensor
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

import torchtext
from torchtext.legacy.data import Field, BucketIterator
from torchtext.legacy.data.dataset import TabularDataset

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

import spacy
import numpy as np

import random
import math
import time
import unicodedata
import re

import pandas as pd
from sklearn.model_selection import train_test_split
from torchtext.data.metrics import bleu_score
from transformers import BertModel, AutoModel
from torchtext.data.utils import ngrams_iterator

from transformers import EncoderDecoderModel, AutoTokenizer, BertTokenizer

import json
from nltk import ngrams

from SPARQLWrapper import SPARQLWrapper, JSON
from collections import Counter, defaultdict
from google.colab import files, drive
from typing import List, Dict, Tuple, DefaultDict, Union, Optional
import copy

In [None]:
# connect to google drive
drive.mount('/content/gdrive')

## Consts

In [None]:
# Model Parameters
BATCH_SIZE = 128
HID_DIM    = 1024
ENC_LAYERS = 6
DEC_LAYERS = 6
ENC_HEADS  = 4
DEC_HEADS  = 4
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.5
DEC_DROPOUT = 0.5
LEARNING_RATE = 0.0005
CLIP = 1
MAX_LENGTH = 50          # After MAX_LENGHT tokens are predicted by the model without reaching <eos>, it will stop trying
N_EPOCHS = 150            # Train for how many epochs

DATASET = 'dataset.json' # Name of the dataset to use, found in the Data repo under out_data
USE_WANDBAI = False       # Use WANDBAI as a logging tool
USE_COPY = True          # Train the models using a copy layer

RANDOM = False           # Randomize the order of the entries
LOWERCASE = False        # force lowercase for query and questions - 
                         # not recommended because it makes it very hard to go back to working SPARQL queries

TAGS = ['dbr:', 'dbo:', 'dbp:', 'dbc:', 'dct:', 'geo:', 'georss:']

DatasetIterator = torchtext.legacy.data.Dataset

MODEL_TYPE = "transformer"
COPY_FLAG = "copy" if USE_COPY else "no_copy"
DATASET_FAMILY = "Monument"
DATASET_NAME = "mon_base_tagged_all_no_resources" # DONT FORGET TO SET

## Utils

### Metrics

In [None]:
# Translate a question into a sparql query
def translate_sentence(tokens: Union[str, List[str]], 
                       src_field: Field,
                       trg_field: Field,
                       model: nn.Module,
                       device: torch.device,
                       max_len=MAX_LENGTH, 
                       predict_with_copy=USE_COPY) -> Tuple[List[str], int]:

    model.eval()
    
    # format as a list of strs
    if isinstance(tokens, str):
        tokens = tokenize_en(tokens)
    if LOWERCASE:
      tokens = [token.lower() for token in tokens]
    print('1-TOKENS:', tokens)
    # extend vocab with KB elems
    if predict_with_copy:
      resources_to_extend = extract_KB_elems(tokens)
      print('2-RES TO EXTEND:', resources_to_extend)
      KB_vocab = VocabDup(resources_to_extend, padding=0)
      print('3-KB VOCAB:', KB_vocab.itos)


      print('SRC FIELD BEFORE:', len(src_field.vocab))
      print('TRG FIELD BEFORE:', len(trg_field.vocab))

      src_field = extend_vocabulary(src_field, KB_vocab)
      trg_field = extend_vocabulary(trg_field, KB_vocab)
   
      print('SRC FIELD AFTER:', len(src_field.vocab))
      print('TRG FIELD AFTER:', len(trg_field.vocab))

    # add <sos> and <eos> delimiters
    tokens = [src_field.init_token] + tokens + [src_field.eos_token]

    print('4-TOKENS+SOS+EOS:', tokens)

    # index sentence with extended src vocab
    src_indexes = [src_field.vocab.stoi[token] for token in tokens]
    print('5-SRC_INDEXES:', src_indexes)
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    
    # make src_mask
    src_mask = model.make_src_mask(src_tensor)

    with torch.no_grad():
        print('... encoder')
        enc_src = model.encoder(src_tensor.masked_fill(src_tensor >= model.encoder.tok_embedding.num_embeddings, 0), src_mask)

    # init empty query sentence with only <sos> token
    trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
    # generate words
    for i in range(max_len - 2):
        print('6-TRG INDEXES:', trg_indexes)
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            print('\t... decoder')
            output, attention = model.decoder(trg_tensor.masked_fill(trg_tensor >= model.decoder.tok_embedding.num_embeddings, 0), enc_src, trg_mask, src_mask)
            if predict_with_copy:
              print('\t... copy')
              output, attention = model.copy_layer(src_tensor, output, attention)
        
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)

        if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
            break

    print('7-FINAL TRG INDEXES:', trg_indexes)
    trg_tokens = [trg_field.vocab.itos[i] if i < len(trg_field.vocab) else '<unk>' for i in trg_indexes]
    return trg_tokens[1:], attention

In [None]:
# Translate a question into a sparql query
def batch_translate(batch: torch.Tensor,
                    trg_field: Field,
                    model: nn.Module,
                    device: torch.device,
                    max_len=MAX_LENGTH, 
                    predict_with_copy=USE_COPY) -> Tuple[List[str], int]:

      model.eval()
    
      src_tensor = batch.English
      src_mask = model.make_src_mask(src_tensor)
        
      with torch.no_grad():
          enc_src = model.encoder(src_tensor.masked_fill(src_tensor >= model.encoder.tok_embedding.num_embeddings, 0), src_mask)


      # init empty query sentence with only <sos> token
      trg_tensor = torch.full((src_tensor.shape[0], 1), 2).to(device)

      # generate words
      for test in range(max_len - 2):
          trg_mask = model.make_trg_mask(trg_tensor)
          with torch.no_grad():
              output, attention = model.decoder(trg_tensor.masked_fill(trg_tensor >= model.decoder.tok_embedding.num_embeddings, 0), enc_src, trg_mask, src_mask)
              if predict_with_copy:
                output, attention = model.copy_layer(src_tensor, output, attention)
        
          output = output[:, -1, :]
          pred_token = output.argmax(1)
          pred_token = torch.unsqueeze(pred_token, dim=1)
          trg_tensor = torch.cat((trg_tensor, pred_token), dim=1)
    
      # remove after eos
      out = []
      for sent in trg_tensor.cpu().numpy():

        if trg_field.vocab.stoi[trg_field.eos_token] in sent:
          try:
            eos_id = np.where(sent == trg_field.vocab.stoi[trg_field.eos_token])[0][0]
            sent = sent[:eos_id]

          except:
            pass

        out.append([trg_field.vocab.itos[i] if i < len(trg_field.vocab) else '<unk>' for i in sent][1:])
    
      return out, attention

In [None]:
# Calculate the BLEU score of our test set by batch
def batch_bleu(iterator: DatasetIterator, 
               trg_field: Field, # use BASE_TRG for syntax!
               model: nn.Module,
               device: torch.device,
               use_copy=USE_COPY) -> float:

    bleu_preds = []
    bleu_expected = []

    for _, batch in enumerate(iterator):
      preds, _ = batch_translate(batch, trg_field, model, device, predict_with_copy=use_copy)
      bleu_preds.extend(preds)
      expected = get_batch_tokens(batch.SPARQL[:,1:-1], trg_field)
      bleu_expected.extend(expected)

    return bleu_score(bleu_preds, [[sent] for sent in bleu_expected])

In [None]:
# Convert a batch of token ids to a batch of tokens (by Samuel)
def get_batch_tokens(batch: torch.Tensor, field: Field) -> List[str]:
    output_tokens = []
    for pred_trg in batch:
        eos_ids = (pred_trg == field.vocab.stoi[field.eos_token]).nonzero(as_tuple=True)[0]
  
        if eos_ids.nelement():
            non_eos_tokens_ids = pred_trg[:eos_ids[0]]
        else:
            non_eos_tokens_ids = pred_trg
        
        output_tokens.append([field.vocab.itos[tok] if tok <= len(field.vocab) else '<unk>' for tok in non_eos_tokens_ids])
    return output_tokens

In [None]:
# Calculate the BLEU score of our test set
def calculate_bleu(data: DatasetIterator, 
                   src_field: Field, 
                   trg_field: Field, 
                   model: nn.Module, 
                   device: torch.device,
                   predict_with_copy=USE_COPY, 
                   max_len=MAX_LENGTH) -> Tuple[float, float]:
    print("Calculating BLEU score...")
    expected_trgs = []
    pred_trgs = []
    pred_copy_trgs = []

    for datum in data:
        src = vars(datum)['English']
        trg = vars(datum)['SPARQL']

        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len, predict_with_copy=predict_with_copy)
        if pred_trg[-1] is trg_field.eos_token:
          pred_trg = pred_trg[:-1]
        pred_trgs.append(pred_trg)
        expected_trgs.append([trg])
        
    return bleu_score(pred_trgs, expected_trgs)

In [None]:
# Calculate the BLEU score of our test set - when using copy we want syntax
def calculate_bleu_syntax(data: DatasetIterator, 
                          src_field: Field, 
                          trg_field: Field, 
                          model: nn.Module, 
                          device: torch.device, 
                          max_len=MAX_LENGTH) -> float:
    print("Calculating BLEU score of syntax...")
    expected_trgs = []
    expected_syntax = []
    pred_trgs_syntax = []
    pred_copy_trgs = []

    for datum in data:
        src = vars(datum)['English']
        trg = vars(datum)['SPARQL']
        pred_syntax, _ = translate_sentence(src, src_field, trg_field, model, device, max_len, predict_with_copy=True)
        if pred_syntax[-1] is trg_field.eos_token:
          pred_syntax = pred_syntax[:-1]

        pred_trgs_syntax.append(pred_syntax)

        trg_syntax = ['<unk>' if token.startswith(tuple(TAGS)) else token for token in trg]
        expected_syntax.append([trg_syntax])

    return bleu_score(pred_trgs_syntax, expected_syntax)

In [None]:
# calculates some metrics on the test set
def get_metrics(data: DatasetIterator,
                test_entries: List[Dict], 
                src_field: Field, 
                trg_field: Field, 
                model: nn.Module,
                device: torch.device, 
                max_len=MAX_LENGTH, 
                predict_with_copy=USE_COPY) -> Dict[str, float]:
    print("Computing evaluation metrics...")
    expected_trgs = []
    pred_trgs = []
    error_report = []

    for i, datum in enumerate(data):
        src = vars(datum)['English']
        trg = vars(datum)['SPARQL']

        pred_trg, _ = translate_sentence(src, src_field, trg_field, model, device, max_len, predict_with_copy)

        pred_trg = pred_trg[:-1]

        pred_trgs.append(pred_trg)
        expected_trgs.append([trg])

        error_entry = {
            'id': test_entries[i]['_id'],
            'template_id': test_entries[i]['template_id'],
            'src': ' '.join(src),
            'trg': ' '.join(trg),
            'predicted': ' '.join(pred_trg),
            'correct': trg == pred_trg
        }
        error_report.append(error_entry)

    metrics = {}
    nb_examples = len(expected_trgs)
    metrics['bleu'] = bleu_score(pred_trgs, expected_trgs)
    metrics['accuracy'] = sum([int(pred_trgs[i] == expected_trgs[i][0]) for i in range(nb_examples)])/nb_examples

    pred_ngrams = [list(ngrams_iterator(pred, len(pred))) for pred in pred_trgs]
    exp_ngrams = [list(ngrams_iterator(exp[0], len(exp[0]))) for exp in expected_trgs]

    #https://towardsdatascience.com/the-ultimate-performance-metric-in-nlp-111df6c64460
    n_commons = [len(set(pred_ngrams[i]) & set(exp_ngrams[i])) for i in range(nb_examples)]

    recalls = [n_commons[i] / len(exp_ngrams[i]) for i in range(nb_examples)]
    metrics['macro recall'] = sum(recalls) / len(recalls)

    precisions = [n_commons[i] / len(pred_ngrams[i]) for i in range(nb_examples)]
    metrics['macro precision'] = sum(precisions) / len(precisions)

    metrics['f1 score'] = 2 * (metrics['macro precision'] * metrics['macro recall']) / (metrics['macro precision'] + metrics['macro recall'])

    with open('out/error_report.json', 'w', encoding='utf-8') as f:
      json.dump(error_report, f, indent=4)

    return metrics

### Vocab

In [None]:
# imitation of a torchtext.vocab.Vocab, basic structure needed to extend a torchtext Vocab
class VocabDup:
  def __init__(self, vocab: Union[Dict[int, str], List[str]], padding=0, base_vocab_size=0):
    if type(vocab) is list:
      self.make_vocab_from_list(vocab, padding)

    elif type(vocab) is dict:
      self.make_vocab_from_dict(vocab, base_vocab_size)
    
    else:
      raise ValueError("Could not make a vocab from this structure")


  # Make vocab from a list (usually KB elem list) to use it to extend base vocabs
  def make_vocab_from_list(self, word_list: List[str], padding=0) -> None:
      word_list = list(set(word_list))
      word_counter = Counter(word_list)
      stoi = defaultdict(int)
      itos = [None for _ in range(len(word_list) + padding)]

      curr_idx = 0
      # pad if necessary
      for i in range(padding):
          word = f'not_a_resource_{i}'
        
          stoi[word] = curr_idx
          itos[curr_idx] = word
          curr_idx+=1

      # add KB elems
      for word in word_counter:
          stoi[word] = curr_idx
          itos[curr_idx] = word
          curr_idx+=1

      self.freq = word_counter
      self.itos = itos
      self.stoi = stoi

  # Make vocab from a dict (usually when loading the vocab files)
  def make_vocab_from_dict(self, word_dict: Dict[int, str], base_vocab_size: int=0) -> None:
      stoi = defaultdict(int)
      base_vocab_size = len(word_dict.values()) if base_vocab_size < 1 else base_vocab_size
      itos = [None for _ in range(base_vocab_size)]

      for idx, word in word_dict.items():
          if idx < base_vocab_size:
              stoi[word] = idx
              itos[idx] = word

      word_counter = Counter(itos)

      self.freq = word_counter
      self.itos = itos
      self.stoi = stoi

In [None]:
# by samuel
def hide_KB_elems(tokens: List[str], unk_token = '<unk>') -> List[str]:
  return [unk_token if token.startswith(tuple(TAGS)) else token for token in tokens]

In [None]:
# Extract KB elements from a tokenized sentence
def extract_KB_elems(tokens: List[str]) -> List[str]:
  removed_resources_en = [t for t in tokens if t.startswith(tuple(TAGS))]
  return removed_resources_en

In [None]:
# This function acts exactly like the PyTorch version, but using the PyTorch version Field.vocab.extend_vocabulary cause
# some seriously weird bugs. Our best guess was that it caused collisions in the dict keys, but it is highly unlikely
def extend_vocabulary(field: Field, extension: VocabDup) -> Field:
    words = extension.itos
    for w in words:
        if w not in field.vocab.itos: # stoi does not work
            field.vocab.itos.append(w)
            field.vocab.stoi[w] = len(field.vocab.itos) - 1

    return field

In [None]:
# It is possible that a query contains a KB elem that is in the KB vocab but not in the question (for example, LC-QuAD template ID 7)
# In that case, we should replace KB elems that are not in BOTH the query and the question by unknown tokens (0)
def fix_extended_vocab(src: List[List[int]], trg: List[List[int]], base_voc_limit_trg: int, unk_token = 0) -> List[List[int]]:
  for i_s, sentence in enumerate(trg): # batch size
    for i_t, token_idx in enumerate(trg[i_s]): # batch size
      if token_idx >= base_voc_limit_trg and token_idx not in src[i_s]:
        trg[i_s][i_t] = unk_token

  return trg

In [None]:
# Save vocab to reuse for inference
def save_vocab(vocab: torchtext.vocab.Vocab, path: str) -> None:
    with open(path, 'w', encoding='utf-8') as f:     
        for token, index in vocab.stoi.items():
            f.write(f'{index}\t{token}\n')

In [None]:
# Read vocab files
def read_vocab(path: str) -> Dict[int, str]:
    voc = {}
    i = 0
    with open(path, 'r', encoding='utf-8') as f:
        data = f.read().splitlines()
        for line in data:
            index, token = line.split('\t')
            voc[i] = token
            i += 1
    return voc

In [None]:
# Split vocbaularies: keep only the base words of question and queries, and save all removed KB elems in another list
def abstract_KB_elems(data) -> Tuple[Dict, Dict]:
  base_vocabs = {'English': [], 'SPARQL': []}
  kb_vocabs = {'English': [], 'SPARQL': []}

  for example in data:
    nl = example.English
    sparql = example.SPARQL

    # for nl
    filtered_nl = [t for t in nl if not t.startswith(tuple(TAGS))]
    removed_resources_nl = [t for t in nl if t.startswith(tuple(TAGS))]

    # for sparql
    filtered_sparql = [t for t in sparql if not t.startswith(tuple(TAGS))]
    removed_resources_sparql = [t for t in sparql if t.startswith(tuple(TAGS))]

    # keep separated by sentences
    base_vocabs['English'].append(filtered_nl)
    base_vocabs['SPARQL'].append(filtered_sparql)

    # a single list of all KB elems
    kb_vocabs['English'].extend(removed_resources_nl)
    kb_vocabs['SPARQL'].extend(removed_resources_sparql)

  return base_vocabs, kb_vocabs

### Data

In [None]:
# Tokenize a question by splitting at spaces
def tokenize_en(text: str) -> List[str]:
    splitted = text.split()
    return [w for w in splitted if len(w) > 0]

In [None]:
# Tokenize a query by splitting at spaces
def tokenize_sparql(text: str) -> List[str]:
    splitted = text.split()
    return [w for w in splitted if len(w) > 0]

In [None]:
# Generate train, val and test sets
def gen_train_test_val_sets(train_examples: List[str], 
                            valid_examples: List[str], 
                            test_examples: List[str], 
                            data_fields: List[Tuple[str, Field]]) -> Tuple[TabularDataset, TabularDataset, TabularDataset]:
    train_set = pd.DataFrame(train_examples, columns=["English", "SPARQL"])
    valid_set = pd.DataFrame(valid_examples, columns=["English", "SPARQL"])
    test_set = pd.DataFrame(test_examples, columns=["English", "SPARQL"])

    train_set = pd.DataFrame(train_set, columns=["English", "SPARQL"])
    valid_set = pd.DataFrame(valid_set, columns=["English", "SPARQL"])
    test_set = pd.DataFrame(test_set, columns=["English", "SPARQL"])

    train_set.to_csv("train.csv", index=False, header=None)
    valid_set.to_csv("valid.csv", index=False, header=None)
    test_set.to_csv("test.csv", index=False, header=None)

    train_data, valid_data, test_data = torchtext.legacy.data.TabularDataset.splits(
        path='./', train='train.csv', validation='valid.csv', test='test.csv', format='csv', fields=data_fields)

    return train_data, valid_data, test_data

In [None]:
# Generate the data fields used to encode the question-query pairs
def gen_data_field() -> Tuple[Field, Field]:
    SRC = Field(tokenize=tokenize_en,
                init_token='<sos>',
                eos_token='<eos>',
                lower=LOWERCASE,
                batch_first=True)

    TRG = Field(tokenize=tokenize_sparql,
                init_token='<sos>',
                eos_token='<eos>',
                lower=LOWERCASE,
                batch_first=True)

    return SRC, TRG

### Training

In [None]:
# Initialize model weights
def initialize_weights(m) -> None:
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)

In [None]:
# Count number of parameters in the model
def count_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [None]:
# Calculate epoch duration
def epoch_time(start_time: float, end_time: float) -> Tuple[float, float]:
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
# Training function of the model
def train(model: nn.Module, 
          iterator: DatasetIterator, 
          optimizer: torch.optim.Optimizer, 
          criterion: nn.Module, clip: float, 
          use_copy=USE_COPY) -> float:

    model.train()
    epoch_loss = []

    for _, batch in enumerate(iterator):
        
        src = batch.English
        trg = batch.SPARQL
        
        if use_copy: 
          trg = fix_extended_vocab(src, trg, OUT_TRG_DIM)

        optimizer.zero_grad()
        
        output, _ = model(src, trg[:,:-1])
        output_dim = output.shape[-1]
           
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1) 

        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        optimizer.step()

        epoch_loss += [loss.item()]

    return epoch_loss[-1]

In [None]:
# Eval function of the model
def evaluate(model: nn.Module, 
             iterator: DatasetIterator, 
             criterion: nn.Module, 
             use_copy=USE_COPY) -> float:

    model.eval()
    epoch_loss = []

    with torch.no_grad():

        for _, batch in enumerate(iterator):

            src = batch.English
            trg = batch.SPARQL

            if use_copy: 
              trg = fix_extended_vocab(src, trg, OUT_TRG_DIM)

            output, _ = model(src, trg[:,:-1])
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)

            loss = criterion(output, trg)
            epoch_loss += [loss.item()]
        
    return epoch_loss[-1]

## Model

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, 
                 hid_dim:int, 
                 n_heads:int, 
                 dropout:float, 
                 device: torch.device):

        super().__init__()

        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)

        self.fc_o = nn.Linear(hid_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, 
                query: torch.Tensor, 
                key: torch.Tensor, 
                value: torch.Tensor, 
                mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = query.shape[0]

        Q = self.fc_q(query)
        K = self.fc_k(key)
        V = self.fc_v(value)

        Q = Q.view(batch_size, -1, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.n_heads,
                   self.head_dim).permute(0, 2, 1, 3)

        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)

        attention = torch.softmax(energy, dim=-1)

        x = torch.matmul(self.dropout(attention), V)

        x = x.permute(0, 2, 1, 3).contiguous()

        x = x.view(batch_size, -1, self.hid_dim)

        x = self.fc_o(x)

        return x, attention

In [None]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self, 
                 hid_dim: int, 
                 pf_dim: int, 
                 dropout: float):
        super().__init__()

        self.fc_1 = nn.Linear(hid_dim, pf_dim)
        self.fc_2 = nn.Linear(pf_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.dropout(torch.relu(self.fc_1(x)))
        x = self.fc_2(x)
        return x

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 hid_dim: int,
                 n_heads: int,
                 pf_dim: int,
                 dropout: float,
                 device: torch.device):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(
            hid_dim, n_heads, dropout, device) 
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
                                                                     pf_dim,
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, 
                src: torch.Tensor,
                src_mask: torch.Tensor) -> torch.Tensor:

        _src, _ = self.self_attention(src, src, src, src_mask)

        # dropout, residual connection and layer norm
        src = self.self_attn_layer_norm(src + self.dropout(_src))

        # positionwise feedforward
        _src = self.positionwise_feedforward(src)

        # dropout, residual and layer norm
        src = self.ff_layer_norm(src + self.dropout(_src))

        return src

In [None]:
class Encoder(nn.Module):
    def __init__(self, 
                 input_dim: int, 
                 hid_dim: int, 
                 n_layers: int,
                 n_heads: int, 
                 pf_dim: int, 
                 dropout: float, 
                 device: torch.device, 
                 max_length=MAX_LENGTH):
        super().__init__()

        self.device = device

        self.tok_embedding = nn.Embedding(input_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)

        self.layers = nn.ModuleList([EncoderLayer(hid_dim,
                                                  n_heads,
                                                  pf_dim,
                                                  dropout,
                                                  device)
                                     for _ in range(n_layers)])

        self.dropout = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, 
                src: torch.Tensor, 
                src_mask: torch.Tensor) -> torch.Tensor:
        batch_size = src.shape[0]
        src_len = src.shape[1]

        pos = torch.arange(0, src_len).unsqueeze(
            0).repeat(batch_size, 1).to(self.device)

        src = self.dropout(
            (self.tok_embedding(src) * self.scale) + self.pos_embedding(pos))

        for layer in self.layers:
            src = layer(src, src_mask)

        return src

In [None]:
class DecoderLayer(nn.Module):
    def __init__(self,
                 hid_dim: int,
                 n_heads: int,
                 pf_dim: int,
                 dropout: float,
                 device: torch.device):
        super().__init__()

        self.self_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.enc_attn_layer_norm = nn.LayerNorm(hid_dim)
        self.ff_layer_norm = nn.LayerNorm(hid_dim)
        self.self_attention = MultiHeadAttentionLayer(
            hid_dim, n_heads, dropout, device)
        self.encoder_attention = MultiHeadAttentionLayer(
            hid_dim, n_heads, dropout, device)
        self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim,
                                                                     pf_dim,
                                                                     dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, 
                trg: torch.Tensor, 
                enc_src: torch.Tensor, 
                trg_mask: torch.Tensor, 
                src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # self attention
        _trg, _ = self.self_attention(trg, trg, trg, trg_mask)

        # dropout, residual connection and layer norm
        trg = self.self_attn_layer_norm(trg + self.dropout(_trg))

        # encoder attention
        _trg, attention = self.encoder_attention(
            trg, enc_src, enc_src, src_mask)

        # dropout, residual connection and layer norm
        trg = self.enc_attn_layer_norm(trg + self.dropout(_trg))

        # positionwise feedforward
        _trg = self.positionwise_feedforward(trg)

        # dropout, residual and layer norm
        trg = self.ff_layer_norm(trg + self.dropout(_trg))

        return trg, attention


In [None]:
class Decoder(nn.Module):
    def __init__(self,
                 output_dim: int,
                 hid_dim: int,
                 n_layers: int,
                 n_heads: int,
                 pf_dim: int,
                 dropout: float,
                 device: torch.device,
                 max_length=MAX_LENGTH):
        super().__init__()

        self.device = device
        self.hid_dim = hid_dim
        self.output_dim = output_dim

        self.tok_embedding = nn.Embedding(output_dim, hid_dim)
        self.pos_embedding = nn.Embedding(max_length, hid_dim)

        self.layers = nn.ModuleList([DecoderLayer(hid_dim,
                                                  n_heads,
                                                  pf_dim,
                                                  dropout,
                                                  device)
                                     for _ in range(n_layers)])

        self.fc_out = nn.Linear(hid_dim, output_dim)

        self.dropout_val = dropout

        self.dropout = nn.Dropout(self.dropout_val)

        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, 
                trg: torch.Tensor, 
                enc_src: torch.Tensor, 
                trg_mask: torch.Tensor, 
                src_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]

        pos = torch.arange(0, trg_len).unsqueeze(
            0).repeat(batch_size, 1).to(self.device)

        trg = self.dropout(
            (self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos))

        for layer in self.layers:
            trg, attention = layer(trg, enc_src, trg_mask, src_mask)

        output = self.fc_out(trg)

        return output, attention

In [None]:
class CopyLayerVocabExtend(nn.Module):
  def __init__(self, decoder: Decoder):
    super().__init__()
    self.switch = nn.Linear(decoder.tok_embedding.num_embeddings, 1)

  def forward(self, 
              src: torch.Tensor, 
              output: torch.Tensor, 
              attention: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    p_pointer = torch.sigmoid(self.switch(output))

    # src -> [2, 23, 1129, 40, 1083, 11, 3]
    # output = trg_predit -> [2, 45, 35, 1129, 40, 1083, 12, 3]
    
    if torch.max(src) + 1 > output.shape[-1]: # mots inconnus dans source?
      extended = Variable(torch.zeros((output.shape[0], output.shape[1], torch.max(src) + 1 - output.shape[-1]))).to(output.device)
      output = torch.cat((output, extended), dim = 2)

    output = ((1 - p_pointer) * F.softmax(output, dim = 2)).scatter_add(2, src.unsqueeze(1).repeat(1, output.shape[1], 1), p_pointer * attention[:, 3]) + 1e-10
    # output = trg_predit -> [2, 45, 35, 1129, 40, 1083, 12, 3]
    return torch.log(output), attention

In [None]:
class TransfSeq2Seq(nn.Module):
    def __init__(self,
                 encoder: Encoder,
                 decoder: Decoder,
                 src_pad_idx: int,
                 trg_pad_idx: int,
                 device: torch.device,
                 copy_layer: CopyLayerVocabExtend =None):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

        self.copy_layer = copy_layer

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask

    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones(
            (trg_len, trg_len), device=self.device)).bool()

        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask

    def forward(self, src, trg):

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        if self.copy_layer is None:
          source = self.encoder(src, src_mask)
          output, attention = self.decoder(trg, source, trg_mask, src_mask)
          return output, attention

        else: # avec copie
          # src -> [2, 23, 1129, 40, 1083, 11, 3]
          # trg -> [2, 45, 34, 1129, 40, 1083, 12, 3]

          # src -> [2, 23, 0, 40, 0, 11, 3]
          source = self.encoder(src.masked_fill(src >= self.encoder.tok_embedding.num_embeddings, 0), src_mask)

          # trg -> [2, 45, 34, 0, 40, 0, 12, 3]
          output, attention = self.decoder(trg.masked_fill(trg >= self.decoder.tok_embedding.num_embeddings, 0), source, trg_mask, src_mask)


          # src -> [2, 23, 1129, 40, 1083, 11, 3]
          # output = trg_predit -> [2, 45, 35, 0, 40, 0, 12, 3]
          output, attention = self.copy_layer(src, output, attention)
          return output, attention

# INFERENCE

In [None]:
ITERATION = 1
OUT_DRIVE_FOLDER_BASE = f"/content/gdrive/MyDrive/PRETRAINED/{MODEL_TYPE}/{COPY_FLAG}/{DATASET_FAMILY}/{DATASET_NAME}/{ITERATION}"

In [None]:
CONFIG_PATH = f'{OUT_DRIVE_FOLDER_BASE}/config.json'
SRC_VOCAB_PATH = f'{OUT_DRIVE_FOLDER_BASE}/src_vocab.field'
TRG_VOCAB_PATH = f'{OUT_DRIVE_FOLDER_BASE}/trg_vocab.field'
MODEL_PATH = f'{OUT_DRIVE_FOLDER_BASE}/best-model-state-dict.pt'
OOV_DATASET = 'oov_dataset.json'

In [None]:
# Translator class to facilitate inference, easier portability
# If you use it as a standalone script, make sure to also import the following utils elements:
# VocabDub, gen_data_field(), read_vocab(), extend_vocabulary(), translate_sentence()
# As well as the model architecture:
# Encoder, Decoder, CopyLayerVocabExtend and CNNSeq2Seq

class Translator:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        
        with open(CONFIG_PATH, 'r') as f:
            config = json.load(f)

        base_vocab_size = max(config['INPUT_DIM'], config['OUTPUT_DIM'])

        # load vocab
        self.SRC, self.TRG = gen_data_field()
        self.SRC.build_vocab([], min_freq=1, max_size=None)
        self.TRG.build_vocab([], min_freq=1, max_size=None)

        src_vocab = VocabDup(read_vocab(SRC_VOCAB_PATH), base_vocab_size)
        trg_vocab = VocabDup(read_vocab(TRG_VOCAB_PATH), base_vocab_size)

        self.SRC = extend_vocabulary(self.SRC, src_vocab)
        self.TRG = extend_vocabulary(self.TRG, trg_vocab)
        
        # define model
        self.enc = Encoder(
            config['INPUT_DIM'], config['HID_DIM'], 
            config['ENCODER']['ENC_LAYERS'], config['ENCODER']['ENC_HEADS'], 
            config['ENCODER']['ENC_PF_DIM'], config['ENCODER']['ENC_DROPOUT'], self.device)
  
        self.dec = Decoder(config['OUTPUT_DIM'], config['HID_DIM'], 
                           config['DECODER']['DEC_LAYERS'], config['DECODER']['DEC_HEADS'], 
                           config['DECODER']['DEC_PF_DIM'], config['DECODER']['DEC_DROPOUT'], self.device)

        self.copy_layer = CopyLayerVocabExtend(self.dec) if config['USE_COPY'] else None
        self.model = TransfSeq2Seq(self.enc, self.dec, config['SRC_PAD_IDX'], config['TRG_PAD_IDX'], self.device, self.copy_layer).to(self.device)

        # load pretrained model
        loaded = torch.load(MODEL_PATH)
        self.model.load_state_dict(loaded)
        self.model.eval()
   
    def translate(self, sentence: str) -> List[str]:
      translation, _ = translate_sentence(
          sentence.split(), self.SRC, self.TRG, self.model, self.device, predict_with_copy=self.model.copy_layer is not None)
      
      return translation

    def calculate_bleu(self, test_data: List[Dict]) -> float:
      print("Calculating BLEU score...")

      expected_trgs = []
      pred_trgs = []
      pred_copy_trgs = []
      error_report = []

      for entry in test_data:
          src_sentence = entry['question']['uri_question_rest_no_resources']
          trg_sentence = entry['query']['uri_interm_sparql_rest_no_resources']
          
          pred_trg = self.translate(src_sentence)
          print("OUT:", pred_trg)
          die()
          pred_trg = pred_trg[:-1]
          pred_trgs.append(pred_trg)
          expected_trgs.append([trg_sentence.split()])

          error_entry = {
            'id': entry['_id'],
            'template_id': entry['template_id'],
            'src': src_sentence,
            'trg': trg_sentence,
            'predicted': ' '.join(pred_trg),
            'correct': trg_sentence == pred_trg
          }

          error_report.append(error_entry)

      bleu = bleu_score(pred_trgs, expected_trgs)

      with open(f'{OUT_DRIVE_FOLDER_BASE}/error_report_oov.json', 'w', encoding='utf-8') as f:
        json.dump(error_report, f, indent=4)

      return bleu

In [None]:
translator = Translator()

with open(OOV_DATASET, 'r', encoding='utf-8') as f:
  dataset = json.load(f)
  
print(translator.calculate_bleu(dataset))