In [None]:
import pickle

from keras.preprocessing.sequence import pad_sequences
import matplotlib.pyplot as plt
import numpy as np
from numpy.random import randint
from sklearn.metrics import confusion_matrix

import seq2seq_base

%matplotlib inline

# Set random seed for reproducibility
np.random.seed(42)

# Data preparation

In [None]:
MAX_LEN = 100
SNIPPET_LEN = 40

save_features = "/storage/egor/tmp/feat_token2ind.pkl"
with open(save_features, "rb") as f:
    feat, token2ind = pickle.load(f)

# add start "symbol"
START = token2ind.setdefault("starting_symbol_for_decoding_seq", len(token2ind))

# increment all values in dictionary by 1
token2ind_new = dict((k, v + 1) for k, v in token2ind.items())

In [None]:
def prepare_samples(raw_feat, enc_len=MAX_LEN, dec_len=SNIPPET_LEN, max_n_samples=10 ** 5,
                    min_seq_len=50):
    enc = []
    dec = []
    
    while len(enc) < max_n_samples:
        # select random sequence of tokens
        ind = randint(len(raw_feat))
        seq = raw_feat[ind]
        
        if len(seq) < min_seq_len:
            continue
        
        ind = randint(len(seq))
        start_ind = max(0, ind - enc_len)
        # increase all values by one to make 0 padding
        enc.append(list(map(lambda x: x + 1, seq[start_ind:ind])))
        end_ind = min(len(seq), ind + dec_len)
        # prepend start symbol
        dec.append(list(map(lambda x: x + 1, [START] + seq[ind:end_ind])))
    dec = pad_sequences(dec, maxlen=dec_len)
    dec_in = dec[:, :-1]
    dec_target = dec[:, 1:]
    enc = pad_sequences(enc, maxlen=enc_len, padding="post")
    dec_in[:, 0] = enc[:, -1]
    return pad_sequences(enc, maxlen=enc_len), dec_in, dec_target

In [None]:
train = feat[:-2000]
val = feat[-2000:]

train_enc, train_dec_in, train_dec_target = prepare_samples(train)
val_enc, val_dec_in, val_dec_target = prepare_samples(val, max_n_samples=20000)

# Prepare model

In [None]:
enc_token2ind = token2ind_new
enc_latent_dim = 256
optimizer = "rmsprop"
encoder_seq_len = MAX_LEN
decoder_seq_len = SNIPPET_LEN

s2s = seq2seq_base.Seq2SeqBase(enc_token2ind=enc_token2ind, enc_latent_dim=enc_latent_dim, optimizer=optimizer,
                               encoder_seq_len=encoder_seq_len, decoder_seq_len=decoder_seq_len)

# Train model

In [None]:
batch_size = 256
epochs = 1
for _ in range(3):
    train_enc, train_dec_in, train_dec_target = prepare_samples(train)
    val_enc, val_dec_in, val_dec_target = prepare_samples(val, max_n_samples=20000)
    s2s.train_model.fit([train_enc, train_dec_in], np.expand_dims(train_dec_target, axis=-1),
                        batch_size=batch_size, shuffle=True, epochs=epochs,
                        validation_data=[[val_enc, val_dec_in], np.expand_dims(val_dec_target, axis=-1)])

# Let's make confusion matrix

In [None]:
def plot_confusion_matrix(cm, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


val_pred = s2s.train_model.predict([val_enc, val_dec_in])
val_pred_ind = val_pred.argmax(axis=-1)

gt = np.hstack([val_dec_target[:, i] for i in range(39)])
pred = np.hstack([val_pred_ind[:, i] for i in range(39)])
gt.shape, pred.shape

# Compute confusion matrix
cnf_matrix = confusion_matrix(pred, gt)
np.set_printoptions(precision=2)

class_names = [x[1] for x in list(sorted(s2s.ind2token.items(), key=lambda x: x[0]))]

# Plot normalized confusion matrix
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
                      title='Normalized confusion matrix')

In [None]:
# Inference step for several samples
for seq_index in range(10):
    # Take one sequence (part of the training test)
    # for trying out decoding.
    input_seq = val_enc[seq_index: seq_index + 1]
    input_sent = " ".join([s2s.ind2token.get(ind, "")
                          for ind in input_seq[0]]).strip()
    expected_output = " ".join(s2s.ind2token.get(ind, "") 
                               for ind in val_dec_target[seq_index: seq_index + 1][0])
    decoded_sentence = s2s.decode_sequence(input_seq, s2s.ind2token.get(input_seq[0][-1], ""))
    print('Input sentence:', input_sent)
    print('Decoded sentence:', decoded_sentence)
    print("-" * 20)