In [None]:
import math
import os
import sys; sys.path.append('../lib')
from functools import partial
from time import time

import matplotlib.pyplot as plt
import numpy as np

from data import Text, TrumpTweetArchive
from history import TrainHistoryRecurrent
from recurrent_network import RecurrentNetwork

# Constants

In [None]:
DATA_DIR = '../data'
PICKLE_DIR = '../pickle'
FIGURE_DIR = '../figures'

HYPERPARAMS = {
    'hidden_state_size': 100,
    'sequence_length': 15,
    'eta': 0.1,
    'epochs': 5
}

# Load data

In [None]:
tweet_archive = TrumpTweetArchive(DATA_DIR, 'condensed_201*.json')

start_one_hot = tweet_archive.get_start_character(rep='index_one_hot')
stop_one_hot = tweet_archive.get_stop_character(rep='index_one_hot')

In [None]:
tweet_archive.num_tweets

In [None]:
tweet_archive.random_preview()

# Train network

In [None]:
network = RecurrentNetwork(
    input_size=tweet_archive.num_characters,
    hidden_state_size=HYPERPARAMS['hidden_state_size'],
    random_seed=0)

for e in range(HYPERPARAMS['epochs']):
    # shuffle tweets
    tweets_shuffled = np.random.permutation(
        tweet_archive.num_tweets)

    for i, j in enumerate(tweets_shuffled):
        tweet = tweet_archive.tweets[j]

        fmt = "\repoch {}: [{:50s}]"
        bar = '=' * (50 * (i + 1) // tweet_archive.num_tweets)
        msg = fmt.format(e + 1, bar)

        if i == tweet_archive.num_tweets - 1:
            print(msg, end='\n\n')
        else:
            print(msg, end='')

        history = network.train(
            tweet,
            seq_length=HYPERPARAMS['sequence_length'],
            eta=HYPERPARAMS['eta'],
            n_updates=math.inf,
            n_epochs=1,
            skip_incomplete=False,
            continue_training=True,
            find_best_params=False)

    print("synthesized tweets:\n")

    for i in range(5):
        sequence = network.synthesize(
            length=140,
            init_one_hot=start_one_hot,
            stop_character_one_hot=stop_one_hot)

        sample_tweet = tweet_archive.get_characters(
            sequence, one_hot=True)

        print(sample_tweet + '\n')

history.save(PICKLE_DIR, postfix='rnn_trump')

In [None]:
history = TrainHistoryRecurrent.load(PICKLE_DIR, postfix='rnn_trump')

In [None]:
history.visualize()

plt.savefig(os.path.join(FIGURE_DIR, 'rnn_tweets.svg'))

In [None]:
np.random.seed(0)

for i in range(10):
    network = history.final_network

    sequence = network.synthesize(
        length=140,
        init_one_hot=start_one_hot,
        stop_character_one_hot=stop_one_hot)

    sequence = tweet_archive.get_characters(sequence, one_hot=True)

    print(sequence + '\n')