In [1]:
import os
from pathlib import Path

import numpy as np
import torch

from seq2seq_translation.datasets.datasets import LanguagePairsDatasets
from seq2seq_translation.sentence_pairs_dataset import SentencePairsDataset
from seq2seq_translation.tokenization.sentencepiece_tokenizer import SentencePieceTokenizer
from seq2seq_translation.run import _fix_model_state_dict
from seq2seq_translation.attention import AttentionType
from seq2seq_translation.rnn import EncoderRNN, AttnDecoderRNN
from seq2seq_translation.train_evaluate import inference



In [2]:
torch.random.manual_seed(1234)
np.random.seed(1234)

In [3]:
source_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/en30000')
target_tokenizer = SentencePieceTokenizer(model_prefix='/Users/adam.amster/seq2seq_translation/tokenizer/30000/fr30000')

In [4]:
os.environ['DEVICE'] = 'cpu'

In [5]:
def construct_test_dset():
	test_datasets = LanguagePairsDatasets(
			data_path=Path('/Users/adam.amster/seq2seq_translation/datasets/wmt14_test'),
			source_lang='en',
			target_lang='fr',
			is_test=True
	)
	
	test_dset = SentencePairsDataset(
		datasets=test_datasets,
		idxs=np.arange(len(test_datasets)),
		source_tokenizer=source_tokenizer,
		target_tokenizer=target_tokenizer,
		max_length=None,
	)
	return test_dset

In [6]:
test_dset = construct_test_dset()

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

In [7]:
def construct_model_attention():
	encoder = EncoderRNN(
		input_size=source_tokenizer.processor.vocab_size(),
		hidden_size=1000,
		bidirectional=True,
		pad_idx=source_tokenizer.processor.pad_id(),
		embedding_dim=1000,
		num_layers=4,
	)

	decoder = AttnDecoderRNN(
		hidden_size=1000,
		attention_size=1000,
		output_size=target_tokenizer.processor.vocab_size(),
		encoder_bidirectional=True,
		max_len=200,
		attention_type=AttentionType.CosineSimilarityAttention,
		encoder_output_size=encoder.output_size,
		pad_idx=source_tokenizer.processor.pad_id(),
		num_embeddings=target_tokenizer.processor.vocab_size(),
		sos_token_id=source_tokenizer.processor.bos_id(),
		embedding_dim=1000,
		num_layers=4,
		eos_token_id=target_tokenizer.processor.eos_id()
	)
	
	encoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/seq2seq_translation/weights/attention/encoder.pt', map_location='cpu'))
	)
	decoder.load_state_dict(
		_fix_model_state_dict(torch.load('/Users/adam.amster/seq2seq_translation/weights/attention/decoder.pt', map_location='cpu'))
	)
	encoder.eval()
	decoder.eval()
	
	return encoder, decoder

In [8]:
def check_padding():
	encoder, decoder = construct_model_attention()
	for i in range(5):
		num_pad_tokens = torch.randint(size=(1, 1), low=0, high=50)[0].item()
		dset_idx = torch.randint(size=(1, 1), low=0, high=len(test_dset))[0].item()
		
		input = test_dset[dset_idx][0]
		input_with_pad_tokens = torch.cat([input, (torch.ones(num_pad_tokens) * source_tokenizer.processor.pad_id()).long()])
		
		print(f'input: {source_tokenizer.decode(input)}')
		
		_, _, _, decoded_ids_no_padding = inference(
			encoder=encoder,
			decoder=decoder,
			input_tensor=input.unsqueeze(0),
			input_lengths=[len(input)]
		)
		
		_, _, _, decoded_ids_padding = inference(
			encoder=encoder,
			decoder=decoder,
			input_tensor=input_with_pad_tokens.unsqueeze(0),
			input_lengths=[len(input_with_pad_tokens)]
		)
		
		print(f'output with no pad tokens appended to input == output with {num_pad_tokens} pad tokens appended to input', (decoded_ids_padding == decoded_ids_padding).all().item())
		
		print('='*11)

In [9]:
check_padding()

input: Moore told reporters that the initial autopsy indicated Johnson died as a result of "positional asphyxia."
output with no pad tokens appended to input == output with 42 pad tokens appended to input True
input: "The alarms went off one after the other as it took reinforcements to deal with the building, because the buildings here are close together," explained France Loiselle, spokesperson for the Quebec fire service.
output with no pad tokens appended to input == output with 13 pad tokens appended to input True
input: Rehousing Due to Rats Causes Strife in La Seyne
output with no pad tokens appended to input == output with 29 pad tokens appended to input True
input: I see her eyes resting on me.
output with no pad tokens appended to input == output with 26 pad tokens appended to input True
input: Moreover, Kiev needs liquid assets to pay for its gas imports from Russia, which accuses it of not having paid a bill of 882 million dollars.
output with no pad tokens appended to input