In [19]:
%matplotlib inline
import numpy as np
import pandas as pd

from keras.models import Graph, Sequential
from keras.layers.recurrent import LSTM
from keras.layers.core import AutoEncoder, Dense, Activation, TimeDistributedDense, Flatten, Dropout, RepeatVector
from keras.layers.recurrent import LSTM
from keras.layers.embeddings import Embedding
from sklearn.cross_validation import StratifiedKFold
from sklearn.metrics import precision_recall_curve, precision_recall_fscore_support, f1_score
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from keras.utils.visualize_util import plot

from helper import *
from lstm_networks import *

In [2]:
# Parameters
vocabulary_size = 196
embedding_size = 10
encoding_size = 30
decoding_size = 10
sequence_length = 90
n_epochs = 5

In [65]:
model = Graph()
model.add_input(name='input', input_shape=(sequence_length,), dtype=int) # or (vocabulary_size?)

# Encoding input into a representation
model.add_node(Embedding(vocabulary_size, embedding_size, input_length=sequence_length),
               name='embedding',
               input='input')
model.add_node(LSTM(encoding_size, input_length=sequence_length, return_sequences=True),
               name='encoder',
               input='embedding')

# Branch 1: reconstruct
model.add_node(LSTM(decoding_size, return_sequences=True),
               name='decoder',
               input='encoder')
model.add_node(TimeDistributedDense(vocabulary_size, activation='softmax'),
               name='distributed',
               input='decoder')
model.add_output(name='reconstruction', input='distributed')

# Branch 2: classify
model.add_node(Flatten(), name='flatten', input='encoder')
model.add_node(Dropout(0.5), name='dropout', input='flatten')
model.add_node(Dense(1, activation='sigmoid'),
               name='dense',
               input='dropout')
model.add_output(name='classification', input='dense')

In [66]:
model.compile(optimizer='rmsprop', loss={'reconstruction':'mse', 'classification':'binary_crossentropy'}) # XXX mse?

In [67]:
plot(model, to_file='model.png')