In [1]:
import torch 
import pandas as pd
from torchtext.data.utils import get_tokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import sys
sys.path.append('./data_prep')
from sentence_dataset_class import ProcessedSentences
from sentence_processing import build_vocab,sentence_processing
sys.path.append('./transformer_testing')
from tomislav_transformer import Seq2SeqTransformer
import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def stringify_series(df):
    df['input_data'] = df['input_data'].astype('string')
    df['output_data'] = df['output_data'].astype('string')
    return df

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=device).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
df_train = pd.read_json('data/train_data.json')
df_test = pd.read_json('data/test_data.json')

In [4]:
token_transform = get_tokenizer('basic_english')
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
train_input_vocab = build_vocab(df_train['input_data'],token_transform,special_symbols)
train_output_vocab = build_vocab(df_train['output_data'],token_transform,special_symbols)

In [5]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(device)
    src_mask = src_mask.to(device)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(device)
    for i in range(max_len-1):
        memory = memory.to(device)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(device)
        out = model.decode(ys, memory, tgt_mask)
        out = out.transpose(0, 1)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,
                        torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys
  # actual function to translate input sentence into target language
def translate(model: torch.nn.Module, input_sentence: str):
    model.eval()
    src = sentence_processing(input_sentence,
                              train_input_vocab,
                              token_transform,
                              BOS_IDX,
                              EOS_IDX).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    return " ".join(train_output_vocab.lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [6]:
torch.manual_seed(0)
input_vocab_size = len(train_input_vocab)
output_vocab_size = len(train_output_vocab)

hyper_params = {}
with open('results2_pad3/11_08-2022_19_50_28_hyperparameters.json','r') as f:
    hyper_params = json.load(f)
    
emb_size = hyper_params['emb_size']
n_head = hyper_params['n_head']
ffn_hid_dim = hyper_params['ffn_hid_dim']
batch_size = hyper_params['batch_size']
num_encoder_layers = hyper_params['num_encoder_layers']
num_decoder_layers = hyper_params['num_decoder_layers']

In [7]:
transformer = Seq2SeqTransformer(
    num_encoder_layers,
    num_decoder_layers,
    emb_size,
    n_head,
    input_vocab_size,
    output_vocab_size,
    ffn_hid_dim)
transformer.load_state_dict(torch.load('results2_pad3/11_08-2022_19_50_28_model.pt'))
transformer.to(device)

Seq2SeqTransformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_feature

In [8]:
translate(transformer,"Bring me the document")

' please let me know if you would like to bring'

In [9]:
df_test['predicted'] = df_test['input_data'].apply(lambda x: translate(transformer,x))

In [11]:
df_test.head()

Unnamed: 0,input_data,output_data,predicted
0,i do not think there will be any issues and or...,i do not think there will be any issues and sh...,i do not think there will be any late afterno...
1,"it through concord is 90,000dth between niagar...","it through concord is available for 90,000dth ...","it is 100 , 000 dth through concord is availa..."
2,"we are posting 50,000 dth excess injection on ...","we are posting 50,000 dth excess injection on ...","we are already seen 50 , 000 dth on either si..."
3,i to remind you that our firm transport open s...,i also want to remind you that our firm transp...,i want to remind you that our firm transport ...
4,sales representative detail our website .,please call your firm sales representative for...,please contact your client representative if ...


In [12]:
df_test.to_json('results2_pad3/11_08-2022_19_50_28_predictions.json')