Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
astorfi committed Dec 13, 2018
1 parent e8be250 commit e546e78
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 23 deletions.
35 changes: 34 additions & 1 deletion README.rst
Expand Up @@ -13,7 +13,40 @@ Documentation
Dataset
============

The dataset object is heavily inspired by the official Pytorch tutorial: [`TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATTENTION <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html/>`_]
**NOTE:** The dataset object is heavily inspired by the official Pytorch tutorial: [`TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATTENTION <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html/>`_]
The dataset is prepaired using the ``data_loader.py`` script.

At the first state we have to define ``word indexing`` for further processing. The ``word2index`` is the dictionary of
transforming word to its associated index and ``index2word`` does the reverse:

.. code-block:: python
SOS_token = 1
EOS_token = 2
class Lang:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "<pad>", SOS_token: "SOS", EOS_token: "EOS"}
self.n_words = 3 # Count SOS and EOS
def addSentence(self, sentence):
for word in sentence.split(' '):
self.addWord(word)
def addWord(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1
Unlike the [`Pytorch tutorial <https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html/>`_] we started
the indexing from ``1`` by ``SOS_token = 1`` to have the ``zero reserved``!


============
Expand Down
Binary file modified __pycache__/data_loader.cpython-35.pyc
Binary file not shown.
41 changes: 19 additions & 22 deletions data_loader.py
Expand Up @@ -28,13 +28,11 @@ def str2bool(v):
### PREPROCESSING ######
########################

# Please refer to https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
######################################################################
# We'll need a unique index per word to use as the inputs and targets of
# the networks later. To keep track of all this we will use a helper class
# called ``Lang`` which has word → index (``word2index``) and index → word
# (``index2word``) dictionaries, as well as a count of each word
# ``word2count`` to use to later replace rare words.
#
# We have to define word indexing for further processing.
# word2index: Word to its associated index
# index2word: Index to the associated word.

SOS_token = 1
EOS_token = 2
Expand Down Expand Up @@ -75,8 +73,6 @@ def unicodeToAscii(s):
)

# Lowercase, trim, and remove non-letter characters


def normalizeString(s):
s = unicodeToAscii(s.lower().strip())
s = re.sub(r"([.!?])", r" \1", s)
Expand All @@ -89,19 +85,18 @@ def normalizeString(s):
# lines into pairs. The files are all English → Other Language, so if we
# want to translate from Other Language → English I added the ``reverse``
# flag to reverse the pairs.
#


def readLangs(lang1, lang2, auto_encoder=False, reverse=False):
print("Reading lines...")

# Read the file and split into lines
lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8'). \
lines = open('data/%s-%s.txt' % ('eng', 'fra'), encoding='utf-8'). \
read().strip().split('\n')

# Split every line into pairs and normalize
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

# Autoencoder have the same data as the output
if auto_encoder:
pairs = [[pair[0], pair[0]] for pair in pairs]

Expand Down Expand Up @@ -176,20 +171,23 @@ def prepareData(lang1, lang2, max_input_length, auto_encoder=False, reverse=Fals


class Dataset():
"""Face Landmarks dataset."""
"""dataset object"""

def __init__(self, phase, num_embeddings=None, max_input_length=None, transform=None, auto_encoder=False):
"""
Args:
split (string): Here we define the split. The choices are: 'trnid', 'tstid' and 'valid' based on flower dataset.
mat_file_data_split (string): Path to the mat file with indexes for the specific split.
mat_file_label (string): Path to the labels based on the file index.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
The initialization of the dataset object.
:param phase: train/test.
:param num_embeddings: The embedding dimentionality.
:param max_input_length: The maximum enforced length of the sentences.
:param transform: Post processing if necessary.
:param auto_encoder: If we are training an autoencoder or not.
"""
lang_in = 'eng'
lang_out = 'fra'
if auto_encoder:
lang_in = 'eng'
lang_out = 'eng'
else:
lang_in = 'eng'
lang_out = 'fra'
# Skip and eliminate the sentences with a length larger than max_input_length!
input_lang, output_lang, pairs = prepareData(lang_in, lang_out, max_input_length, auto_encoder=auto_encoder, reverse=True)
print(random.choice(pairs))
Expand All @@ -216,7 +214,6 @@ def __init__(self, phase, num_embeddings=None, max_input_length=None, transform=
def langs(self):
return self.input_lang, self.output_lang


def __len__(self):
return len(self.data)

Expand Down

0 comments on commit e546e78

Please sign in to comment.