diff --git a/configs/default/components/models/attn_decoder_rnn.yml b/configs/default/components/models/attn_decoder_rnn.yml new file mode 100644 index 0000000..e2e372b --- /dev/null +++ b/configs/default/components/models/attn_decoder_rnn.yml @@ -0,0 +1,78 @@ +# This file defines the default values for the RNN model. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Size of the hidden state (LOADED) +hidden_size: 100 + +# Wether to include the last hidden state in the outputs +output_last_state: False + +# Type of recurrent cell (LOADED) +# -> Only GRU is supported + +# Number of "stacked" layers (LOADED) +# -> Only a single layer is supported + +# Dropout rate (LOADED) +# Default: 0 (means that it is turned off) +dropout_rate: 0 + +# Prediction mode (LOADED) +# Options: +# * Dense (passes every activation through output layer) | +# * Last (passes only the last activation though output layer) | +# * None (all outputs are discarded) +prediction_mode: Dense + +# Enable FFN layer at the output of the RNN (before eventual feed back in the case of autoregression). +# Useful if the raw outputs of the RNN are needed, for attention encoder-decoder for example. +ffn_output: True + +# Length of generated output sequence (LOADED) +# User must set it per task, as it is task specific. +autoregression_length: 10 + +# If true, output of the last layer will be additionally processed with Log Softmax (LOADED) +use_logsoftmax: True + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of encoder outputs (INPUT) + inputs: inputs + + # Stream containing the inital state of the RNN (INPUT) + # The stream will be actually created only if `inital_state: Input` + input_state: input_state + + # Stream containing predictions (OUTPUT) + predictions: predictions + + # Stream containing the final output state of the RNN (output) + # The stream will be actually created only if `output_last_state: True` + output_state: output_state + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + # Size of the input (RETRIEVED) + input_size: input_size + + # Size of the prediction (RETRIEVED) + prediction_size: prediction_size + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/default/components/problems/text_to_text/translation_pairs.yml b/configs/default/components/problems/text_to_text/translation_pairs.yml new file mode 100644 index 0000000..f48f650 --- /dev/null +++ b/configs/default/components/problems/text_to_text/translation_pairs.yml @@ -0,0 +1,49 @@ +# This file defines the default values for the WikiText language modeling. + +#################################################################### +# 1. CONFIGURATION PARAMETERS that will be LOADED by the component. +#################################################################### + +# Folder where problem will store data (LOADED) +data_folder: ~/data/language_modeling/translation_pairs + +# Defines the dataset that will be used used (LOADED) +# Options: eng-fra, eng-pol +dataset: eng-fra + +# Defines the used subset (LOADED) +# Options: train | valid | test +subset: train + +# Length limit of source and target sentence +# if < 0, no limit +sentence_length: 10 + +streams: + #################################################################### + # 2. Keymappings associated with INPUT and OUTPUT streams. + #################################################################### + + # Stream containing batch of indices (OUTPUT) + # Every problem MUST return that stream. + indices: indices + + # Stream containing batch of tokenized source sentences (OUTPUT) + sources: sources + + # Stream containing batch of tokenized target sentences (OUTPUT) + targets: targets + +globals: + #################################################################### + # 3. Keymappings of variables that will be RETRIEVED from GLOBALS. + #################################################################### + + #################################################################### + # 4. Keymappings associated with GLOBAL variables that will be SET. + #################################################################### + + #################################################################### + # 5. Keymappings associated with statistics that will be ADDED. + #################################################################### + diff --git a/configs/translation/eng_fra_translation_enc_attndec.yml b/configs/translation/eng_fra_translation_enc_attndec.yml new file mode 100644 index 0000000..7eab08f --- /dev/null +++ b/configs/translation/eng_fra_translation_enc_attndec.yml @@ -0,0 +1,172 @@ +# This pipeline applied an encoder-decoder GRU with attention on the open Tatoeba translation sentence pairs. +# Inspired by https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html . +# Note that training will be slower than in the tutorial, as teacher forcing is not implemented here. + +# Training parameters: +training: + problem: + type: &p_type TranslationPairs + data_folder: &data_folder ~/data/language_modeling/translation_pairs + dataset: &dataset eng-fra + subset: train + sentence_length: 10 + batch_size: 64 + + # optimizer parameters: + optimizer: + name: Adam + lr: 1.0e-3 + + # settings parameters + terminal_conditions: + loss_stop: 1.0e-2 + episode_limit: 1000000 + epoch_limit: 100 + +# Validation parameters: +validation: + partial_validation_interval: 100 + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: valid + sentence_length: 10 + batch_size: 64 + +# Testing parameters: +testing: + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: test + sentence_length: 10 + batch_size: 64 + +pipeline: + name: eng_fra_translation_enc_attndec + + # Source encoding - model 1. + source_sentence_embedding: + type: SentenceEmbeddings + priority: 1.1 + embeddings_size: 50 + pretrained_embeddings: glove.6B.50d.txt + data_folder: *data_folder + source_vocabulary_files: eng-fra/eng.train.txt,eng-fra/eng.valid.txt,eng-fra/eng.test.txt + vocabulary_mappings_file: eng-fra/eng.all.tokenized_words + regenerate: True + additional_tokens: + import_word_mappings_from_globals: False + export_word_mappings_to_globals: False + fixed_padding: 10 + streams: + inputs: sources + outputs: embedded_sources + + # Target encoding. + target_indexer: + type: SentenceIndexer + priority: 2.1 + data_folder: *data_folder + source_vocabulary_files: eng-fra/fra.train.txt,eng-fra/fra.valid.txt,eng-fra/fra.test.txt + import_word_mappings_from_globals: False + export_word_mappings_to_globals: True + fixed_padding: 10 + regenerate: True + streams: + inputs: targets + outputs: indexed_targets + + # Single layer GRU Encoder + encoder: + type: RecurrentNeuralNetwork + cell_type: GRU + priority: 3 + initial_state: Trainable + hidden_size: 50 + num_layers: 1 + use_logsoftmax: False + output_last_state: True + prediction_mode: Dense + ffn_output: False + streams: + inputs: embedded_sources + predictions: s2s_encoder_output + output_state: s2s_state_output + globals: + input_size: embeddings_size + prediction_size: embeddings_size + + # Single layer GRU Decoder with attention + decoder: + type: Attn_Decoder_RNN + priority: 4 + hidden_size: 50 + use_logsoftmax: False + autoregression_length: 10 + prediction_mode: Dense + streams: + inputs: s2s_encoder_output + predictions: s2s_decoder_output + input_state: s2s_state_output + globals: + input_size: embeddings_size + prediction_size: embeddings_size + + # FF, to resize the from the output size of the seq2seq to the size of the target vector + ff_resize_s2s_output: + type: FeedForwardNetwork + use_logsoftmax: True + dimensions: 3 + priority: 5 + streams: + inputs: s2s_decoder_output + globals: + input_size: embeddings_size + prediction_size: vocabulary_size + + # Loss + nllloss: + type: NLLLoss + priority: 6 + num_targets_dims: 2 + streams: + targets: indexed_targets + loss: loss + + # Prediction decoding. + prediction_decoder: + type: SentenceIndexer + priority: 10 + # Reverse mode. + reverse: True + # Use distributions as inputs. + use_input_distributions: True + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: prediction_sentences + + + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.0 + + bleu: + type: BLEUStatistics + priority: 100.2 + streams: + targets: indexed_targets + + + # Viewers. + viewer: + type: StreamViewer + priority: 100.3 + input_streams: sources,targets,indexed_targets,prediction_sentences + +#: pipeline diff --git a/configs/wikitext/wikitext_language_modeling_encoder_attndecoder.yml b/configs/wikitext/wikitext_language_modeling_encoder_attndecoder.yml new file mode 100644 index 0000000..244fc4a --- /dev/null +++ b/configs/wikitext/wikitext_language_modeling_encoder_attndecoder.yml @@ -0,0 +1,174 @@ +# This pipeline applies seq2seq on wikitext-2 to make word-level prediction. +# It's been made for test purposes only, as it is doing: +# [word 0 , ... , word 49] -> [word 1 , ... , word 50] (basically copying most of the input) +# +# The seq2seq here is implemented throught the use of 2 `RecurrentNeuralNetwork` + +# Training parameters: +training: + problem: + type: &p_type WikiTextLanguageModeling + data_folder: &data_folder ~/data/language_modeling/wikitext-2 + dataset: &dataset wikitext-2 + subset: train + sentence_length: 42 + batch_size: 64 + + # optimizer parameters: + optimizer: + name: SGD + lr: 1.0e-2 + + # settings parameters + terminal_conditions: + loss_stop: 1.0e-2 + episode_limit: 1000000 + epoch_limit: 100 + +# Validation parameters: +validation: + partial_validation_interval: 100 + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: valid + sentence_length: 42 + batch_size: 64 + +# Testing parameters: +testing: + problem: + type: *p_type + data_folder: *data_folder + dataset: *dataset + subset: test + sentence_length: 42 + batch_size: 64 + +pipeline: + name: wikitext_language_modeling_seq2seq + + # Source encoding - model 1. + source_sentence_embedding: + type: SentenceEmbeddings + priority: 1.1 + embeddings_size: 50 + pretrained_embeddings: glove.6B.50d.txt + data_folder: *data_folder + source_vocabulary_files: wiki.train.tokens,wiki.valid.tokens,wiki.test.tokens + vocabulary_mappings_file: wiki.all.tokenized_words + additional_tokens: + export_word_mappings_to_globals: True + streams: + inputs: sources + outputs: embedded_sources + + # Target encoding. + target_indexer: + type: SentenceIndexer + priority: 2.1 + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: targets + outputs: indexed_targets + + # LSTM Encoder + encoder: + type: RecurrentNeuralNetwork + cell_type: GRU + priority: 3 + initial_state: Trainable + hidden_size: 50 + num_layers: 1 + use_logsoftmax: False + output_last_state: True + prediction_mode: Dense + ffn_output: False + streams: + inputs: embedded_sources + predictions: s2s_encoder_output + output_state: s2s_state_output + globals: + input_size: embeddings_size + prediction_size: embeddings_size + + # LSTM Decoder + decoder: + type: Attn_Decoder_RNN + priority: 4 + hidden_size: 50 + num_layers: 1 + use_logsoftmax: False + autoregression_length: 42 + prediction_mode: Dense + streams: + inputs: s2s_encoder_output + predictions: s2s_decoder_output + input_state: s2s_state_output + globals: + input_size: embeddings_size + prediction_size: embeddings_size + + # FF, to resize the from the output size of the seq2seq to the size of the target vector + ff_resize_s2s_output: + type: FeedForwardNetwork + use_logsoftmax: True + dimensions: 3 + priority: 5 + streams: + inputs: s2s_decoder_output + globals: + input_size: embeddings_size + prediction_size: vocabulary_size + + # Loss + nllloss: + type: NLLLoss + priority: 6 + num_targets_dims: 2 + streams: + targets: indexed_targets + loss: loss + + # Prediction decoding. + prediction_decoder: + type: SentenceIndexer + priority: 10 + # Reverse mode. + reverse: True + # Use distributions as inputs. + use_input_distributions: True + data_folder: *data_folder + import_word_mappings_from_globals: True + streams: + inputs: predictions + outputs: prediction_sentences + + + # Statistics. + batch_size: + type: BatchSizeStatistics + priority: 100.0 + + #accuracy: + # type: AccuracyStatistics + # priority: 100.1 + # streams: + # targets: indexed_targets + + bleu: + type: BLEUStatistics + priority: 100.2 + streams: + targets: indexed_targets + + + # Viewers. + viewer: + type: StreamViewer + priority: 100.3 + input_streams: sources,targets,indexed_targets,prediction_sentences + +#: pipeline diff --git a/ptp/components/models/__init__.py b/ptp/components/models/__init__.py index 8d04135..20c2841 100644 --- a/ptp/components/models/__init__.py +++ b/ptp/components/models/__init__.py @@ -7,6 +7,7 @@ from .recurrent_neural_network import RecurrentNeuralNetwork from .sentence_embeddings import SentenceEmbeddings from .seq2seq_rnn import Seq2Seq_RNN +from .attn_decoder_rnn import Attn_Decoder_RNN from .vqa.element_wise_multiplication import ElementWiseMultiplication from .vqa.multimodal_compact_bilinear_pooling import MultimodalCompactBilinearPooling @@ -25,4 +26,5 @@ 'ElementWiseMultiplication', 'MultimodalCompactBilinearPooling', 'RelationalNetwork', + 'Attn_Decoder_RNN' ] diff --git a/ptp/components/models/attn_decoder_rnn.py b/ptp/components/models/attn_decoder_rnn.py new file mode 100644 index 0000000..4d558ed --- /dev/null +++ b/ptp/components/models/attn_decoder_rnn.py @@ -0,0 +1,235 @@ +# Copyright (C) Alexis Asseman, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Alexis Asseman" + +import torch + +from ptp.configuration.configuration_error import ConfigurationError +from ptp.components.models.model import Model +from ptp.data_types.data_definition import DataDefinition + + +class Attn_Decoder_RNN(Model): + """ + Single layer GRU decoder with attention: + Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. + + Needs the full sequence of hidden states from the encoder as input, as well as the last hidden state from the encoder as input state. + + Code is based on https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html. + """ + def __init__(self, name, config): + """ + Initializes the model. + + :param config: Dictionary of parameters (read from configuration ``.yaml`` file). + :type config: ``ptp.configuration.ConfigInterface`` + """ + # Call constructors of parent classes. + Model.__init__(self, name, Attn_Decoder_RNN, config) + + # Get input/output mode + self.output_last_state = self.config["output_last_state"] + self.ffn_output = self.config["ffn_output"] + + # Get prediction mode from configuration. + self.prediction_mode = self.config["prediction_mode"] + if self.prediction_mode not in ['Dense','Last', 'None']: + raise ConfigurationError("Invalid 'prediction_mode' (current {}, available {})".format(self.prediction_mode, ['Dense','Last', 'None'])) + + self.autoregression_length = self.config["autoregression_length"] + + # Retrieve input size from global variables. + self.key_input_size = self.global_keys["input_size"] + self.input_size = self.globals["input_size"] + if type(self.input_size) == list: + if len(self.input_size) == 1: + self.input_size = self.input_size[0] + else: + raise ConfigurationError("RNN input size '{}' must be a single dimension (current {})".format(self.key_input_size, self.input_size)) + + # Retrieve output (prediction) size from global params. + self.prediction_size = self.globals["prediction_size"] + if type(self.prediction_size) == list: + if len(self.prediction_size) == 1: + self.prediction_size = self.prediction_size[0] + else: + raise ConfigurationError("RNN prediction size '{}' must be a single dimension (current {})".format(self.key_prediction_size, self.prediction_size)) + + # Retrieve hidden size from configuration. + self.hidden_size = self.config["hidden_size"] + if type(self.hidden_size) == list: + if len(self.hidden_size) == 1: + self.hidden_size = self.hidden_size[0] + else: + raise ConfigurationError("RNN hidden_size must be a single dimension (current {})".format(self.hidden_size)) + + # Get dropout rate value from config. + dropout_rate = self.config["dropout_rate"] + + # Create dropout layer. + self.dropout = torch.nn.Dropout(dropout_rate) + + # Create rnn cell. + self.rnn_cell = getattr(torch.nn, "GRU")(self.input_size, self.hidden_size, 1, dropout=dropout_rate, batch_first=True) + + # Create layers for the attention + self.attn = torch.nn.Linear(self.hidden_size * 2, self.autoregression_length) + self.attn_combine = torch.nn.Linear(self.hidden_size * 2, self.hidden_size) + + # Create the trainable initial input for the decoder (A trained token of sorts) + self.sos_token = torch.zeros(1, self.input_size) + torch.nn.init.xavier_uniform(self.sos_token) + self.sos_token = torch.nn.Parameter(self.sos_token, requires_grad=True) + + # Get key mappings. + self.key_inputs = self.stream_keys["inputs"] + self.key_predictions = self.stream_keys["predictions"] + self.key_input_state = self.stream_keys["input_state"] + if self.output_last_state: + self.key_output_state = self.stream_keys["output_state"] + + self.logger.info("Initializing RNN with input size = {}, hidden size = {} and prediction size = {}".format(self.input_size, self.hidden_size, self.prediction_size)) + + # Create the output layer. + self.activation2output_layer = None + if(self.ffn_output): + self.activation2output_layer = torch.nn.Linear(self.hidden_size, self.prediction_size) + + # Create the final non-linearity. + self.use_logsoftmax = self.config["use_logsoftmax"] + if self.use_logsoftmax: + if self.prediction_mode == "Dense": + # Used then returning dense prediction, i.e. every output of unfolded model. + self.log_softmax = torch.nn.LogSoftmax(dim=2) + else: + # Used when returning only the last output. + self.log_softmax = torch.nn.LogSoftmax(dim=1) + + def activation2output(self, activations): + output = self.dropout(activations) + + if(self.ffn_output): + #output = activations.squeeze(1) + shape = activations.shape + + # Reshape to 2D tensor [BATCH_SIZE * SEQ_LEN x HIDDEN_SIZE] + output = output.contiguous().view(-1, shape[2]) + + # Propagate data through the output layer [BATCH_SIZE * SEQ_LEN x PREDICTION_SIZE] + output = self.activation2output_layer(output) + #output = output.unsqueeze(1) + + # Reshape back to 3D tensor [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + output = output.view(shape[0], shape[1], output.size(1)) + + return output + + + def input_data_definitions(self): + """ + Function returns a dictionary with definitions of input data that are required by the component. + + :return: dictionary containing input data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + d = {} + + d[self.key_inputs] = DataDefinition([-1, -1, self.hidden_size], [torch.Tensor], "Batch of encoder outputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE]") + + # Input hidden state + d[self.key_input_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") + + return d + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + d = {} + + if self.prediction_mode == "Dense": + d[self.key_predictions] = DataDefinition([-1, -1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") + elif self.prediction_mode == "Last": # "Last" + # Only last prediction. + d[self.key_predictions] = DataDefinition([-1, self.prediction_size], [torch.Tensor], "Batch of predictions, each represented as probability distribution over classes [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE]") + + # Output hidden state stream + if self.output_last_state: + d[self.key_output_state] = DataDefinition([1, -1, self.hidden_size], [torch.Tensor], "Batch of RNN last states") + + return d + + def forward(self, data_dict): + """ + Forward pass of the model. + + :param data_dict: DataDict({'inputs', 'predictions ...}), where: + + - inputs: expected inputs [BATCH_SIZE x SEQ_LEN x INPUT_SIZE], + - predictions: returned output with predictions (log_probs) [BATCH_SIZE x SEQ_LEN x PREDICTION_SIZE] + """ + + inputs = data_dict[self.key_inputs] + batch_size = inputs.shape[0] + + # Initialize hidden state. + hidden = data_dict[self.key_input_state] + + # List that will contain the output sequence + activations = [] + + # First input to the decoder - trainable "start of sequence" token + activations_partial = self.sos_token.expand(batch_size, -1).unsqueeze(1) + + # Feed back the outputs iteratively + for i in range(self.autoregression_length): + + # Do the attention thing + attn_weights = torch.nn.functional.softmax( + self.attn(torch.cat((activations_partial.transpose(0, 1), hidden), 2)), + dim=2 + ) + attn_applied = torch.bmm(attn_weights.transpose(0, 1), inputs) + activations_partial = torch.cat((activations_partial, attn_applied), 2) + activations_partial = self.attn_combine(activations_partial) + activations_partial = torch.nn.functional.relu(activations_partial) + + # Feed through the RNN + activations_partial, hidden = self.rnn_cell(activations_partial, hidden) + activations_partial = self.activation2output(activations_partial) + + # Add the single step output into list + if self.prediction_mode == "Dense": + activations += [activations_partial] + + # Reassemble all the outputs from list into an output tensor + if self.prediction_mode == "Dense": + outputs = torch.cat(activations, 1) + # Log softmax - along PREDICTION dim. + if self.use_logsoftmax: + outputs = self.log_softmax(outputs) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + elif self.prediction_mode == "Last": + if self.use_logsoftmax: + outputs = self.log_softmax(activations_partial.squeeze(1)) + # Add predictions to datadict. + data_dict.extend({self.key_predictions: outputs}) + + # Output last hidden state, if requested + if self.output_last_state: + data_dict.extend({self.key_output_state: hidden}) diff --git a/ptp/components/problems/text_to_text/__init__.py b/ptp/components/problems/text_to_text/__init__.py index be7cc00..804ae58 100644 --- a/ptp/components/problems/text_to_text/__init__.py +++ b/ptp/components/problems/text_to_text/__init__.py @@ -1,5 +1,7 @@ from .wikitext_language_modeling import WikiTextLanguageModeling +from .translation_pairs import TranslationPairs __all__ = [ 'WikiTextLanguageModeling', + 'TranslationPairs' ] diff --git a/ptp/components/problems/text_to_text/translation_pairs.py b/ptp/components/problems/text_to_text/translation_pairs.py new file mode 100644 index 0000000..3f72b87 --- /dev/null +++ b/ptp/components/problems/text_to_text/translation_pairs.py @@ -0,0 +1,231 @@ +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Alexis Asseman" + +import os +import random +import tempfile +import unicodedata +import re + +from nltk.tokenize import WhitespaceTokenizer + +import ptp.components.utils.io as io +from ptp.configuration import ConfigurationError +from ptp.components.problems.problem import Problem +from ptp.data_types.data_definition import DataDefinition + + +class TranslationPairs(Problem): + """ + Bilingual sentence pairs from http://www.manythings.org/anki/. + Only some pairs are included here, but many more are available on the website. + Will download the requested language pair if necessary, normalize and tokenize the sentences, and will cut the data into train, valid, test sets. + + Resulting tokens that are shorter than the specified length are then passed to samples (source/target) as list of tokens (set by the user in configuration file). + """ + def __init__(self, name, config): + """ + The init method downloads the required files, loads the file associated with a given subset (train/valid/test), + concatenates all sencentes and tokenizes them using NLTK's WhitespaceTokenizer. + + :param name: Name of the component. + + :param class_type: Class type of the component. + + :param config: Dictionary of parameters (read from configuration ``.yaml`` file). + """ + # Call constructor of parent classes. + Problem.__init__(self, name, TranslationPairs, config) + + # Set streams key mappings. + self.key_sources = self.stream_keys["sources"] + self.key_targets = self.stream_keys["targets"] + + # Get absolute path to data folder. + self.data_folder = os.path.expanduser(self.config['data_folder']) + + # Get dataset. + if (self.config['dataset'] is None) or (self.config['dataset'] not in ["eng-fra", "eng-pol"]): + raise ConfigurationError("Problem supports only 'dataset' options: 'eng-fra', 'eng-pol'") + dataset = self.config['dataset'] + + # Get (sub)set: train/valid/test. + if (self.config['subset'] is None) or (self.config['subset'] not in ['train', 'valid', 'test']): + raise ConfigurationError("Problem supports one 'subset' options: 'train', 'valid', 'test' ") + subset = self.config['subset'] + + # Extract source and target language name + self.lang_source = self.config['dataset'].split('-')[0] + self.lang_target = self.config['dataset'].split('-')[1] + + + # Names of files used by this problem. + filenames = [ + self.lang_source + ".train.txt", + self.lang_target + ".train.txt", + self.lang_source + ".valid.txt", + self.lang_target + ".valid.txt", + self.lang_source + ".test.txt", + self.lang_target + ".test.txt" + ] + + # Initialize dataset if files do not exist. + if not io.check_files_existence(os.path.join(self.data_folder, dataset), filenames): + # Set url and source filename depending on dataset. + url = "https://www.manythings.org/anki/" + self.lang_target + "-" + self.lang_source + ".zip" + zipfile_name = "translate_" + self.lang_target + "_" + self.lang_source + ".zip" + + with tempfile.TemporaryDirectory() as tmpdirname: + # Download and extract wikitext zip. + io.download_extract_zip_file(self.logger, tmpdirname, url, zipfile_name) + + # Create train, valid, test files from the downloaded file + lines = io.load_string_list_from_txt_file(tmpdirname, self.lang_target + ".txt") + + # Shuffle the lines + random.seed(42) + random.shuffle(lines) + + # Split english and french pairs + lines_source = [self.normalizeString(l.split('\t')[0]) for l in lines] + lines_target = [self.normalizeString(l.split('\t')[1]) for l in lines] + + # Cut dataset into train (90%), valid (5%), test (5%) files + test_index = len(lines) // 20 + valid_index = test_index + (len(lines) // 20) + + os.makedirs(os.path.join(self.data_folder, dataset), exist_ok=True) + + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_source + ".test.txt"), mode='w+') as f: + f.write('\n'.join(lines_source[0:test_index])) + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_target + ".test.txt"), mode='w+') as f: + f.write('\n'.join(lines_target[0:test_index])) + + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_source + ".valid.txt"), mode='w+') as f: + f.write('\n'.join(lines_source[test_index:valid_index])) + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_target + ".valid.txt"), mode='w+') as f: + f.write('\n'.join(lines_target[test_index:valid_index])) + + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_source + ".train.txt"), mode='w+') as f: + f.write('\n'.join(lines_source[valid_index:])) + with open(os.path.join(os.path.join(self.data_folder, dataset), self.lang_target + ".train.txt"), mode='w+') as f: + f.write('\n'.join(lines_target[valid_index:])) + + else: + self.logger.info("Files {} found in folder '{}'".format(filenames, self.data_folder)) + + + # Load the lines + lines_source = io.load_string_list_from_txt_file(os.path.join(self.data_folder, dataset), self.lang_source + "."+subset+".txt") + lines_target = io.load_string_list_from_txt_file(os.path.join(self.data_folder, dataset), self.lang_target + "."+subset+".txt") + + # Get the required sample length. + self.sentence_length = self.config['sentence_length'] + + # Separate into src - tgt sentence pairs + tokenize + tokenizer = WhitespaceTokenizer() + self.sentences_source = [] + self.sentences_target = [] + for s_src, s_tgt in zip(lines_source, lines_target): + src = tokenizer.tokenize(s_src) + tgt = tokenizer.tokenize(s_tgt) + # Keep only the pairs that are shorter or equal to the requested length + # If self.sentence_length < 0, then give all the pairs regardless of length + if (len(src) <= self.sentence_length and len(tgt) <= self.sentence_length) \ + or self.sentence_length < 0: + self.sentences_source += [src] + self.sentences_target += [tgt] + + self.logger.info("Load text consisting of {} sentences".format(len(self.sentences_source))) + + # Calculate the size of dataset. + self.dataset_length = len(self.sentences_source) + + # Display exemplary sample. + self.logger.info("Exemplary sample:\n source: {}\n target: {}".format(self.sentences_source[0], self.sentences_target[0])) + + + def output_data_definitions(self): + """ + Function returns a dictionary with definitions of output data produced the component. + + :return: dictionary containing output data definitions (each of type :py:class:`ptp.utils.DataDefinition`). + """ + return { + self.key_indices: DataDefinition([-1, 1], [list, int], "Batch of sample indices [BATCH_SIZE] x [1]"), + self.key_sources: DataDefinition([-1, self.sentence_length, 1], [list, list, str], "Batch of input sentences, each consisting of several words [BATCH_SIZE] x [SENTENCE_LENGTH] x [string]"), + self.key_targets: DataDefinition([-1, self.sentence_length, 1], [list, list, str], "Batch of target sentences, each consisting of several words [BATCH_SIZE] x [SENTENCE_LENGTH] x [string]") + } + + # Turn a Unicode string to plain ASCII, thanks to + # https://stackoverflow.com/a/518232/2809427 + @staticmethod + def unicodeToAscii(s): + return ''.join( + c for c in unicodedata.normalize('NFD', s) + if unicodedata.category(c) != 'Mn' + ) + + # Lowercase, trim, and remove non-letter characters + # https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html + def normalizeString(self, s): + s = self.unicodeToAscii(s.lower().strip()) + s = re.sub(r"([.!?])", r" \1", s) + s = re.sub(r"[^a-zA-Z.!?]+", r" ", s) + return s + + def __len__(self): + """ + Returns the "size" of the "problem" (total number of samples). + + :return: The size of the problem. + """ + return self.dataset_length + + + def __getitem__(self, index): + """ + Getter method to access the dataset and return a sample. + + :param index: index of the sample to return. + :type index: int + + :return: ``DataDict({'indices', sources','targets'})`` + + """ + # Return data_dict. + data_dict = self.create_data_dict(index) + data_dict[self.key_sources] = self.sentences_source[index] + data_dict[self.key_targets] = self.sentences_target[index] + return data_dict + + def collate_fn(self, batch): + """ + Generates a batch of samples from a list of individuals samples retrieved by :py:func:`__getitem__`. + + :param batch: List of :py:class:`ptp.utils.DataDict` retrieved by :py:func:`__getitem__` + :type batch: list + + :return: DataDict containing the created batch. + + """ + # Collate indices. + data_dict = self.create_data_dict([sample[self.key_indices] for sample in batch]) + # Collate sources. + data_dict[self.key_sources] = [sample[self.key_sources] for sample in batch] + data_dict[self.key_targets] = [sample[self.key_targets] for sample in batch] + return data_dict +