Skip to content

Commit

Permalink
training
Browse files Browse the repository at this point in the history
  • Loading branch information
matt-peters committed Apr 23, 2018
1 parent e6463cc commit 5de3ca9
Show file tree
Hide file tree
Showing 6 changed files with 1,667 additions and 4 deletions.
204 changes: 204 additions & 0 deletions bilm/data.py
@@ -1,9 +1,13 @@
# originally based on https://github.com/tensorflow/models/tree/master/lm_1b
import glob
import random

import numpy as np

from typing import List



class Vocabulary(object):
'''
A token vocabulary. Holds a map from token to ids and provides
Expand Down Expand Up @@ -249,3 +253,203 @@ def batch_sentences(self, sentences: List[List[str]]):

return X_ids


##### for training
def _get_batch(generator, batch_size, num_steps, max_word_length):
"""Read batches of input."""
cur_stream = [None] * batch_size

no_more_data = False
while True:
inputs = np.zeros([batch_size, num_steps], np.int32)
if max_word_length is not None:
char_inputs = np.zeros([batch_size, num_steps, max_word_length],
np.int32)
else:
char_inputs = None
targets = np.zeros([batch_size, num_steps], np.int32)

for i in range(batch_size):
cur_pos = 0

while cur_pos < num_steps:
if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
try:
cur_stream[i] = list(next(generator))
except StopIteration:
# No more data, exhaust current streams and quit
no_more_data = True
break

how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
next_pos = cur_pos + how_many

inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
if max_word_length is not None:
char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][
:how_many]
targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1]

cur_pos = next_pos

cur_stream[i][0] = cur_stream[i][0][how_many:]
if max_word_length is not None:
cur_stream[i][1] = cur_stream[i][1][how_many:]

if no_more_data:
# There is no more data. Note: this will not return data
# for the incomplete batch
break

X = {'token_ids': inputs, 'tokens_characters': char_inputs,
'next_token_id': targets}

yield X

class LMDataset(object):
"""
Hold a language model dataset.
A dataset is a list of tokenized files. Each file contains one sentence
per line. Each sentence is pre-tokenized and white space joined.
"""
def __init__(self, filepattern, vocab, reverse=False, test=False,
shuffle_on_load=False):
'''
filepattern = a glob string that specifies the list of files.
vocab = an instance of Vocabulary or UnicodeCharsVocabulary
reverse = if True, then iterate over tokens in each sentence in reverse
test = if True, then iterate through all data once then stop.
Otherwise, iterate forever.
shuffle_on_load = if True, then shuffle the sentences after loading.
'''
self._vocab = vocab
self._all_shards = glob.glob(filepattern)
print('Found %d shards at %s' % (len(self._all_shards), filepattern))
self._shards_to_choose = []

self._reverse = reverse
self._test = test
self._shuffle_on_load = shuffle_on_load
self._use_char_inputs = hasattr(vocab, 'encode_chars')

self._ids = self._load_random_shard()

def _choose_random_shard(self):
if len(self._shards_to_choose) == 0:
self._shards_to_choose = list(self._all_shards)
random.shuffle(self._shards_to_choose)
shard_name = self._shards_to_choose.pop()
return shard_name

def _load_random_shard(self):
"""Randomly select a file and read it."""
if self._test:
if len(self._all_shards) == 0:
# we've loaded all the data
# this will propogate up to the generator in get_batch
# and stop iterating
raise StopIteration
else:
shard_name = self._all_shards.pop()
else:
# just pick a random shard
shard_name = self._choose_random_shard()

ids = self._load_shard(shard_name)
self._i = 0
self._nids = len(ids)
return ids

def _load_shard(self, shard_name):
"""Read one file and convert to ids.
Args:
shard_name: file path.
Returns:
list of (id, char_id) tuples.
"""
print('Loading data from: %s' % shard_name)
with open(shard_name) as f:
sentences_raw = f.readlines()

if self._reverse:
sentences = []
for sentence in sentences_raw:
splitted = sentence.split()
splitted.reverse()
sentences.append(' '.join(splitted))
else:
sentences = sentences_raw

if self._shuffle_on_load:
random.shuffle(sentences)

ids = [self.vocab.encode(sentence, self._reverse)
for sentence in sentences]
if self._use_char_inputs:
chars_ids = [self.vocab.encode_chars(sentence, self._reverse)
for sentence in sentences]
else:
chars_ids = [None] * len(ids)

print('Loaded %d sentences.' % len(ids))
print('Finished loading')
return list(zip(ids, chars_ids))

def get_sentence(self):
while True:
if self._i == self._nids:
self._ids = self._load_random_shard()
ret = self._ids[self._i]
self._i += 1
yield ret

@property
def max_word_length(self):
if self._use_char_inputs:
return self._vocab.max_word_length
else:
return None

def iter_batches(self, batch_size, num_steps):
for X in _get_batch(self.get_sentence(), batch_size, num_steps,
self.max_word_length):

# token_ids = (batch_size, num_steps)
# char_inputs = (batch_size, num_steps, 50) of character ids
# targets = word ID of next word (batch_size, num_steps)
yield X

@property
def vocab(self):
return self._vocab

class BidirectionalLMDataset(object):
def __init__(self, filepattern, vocab, test=False, shuffle_on_load=False):
'''
bidirectional version of LMDataset
'''
self._data_forward = LMDataset(
filepattern, vocab, reverse=False, test=test,
shuffle_on_load=shuffle_on_load)
self._data_reverse = LMDataset(
filepattern, vocab, reverse=True, test=test,
shuffle_on_load=shuffle_on_load)

def iter_batches(self, batch_size, num_steps):
max_word_length = self._data_forward.max_word_length

for X, Xr in zip(
_get_batch(self._data_forward.get_sentence(), batch_size,
num_steps, max_word_length),
_get_batch(self._data_reverse.get_sentence(), batch_size,
num_steps, max_word_length)
):

for k, v in Xr.items():
X[k + '_reverse'] = v

yield X

0 comments on commit 5de3ca9

Please sign in to comment.