Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
corpus reader
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Mar 10, 2018
1 parent a61fd99 commit 8b37495
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 46 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/gluon/data/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ def read(self):
raise NotImplementedError

def read_iter(self):
return self.read()
return iter(self.read())
91 changes: 59 additions & 32 deletions python/mxnet/gluon/data/text/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,54 @@
from ..datareader import DataReader
from .utils import flatten_samples, collate, pair

class WordLanguageReader(DataReader):
class CorpusReader(DataReader):
"""Text reader that reads a whole corpus and produces a dataset based on provided
sample splitter and word tokenizer.
The returned dataset includes samples, each of which can either be a list of tokens if tokenizer
is specified, or a single string segment from the result of sample_splitter.
Parameters
----------
filename : str
Path to the input text file.
encoding : str, default 'utf8'
File encoding format.
flatten : bool, default False
Whether to return all samples as flattened tokens. If True, each sample is a token.
sample_splitter : function, default str.splitlines
A function that splits the dataset string into samples.
tokenizer : function or None, default str.split
A function that splits each sample string into list of tokens. If None, raw samples are
returned according to `sample_splitter`.
"""
def __init__(self, filename, encoding='utf8', flatten=False,
sample_splitter=lambda s: s.splitlines(),
tokenizer=lambda s: s.split()):
assert sample_splitter, 'sample_splitter must be specified.'
self._filename = os.path.expanduser(filename)
self._encoding = encoding
self._flatten = flatten
self._sample_splitter = sample_splitter
self._tokenizer = tokenizer

def read(self):
with io.open(self._filename, 'r', encoding=self._encoding) as fin:
content = fin.read()
samples = (s.strip() for s in self._sample_splitter(content))
if self._tokenizer:
samples = [self._tokenizer(s) for s in samples if s]
else:
samples = [s for s in samples if s]
if self._flatten:
samples = flatten(samples)
return samples


class WordLanguageReader(CorpusReader):
"""Text reader that reads a whole corpus and produces a language modeling dataset based
on provided sample splitter and word tokenizer.
The returned dataset includes data (current word) and label (next word).
Parameters
Expand Down Expand Up @@ -61,43 +105,26 @@ class WordLanguageReader(DataReader):
"""
def __init__(self, filename, encoding='utf8', sample_splitter=lambda s: s.splitlines(),
tokenizer=lambda s: s.split(), seq_len=None, bos=None, eos=None, pad=None):
self._filename = os.path.expanduser(filename)
self._encoding = encoding
self._sample_splitter = sample_splitter
self._tokenizer = tokenizer

if bos and eos:
def process(s):
out = [bos]
out.extend(s)
out.append(eos)
return pair(out)
elif bos:
def process(s):
out = [bos]
out.extend(s)
return pair(out)
elif eos:
def process(s):
s.append(eos)
return pair(s)
else:
def process(s):
return pair(s)
self._process = process
assert tokenizer, "Tokenizer must be specified for reading word language model corpus."
super(WordLanguageReader, self).__init__(filename, encoding, False, sample_splitter, tokenizer)
def process(s):
tokens = [bos] if bos else []
tokens.extend(s)
if eos:
tokens.append(eos)
return tokens
self._seq_len = seq_len
self._process = process
self._pad = pad

def read(self):
with io.open(self._filename, 'r', encoding=self._encoding) as fin:
content = fin.read()
samples = [s.strip() for s in self._sample_splitter(content)]
samples = [self._process(self._tokenizer(s)) for s in samples if s]
samples = super(WordLanguageReader, self).read()
samples = [self._process(s) for s in samples]
if self._seq_len:
samples = flatten_samples(samples)
if self._pad and len(samples) % self._seq_len:
pad_len = self._seq_len - len(samples) % self._seq_len
samples.extend([self._pad] * pad_len)
samples = collate(samples, self._seq_len)
samples = [list(zip(*s)) for s in samples]
return SimpleDataset(samples)
samples = collate(samples, self._seq_len, 1)

return SimpleDataset(samples).transform(lambda x: (x[:-1], x[1:]))
7 changes: 5 additions & 2 deletions python/mxnet/gluon/data/text/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def flatten_samples(samples):
"""
return [token for sample in samples for token in sample if token]

def collate(flat_sample, seq_len):
def collate(flat_sample, seq_len, overlap=0):
"""Collate a flat list of tokens into list of list of tokens, with each
inner list's length equal to the specified `seq_len`.
Expand All @@ -48,13 +48,16 @@ def collate(flat_sample, seq_len):
A flat list of tokens.
seq_len : int
The length of each of the samples.
overlap : int, default 0
The extra number of items in current sample that should overlap with the
next sample.
Returns
-------
List of samples, each of which has length equal to `seq_len`.
"""
num_samples = len(flat_sample) // seq_len
return [flat_sample[i*seq_len:(i+1)*seq_len] for i in range(num_samples)]
return [flat_sample[i*seq_len:((i+1)*seq_len+overlap)] for i in range(num_samples)]

def pair(sample):
"""Produce tuples of tokens from a list of tokens, with current token as the first
Expand Down
22 changes: 11 additions & 11 deletions tests/python/unittest/test_gluon_data_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,20 @@ def test_wikitext2():
val = data.text.lm.WikiText2(root='data/wikitext-2', segment='val')
test = data.text.lm.WikiText2(root='data/wikitext-2', segment='test')
train_freq, val_freq, test_freq = [get_frequencies(x) for x in [train, val, test]]
assert len(train) == 58626
assert len(train_freq) == 33278
assert len(val) == 6112
assert len(val_freq) == 13778
assert len(test) == 6892
assert len(test_freq) == 14144
assert test_freq['English'] == 35
assert len(train[0][0]) == 35
assert len(train) == 59306, len(train)
assert len(train_freq) == 33279, len(train_freq)
assert len(val) == 6182, len(val)
assert len(val_freq) == 13778, len(val_freq)
assert len(test) == 6975, len(test)
assert len(test_freq) == 14144, len(test_freq)
assert test_freq['English'] == 33, test_freq['English']
assert len(train[0][0]) == 35, len(train[0][0])
test_no_pad = data.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None)
assert len(test_no_pad) == 6891
assert len(test_no_pad) == 6974, len(test_no_pad)

train_paragraphs = data.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None)
assert len(train_paragraphs) == 23767
assert len(train_paragraphs[0][0]) != 35
assert len(train_paragraphs) == 23767, len(train_paragraphs)
assert len(train_paragraphs[0][0]) != 35, len(train_paragraphs[0][0])


if __name__ == '__main__':
Expand Down

0 comments on commit 8b37495

Please sign in to comment.