# Boilerplate

In [1]:
# files
TRAINING_DIRECTORY = 'cnn/stories/'
EXTENSION = '.story'
MAX_FILES = 10000

# tokenization
FILTERS = '!"#$%&()*+,-./:;=?@[\\]^_`{|}~\t\n'  # default minus >, <
INPUT_END_CHAR = '<input-end>'
TARGET_END_CHAR = '<target-end>'
OOV_CHAR = '<unk>'
INPUT_END_TOKEN = 1
TARGET_END_TOKEN = 2
OOV_TOKEN = 3

# MODEL_PARAMS
SENTENCE_LEN = 200
MAX_INPUT_LEN = 175
NUM_WORDS = 50000

# Read in files

In [2]:
import glob

In [3]:
FILES = glob.glob('%s/*%s' % (TRAINING_DIRECTORY, EXTENSION))
print(len(FILES))
FILES

92579


['cnn/stories/0001d1afc246a7964130f43ae940af6bc6c57f01.story',
 'cnn/stories/0002095e55fcbd3a2f366d9bf92a95433dc305ef.story',
 'cnn/stories/00027e965c8264c35cc1bc55556db388da82b07f.story',
 'cnn/stories/0002c17436637c4fe1837c935c04de47adb18e9a.story',
 'cnn/stories/0003ad6ef0c37534f80b55b4235108024b407f0b.story',
 'cnn/stories/0004306354494f090ee2d7bc5ddbf80b63e80de6.story',
 'cnn/stories/0005d61497d21ff37a17751829bd7e3b6e4a7c5c.story',
 'cnn/stories/0006021f772fad0aa78a977ce4a31b3faa6e6fe5.story',
 'cnn/stories/00083697263e215e5e7eda753070f08aa374dd45.story',
 'cnn/stories/000940f2bb357ac04a236a232156d8b9b18d1667.story',
 'cnn/stories/0009ebb1967511741629926ef9f5faea2bb6be24.story',
 'cnn/stories/000c835555db62e319854d9f8912061cdca1893e.story',
 'cnn/stories/000ca3fc9d877f8d4bb2ebd1d6858c69be571fd8.story',
 'cnn/stories/000cd1ee0098c4d510a03ddc97d11764448ebac2.story',
 'cnn/stories/000e009f6b1d954d827c9a550f3f24a5474ee82b.story',
 'cnn/stories/001097a19e2c96de11276b3cce11566ccfed0030.

In [4]:
FILES = FILES[:MAX_FILES]

# Define method for generating text from files

In [5]:
def preprocessor(text):
    table = {ord(c): None for c in '<>'}
    text = text.translate(table)
    return text

In [6]:
def text_generator(files, preprocessor=None):
    for f in files:
        text = open(f).read()
        if preprocessor is not None:
            text = preprocessor(text)
        # remove highlights
        body, highlight1, *_ = text.split('@highlight')
        yield body, highlight1

In [7]:
from keras.preprocessing.sequence import pad_sequences

def tokenize(input_text, target_text, tokenizer, max_input_len, target_begin_token, end_token):
    input_tokens = tokenizer([input_text])[0]
    target_tokens = tokenizer([target_text])[0]
    input_tokens = pad_sequences([input_tokens[:max_input_len]])[0].tolist()
    target_tokens = target_tokens
    return [input_tokens + [target_begin_token] + target_tokens + [end_token]]

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [8]:
next(text_generator(FILES, preprocessor=preprocessor))

('It\'s official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria.\n\nObama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons.\n\nThe proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction."\n\nIt\'s a step that is set to turn an international crisis into a fierce domestic political battle.\n\nThere are key questions looming over the debate: What did U.N. weapons inspectors find in Syria? What happens if Congress votes no? And how will the Syrian government react?\n\nIn a televised address from the White House Rose Garden earlier Saturday, the president said he would take his case to Congress, not because he has to -- but

# Initialize tokenizer

In [9]:
from keras.preprocessing.text import text_to_word_sequence, Tokenizer as _Tokenizer

class Tokenizer(_Tokenizer):
    def fit_on_texts(self, texts):
        """Updates internal vocabulary based on a list of texts.
        In the case where texts contains lists, we assume each entry of the lists
        to be a token.
        Required before using `texts_to_sequences` or `texts_to_matrix`.
        # Arguments
            texts: can be a list of strings,
                a generator of strings (for memory-efficiency),
                or a list of list of strings.
        """
        for text in texts:
            self.document_count += 1
            if self.char_level or isinstance(text, list):
                seq = text
            else:
                seq = text_to_word_sequence(text,
                                            self.filters,
                                            self.lower,
                                            self.split)
            for w in seq:
                if w in self.word_counts:
                    self.word_counts[w] += 1
                else:
                    self.word_counts[w] = 1
            for w in set(seq):
                if w in self.word_docs:
                    self.word_docs[w] += 1
                else:
                    self.word_docs[w] = 1

        wcounts = list(self.word_counts.items())
        wcounts.sort(key=lambda x: x[1], reverse=True)
        sorted_voc = [wc[0] for wc in wcounts]
        # note that index 0, 1, 2 is reserved, never assigned to an existing word
        self.word_index = dict(list(zip(sorted_voc, list(range(4, len(sorted_voc) + 4)))))
        self.word_index[self.oov_token] = 3

        for w, c in list(self.word_docs.items()):
            self.index_docs[self.word_index[w]] = c

    def texts_to_sequences_generator(self, texts):
        """Transforms each text in `texts` in a sequence of integers.
        Each item in texts can also be a list, in which case we assume each item of that list
        to be a token.
        Only top "num_words" most frequent words will be taken into account.
        Only words known by the tokenizer will be taken into account.
        # Arguments
            texts: A list of texts (strings).
        # Yields
            Yields individual sequences.
        """
        num_words = self.num_words
        for text in texts:
            if self.char_level or isinstance(text, list):
                seq = text
            else:
                seq = text_to_word_sequence(text,
                                            self.filters,
                                            self.lower,
                                            self.split)
            vect = []
            for w in seq:
                i = self.word_index.get(w)
                if i is not None and (self.num_words and i < self.num_words):
                    vect.append(i)
                elif self.oov_token is not None:
                    i = self.word_index.get(self.oov_token)
                    if i is not None:
                        vect.append(i)
            yield vect

In [10]:
TOKENIZER = Tokenizer(
    num_words=NUM_WORDS,
    filters=FILTERS,  # no newline
    oov_token=OOV_CHAR)

In [11]:
gen = text_generator(FILES, preprocessor=preprocessor)

In [12]:
%%time
TOKENIZER.fit_on_texts(text for train_pair in gen for text in train_pair)

CPU times: user 8.51 s, sys: 904 ms, total: 9.41 s
Wall time: 11.4 s


In [13]:
TOKENIZER.num_words

50000

In [14]:
TOKENIZER.document_count

20000

In [15]:
len(TOKENIZER.word_index), TOKENIZER.word_index

(100605,
 {'the': 4,
  'to': 5,
  'of': 6,
  'a': 7,
  'and': 8,
  'in': 9,
  'that': 10,
  'for': 11,
  'is': 12,
  'said': 13,
  'on': 14,
  'was': 15,
  'he': 16,
  'with': 17,
  'it': 18,
  'as': 19,
  'at': 20,
  'his': 21,
  'have': 22,
  'from': 23,
  'are': 24,
  'i': 25,
  'be': 26,
  'but': 27,
  'by': 28,
  'this': 29,
  'has': 30,
  'an': 31,
  'not': 32,
  'they': 33,
  'who': 34,
  'we': 35,
  'will': 36,
  'were': 37,
  'their': 38,
  'you': 39,
  'about': 40,
  'one': 41,
  'had': 42,
  'she': 43,
  'been': 44,
  'more': 45,
  'her': 46,
  'or': 47,
  'cnn': 48,
  'people': 49,
  'after': 50,
  'when': 51,
  'new': 52,
  'all': 53,
  'which': 54,
  'there': 55,
  'out': 56,
  'would': 57,
  'up': 58,
  'what': 59,
  'its': 60,
  'also': 61,
  "it's": 62,
  'year': 63,
  'two': 64,
  'time': 65,
  'than': 66,
  'if': 67,
  'so': 68,
  'can': 69,
  'u': 70,
  'no': 71,
  'some': 72,
  's': 73,
  'other': 74,
  'first': 75,
  'into': 76,
  'just': 77,
  'my': 78,
  'like':

In [16]:
index_to_word = {v: k for k, v in TOKENIZER.word_index.items()}
index_to_word[0] = '<pad>'
index_to_word[INPUT_END_TOKEN] = INPUT_END_CHAR
index_to_word[TARGET_END_TOKEN] = TARGET_END_CHAR

In [17]:
sorted(index_to_word.items(), key=lambda x: x[0])

[(0, '<pad>'),
 (1, '<input-end>'),
 (2, '<target-end>'),
 (3, '<unk>'),
 (4, 'the'),
 (5, 'to'),
 (6, 'of'),
 (7, 'a'),
 (8, 'and'),
 (9, 'in'),
 (10, 'that'),
 (11, 'for'),
 (12, 'is'),
 (13, 'said'),
 (14, 'on'),
 (15, 'was'),
 (16, 'he'),
 (17, 'with'),
 (18, 'it'),
 (19, 'as'),
 (20, 'at'),
 (21, 'his'),
 (22, 'have'),
 (23, 'from'),
 (24, 'are'),
 (25, 'i'),
 (26, 'be'),
 (27, 'but'),
 (28, 'by'),
 (29, 'this'),
 (30, 'has'),
 (31, 'an'),
 (32, 'not'),
 (33, 'they'),
 (34, 'who'),
 (35, 'we'),
 (36, 'will'),
 (37, 'were'),
 (38, 'their'),
 (39, 'you'),
 (40, 'about'),
 (41, 'one'),
 (42, 'had'),
 (43, 'she'),
 (44, 'been'),
 (45, 'more'),
 (46, 'her'),
 (47, 'or'),
 (48, 'cnn'),
 (49, 'people'),
 (50, 'after'),
 (51, 'when'),
 (52, 'new'),
 (53, 'all'),
 (54, 'which'),
 (55, 'there'),
 (56, 'out'),
 (57, 'would'),
 (58, 'up'),
 (59, 'what'),
 (60, 'its'),
 (61, 'also'),
 (62, "it's"),
 (63, 'year'),
 (64, 'two'),
 (65, 'time'),
 (66, 'than'),
 (67, 'if'),
 (68, 'so'),
 (69, 'can'

In [18]:
TOKENIZER.num_words = min(len(TOKENIZER.word_index)+1, TOKENIZER.num_words)

In [19]:
gen = text_generator(FILES)
x, y = next(gen)

In [20]:
len(x)

9442

In [21]:
x

'It\'s official: U.S. President Barack Obama wants lawmakers to weigh in on whether to use military force in Syria.\n\nObama sent a letter to the heads of the House and Senate on Saturday night, hours after announcing that he believes military action against Syrian targets is the right step to take over the alleged use of chemical weapons.\n\nThe proposed legislation from Obama asks Congress to approve the use of military force "to deter, disrupt, prevent and degrade the potential for future uses of chemical weapons or other weapons of mass destruction."\n\nIt\'s a step that is set to turn an international crisis into a fierce domestic political battle.\n\nThere are key questions looming over the debate: What did U.N. weapons inspectors find in Syria? What happens if Congress votes no? And how will the Syrian government react?\n\nIn a televised address from the White House Rose Garden earlier Saturday, the president said he would take his case to Congress, not because he has to -- but 

In [22]:
seq = tokenize(
    x,
    y,
    TOKENIZER.texts_to_sequences,
    max_input_len=MAX_INPUT_LEN,
    target_begin_token=INPUT_END_TOKEN,
    end_token=TARGET_END_TOKEN)
seq

[[62,
  241,
  70,
  73,
  90,
  824,
  132,
  821,
  1879,
  5,
  6387,
  9,
  14,
  254,
  5,
  220,
  172,
  445,
  9,
  444,
  132,
  660,
  7,
  1095,
  5,
  4,
  2267,
  6,
  4,
  158,
  8,
  525,
  14,
  331,
  271,
  375,
  50,
  4337,
  10,
  16,
  1142,
  172,
  510,
  112,
  574,
  2258,
  12,
  4,
  188,
  808,
  5,
  146,
  82,
  4,
  1020,
  220,
  6,
  1846,
  662,
  4,
  1969,
  1591,
  23,
  132,
  3857,
  501,
  5,
  4756,
  4,
  220,
  6,
  172,
  445,
  5,
  9294,
  8025,
  1205,
  8,
  15356,
  4,
  802,
  11,
  437,
  2206,
  6,
  1846,
  662,
  47,
  74,
  662,
  6,
  1411,
  3156,
  62,
  7,
  808,
  10,
  12,
  243,
  5,
  809,
  31,
  190,
  739,
  76,
  7,
  4631,
  1315,
  223,
  1038,
  55,
  24,
  654,
  689,
  8731,
  82,
  4,
  934,
  59,
  130,
  70,
  652,
  662,
  4727,
  299,
  9,
  444,
  59,
  1898,
  67,
  501,
  1779,
  71,
  8,
  96,
  36,
  4,
  574,
  95,
  5285,
  9,
  7,
  4922,
  940,
  23,
  4,
  245,
  158,
  2430,
  3886,
  382,
  331,
 

In [23]:
[[index_to_word[i] for i in L] for L in seq]

[["it's",
  'official',
  'u',
  's',
  'president',
  'barack',
  'obama',
  'wants',
  'lawmakers',
  'to',
  'weigh',
  'in',
  'on',
  'whether',
  'to',
  'use',
  'military',
  'force',
  'in',
  'syria',
  'obama',
  'sent',
  'a',
  'letter',
  'to',
  'the',
  'heads',
  'of',
  'the',
  'house',
  'and',
  'senate',
  'on',
  'saturday',
  'night',
  'hours',
  'after',
  'announcing',
  'that',
  'he',
  'believes',
  'military',
  'action',
  'against',
  'syrian',
  'targets',
  'is',
  'the',
  'right',
  'step',
  'to',
  'take',
  'over',
  'the',
  'alleged',
  'use',
  'of',
  'chemical',
  'weapons',
  'the',
  'proposed',
  'legislation',
  'from',
  'obama',
  'asks',
  'congress',
  'to',
  'approve',
  'the',
  'use',
  'of',
  'military',
  'force',
  'to',
  'deter',
  'disrupt',
  'prevent',
  'and',
  'degrade',
  'the',
  'potential',
  'for',
  'future',
  'uses',
  'of',
  'chemical',
  'weapons',
  'or',
  'other',
  'weapons',
  'of',
  'mass',
  'destru

In [24]:
len(seq), len(seq[0])

(1, 193)

In [25]:
s = seq[0]
s

[62,
 241,
 70,
 73,
 90,
 824,
 132,
 821,
 1879,
 5,
 6387,
 9,
 14,
 254,
 5,
 220,
 172,
 445,
 9,
 444,
 132,
 660,
 7,
 1095,
 5,
 4,
 2267,
 6,
 4,
 158,
 8,
 525,
 14,
 331,
 271,
 375,
 50,
 4337,
 10,
 16,
 1142,
 172,
 510,
 112,
 574,
 2258,
 12,
 4,
 188,
 808,
 5,
 146,
 82,
 4,
 1020,
 220,
 6,
 1846,
 662,
 4,
 1969,
 1591,
 23,
 132,
 3857,
 501,
 5,
 4756,
 4,
 220,
 6,
 172,
 445,
 5,
 9294,
 8025,
 1205,
 8,
 15356,
 4,
 802,
 11,
 437,
 2206,
 6,
 1846,
 662,
 47,
 74,
 662,
 6,
 1411,
 3156,
 62,
 7,
 808,
 10,
 12,
 243,
 5,
 809,
 31,
 190,
 739,
 76,
 7,
 4631,
 1315,
 223,
 1038,
 55,
 24,
 654,
 689,
 8731,
 82,
 4,
 934,
 59,
 130,
 70,
 652,
 662,
 4727,
 299,
 9,
 444,
 59,
 1898,
 67,
 501,
 1779,
 71,
 8,
 96,
 36,
 4,
 574,
 95,
 5285,
 9,
 7,
 4922,
 940,
 23,
 4,
 245,
 158,
 2430,
 3886,
 382,
 331,
 4,
 90,
 13,
 16,
 57,
 146,
 21,
 191,
 5,
 501,
 32,
 92,
 16,
 30,
 5,
 27,
 92,
 16,
 821,
 5,
 106,
 25,
 365,
 1,
 574,
 241,
 132,
 6543,
 5,
 4,

In [26]:
one_hot = TOKENIZER.sequences_to_matrix([[i] for i in s])

In [27]:
one_hot

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.]])

In [28]:
# only one per row
import numpy as np
np.argwhere(one_hot == 1)[:50]

array([[   0,   62],
       [   1,  241],
       [   2,   70],
       [   3,   73],
       [   4,   90],
       [   5,  824],
       [   6,  132],
       [   7,  821],
       [   8, 1879],
       [   9,    5],
       [  10, 6387],
       [  11,    9],
       [  12,   14],
       [  13,  254],
       [  14,    5],
       [  15,  220],
       [  16,  172],
       [  17,  445],
       [  18,    9],
       [  19,  444],
       [  20,  132],
       [  21,  660],
       [  22,    7],
       [  23, 1095],
       [  24,    5],
       [  25,    4],
       [  26, 2267],
       [  27,    6],
       [  28,    4],
       [  29,  158],
       [  30,    8],
       [  31,  525],
       [  32,   14],
       [  33,  331],
       [  34,  271],
       [  35,  375],
       [  36,   50],
       [  37, 4337],
       [  38,   10],
       [  39,   16],
       [  40, 1142],
       [  41,  172],
       [  42,  510],
       [  43,  112],
       [  44,  574],
       [  45, 2258],
       [  46,   12],
       [  47,

# Define batch generator

In [29]:
def sequencer(tokens, L):
    return [tokens[i:L+i] for i in range(0, len(tokens)-L+1)]

In [30]:
list(sequencer('a quick brown fox', 10))

['a quick br',
 ' quick bro',
 'quick brow',
 'uick brown',
 'ick brown ',
 'ck brown f',
 'k brown fo',
 ' brown fox']

In [31]:
import random
import numpy as np

class BatchGenerator:
    def __init__(self, files, tokenizer, one_hot_encoder, num_words, max_input_len, sentence_len, batch_size,
                 input_end_token, target_end_token, epoch_end=None):
        self.files = files
        self.tokenizer = tokenizer
        self.one_hot_encoder = one_hot_encoder
        self.num_words = num_words
        self.max_input_len = max_input_len
        self.sentence_len = sentence_len
        self.batch_size = batch_size
        self.input_end_token = input_end_token
        self.target_end_token = target_end_token
        self.epoch_end = epoch_end
        
    def generate_forever(self):
        while True:
            random.shuffle(self.files)
            yield from self.generate_epoch()
            
    def generate_epoch(self):
        steps = []
        for file in self.files:
            training_example = self.process_file(file)
            for target_seq in self.sequence(training_example.target_tokens):
                full_example = training_example.input_tokens + target_seq
                full_example = pad_sequences([full_example], maxlen=self.sentence_len)[0]
                if len(full_example) > self.sentence_len:
                    full_example = full_example[-self.sentence_len:]
                steps.append(full_example)
            
            n_batches = len(steps) // self.batch_size
            n_steps = n_batches * self.batch_size
            batches = self.generate_batches(steps[:n_steps], self.batch_size, n_batches)
            steps = steps[n_steps:]
            yield from batches
    
    def generate_batches(self, steps, batch_size, n_batches):
        for i in range(n_batches):
            start, stop = i*batch_size, (i+1)*batch_size
            X = np.array(steps[start:stop])
            y = self.one_hot_encoder([[i] for s in X for i in s])
            y = y.reshape((self.batch_size, self.sentence_len, self.num_words))
            yield X[:-1], y[1:]

    def process_file(self, file):
        input_text, target_text = self.split_file(file)
        input_tokens, target_tokens = self.tokenizer([input_text])[0], self.tokenizer([target_text])[0]
        # add special tokens
        input_tokens = input_tokens[:self.max_input_len] + [self.input_end_token]
        target_tokens = target_tokens + [self.target_end_token]
        training_example = TrainingExample(input_text, target_text,
                                           input_tokens, target_tokens)
        return training_example

    def split_file(self, file):
        with open(file) as f:
            text = f.read()
        text = self.preprocess(text)
        body, highlight1, *_ = text.split('@highlight')
        return body, highlight1
        
    def preprocess(self, text):
        text = ' \n '.join(t for t in text.split('\n') if t)
        table = {ord(c): None for c in '<>'}
        text = text.translate(table)
        return text

    def sequence(self, tokens):
        return [tokens[:i] for i in range(1, len(tokens)+1)]
    

class TrainingExample:
    def __init__(self, input_text, target_text, input_tokens, target_tokens):
        self.input_text = input_text
        self.target_text = target_text
        self.input_tokens = input_tokens
        self.target_tokens = target_tokens


In [32]:
batch_gen = BatchGenerator(
    files=FILES,
    tokenizer=TOKENIZER.texts_to_sequences,
    one_hot_encoder=TOKENIZER.sequences_to_matrix,
    num_words=TOKENIZER.num_words,
    max_input_len=MAX_INPUT_LEN,
    sentence_len=SENTENCE_LEN,
    batch_size=32,
    input_end_token=INPUT_END_TOKEN,
    target_end_token=TARGET_END_TOKEN
).generate_forever()

In [33]:
X, y = next(batch_gen)

In [34]:
X

array([[   0,    0,    0, ...,  128,    1,  475],
       [   0,    0,    0, ...,    1,  475, 2236],
       [   0,    0,    0, ...,  475, 2236,   15],
       ...,
       [   0,    0,    0, ...,    9, 1062, 2083],
       [   0,    0,    0, ..., 1062, 2083,   12],
       [   0,    0,    0, ..., 2083,   12,  131]], dtype=int32)

In [35]:
X.shape, y.shape

((31, 200), (31, 200, 50000))

In [36]:
X

array([[   0,    0,    0, ...,  128,    1,  475],
       [   0,    0,    0, ...,    1,  475, 2236],
       [   0,    0,    0, ...,  475, 2236,   15],
       ...,
       [   0,    0,    0, ...,    9, 1062, 2083],
       [   0,    0,    0, ..., 1062, 2083,   12],
       [   0,    0,    0, ..., 2083,   12,  131]], dtype=int32)

In [37]:
y

array([[[1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [38]:
print('\n\n'.join([' '.join([index_to_word.get(i, '<pad>') for i in x]) for x in X]))

<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> kabul afghanistan cnn a british soldier who was reported missing from a checkpoint in afghanistan early monday was found dead later in the day the british ministry of defense announced it is with great sadness that i announce the death of a soldier from the <unk> 4th battalion the royal regiment of scotland said lt col tim <unk> spokesman for task force helmand in a statement after an extensive search fellow members of nato's international security assistance force found his body in the nahr e saraj district of helmand province <unk> said the soldier had suffered gunshot wo

In [39]:
# only one per row
import numpy as np
ys = np.argwhere(y[0] == 1)

In [40]:
import numpy as np
for j in range(0, len(y), 5):
    ys = np.argwhere(y[j] == 1)
    assert len(ys) == len({row for row, idx in ys})
    print(' '.join(index_to_word[idx] for row, idx in ys))
    print('\n')

<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> kabul afghanistan cnn a british soldier who was reported missing from a checkpoint in afghanistan early monday was found dead later in the day the british ministry of defense announced it is with great sadness that i announce the death of a soldier from the <unk> 4th battalion the royal regiment of scotland said lt col tim <unk> spokesman for task force helmand in a statement after an extensive search fellow members of nato's international security assistance force found his body in the nahr e saraj district of helmand province <unk> said the soldier had suffered gunshot wounds h

# Training

In [41]:
N_HEADS = 8
N_LAYERS = 6
D_MODEL = 64*N_HEADS
VOCAB_SIZE = TOKENIZER.num_words
WARMUP_STEPS = 200
BATCH_SIZE = 32

In [42]:
batch_gen = BatchGenerator(
    files=FILES,
    tokenizer=TOKENIZER.texts_to_sequences,
    one_hot_encoder=TOKENIZER.sequences_to_matrix,
    num_words=TOKENIZER.num_words,
    max_input_len=MAX_INPUT_LEN,
    sentence_len=SENTENCE_LEN,
    batch_size=BATCH_SIZE,
    input_end_token=INPUT_END_TOKEN,
    target_end_token=TARGET_END_TOKEN
)

In [None]:
%%time
# loop over batch generator until we hit the end of the epoch
# to calculate number of batches in epoch and compute some
# stats along the way
steps_per_epoch = 0
for batch in batch_gen.generate_epoch():
    steps_per_epoch += 1

In [None]:
print('steps per epoch', steps_per_epoch)

In [None]:
train_gen = batch_gen.generate_forever()

In [None]:
from keras.callbacks import TerminateOnNaN
callbacks = [TerminateOnNaN()]

In [None]:
from model_decoder import TransformerDecoder
model = TransformerDecoder(
        n_heads=N_HEADS, decoder_layers=N_LAYERS,
        d_model=D_MODEL, vocab_size=VOCAB_SIZE, sequence_len=SENTENCE_LEN,
        layer_normalization=True, dropout=True,
        residual_connections=True)

In [None]:
model.summary()

In [None]:
# import keras.backend as K
# def loss(y_true, y_pred):
#    return K.categorical_crossentropy(y_true[:,-1:,:], y_pred[:,-1:,:])

In [None]:
loss = 'categorical_crossentropy'

In [None]:
class LRScheduler:
    def __init__(self, d_model, warmup_steps):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.epoch = 1

    def lr(self, epoch):
        lr = self.d_model**-.5 * min(self.epoch**-.5, epoch*(self.warmup_steps**-1.5))
        self.epoch += 1
        return lr
lr_scheduler = LRScheduler(D_MODEL, WARMUP_STEPS)

In [None]:
from keras.callbacks import LearningRateScheduler
# callbacks.append(LearningRateScheduler(lr_scheduler.lr))

In [None]:
from keras.optimizers import adam
model.compile(loss=loss, optimizer=adam(lr=1e-4))

In [None]:
# from keras import backend as K
# old_lr = K.get_value(model.optimizer.lr)
# K.set_value(model.optimizer.lr, 1e-4)

In [None]:
n_epochs = 1000
model.fit_generator(
    train_gen, steps_per_epoch=steps_per_epoch,
    epochs=n_epochs, callbacks=callbacks)

In [None]:
X, y = next(batch_gen)

In [None]:
y

In [None]:
def show_X(X):
    print('X:', ' '.join(index_to_word[i] for i in X))
    
def show_y(y):
    ones = np.argwhere(y == 1)
    print('y:', ' '.join(index_to_word[idx] for row, idx in ones))

def show_results(model, X, y):
    show_X(X[0])
    show_y(y)
    y_hat = model.predict(X)
    show_y(y_hat[0])