This is an implementation the Connectionist Temporal Classification loss function:

> Graves, A., Fernández, S., Gomez, F., & Schmidhuber, J. (2006, June). Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks. In Proceedings of the 23rd international conference on Machine learning (pp. 369-376). ACM. ftp://ftp.idsia.ch/pub/juergen/icml2006.pdf

This notebook only show the learning procedure, no thorough testing is performed and the prefix search decoding is not implemented (contributions are welcome!).

The original paper seems to use size 1 minibatches instead of 16 here. There shouldn't be any significant variations otherwise.

Please download the [TIMIT dataset](http://academictorrents.com/details/34e2b78745138186976cbc27939b1b34d18bd5b3) and place the `TIMIT.zip` file next to this one.

The following python packages are required:
- scipy
- lasagne
- matplotlib
- [sphfile](https://pypi.python.org/pypi/sphfile) (to read the sound files)
- [python_speech_features](https://github.com/jameslyons/python_speech_features) (to generate mfcc features)


In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib notebook

import os
os.environ['THEANO_FLAGS'] = "device=cuda"
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [None]:
import pickle as pkl
import numpy as np
from zipfile import ZipFile
from sphfile import SPHFile
from python_speech_features import mfcc
import theano
import theano.tensor as T
import lasagne
from lasagne.layers import InputLayer, LSTMLayer, DenseLayer, ConcatLayer, GaussianNoiseLayer
from lasagne.init import Uniform
from lasagne.nonlinearities import tanh, sigmoid
import matplotlib
import matplotlib.pyplot as plt
from ctc import ctc_loss, log_softmax, ctc_backward
import time

## small useful functions

In [None]:
def smooth(x, w):
    window = int(np.ceil(len(x) / 2 * (1000 ** w - 1) / 999))
    window += 1 - window % 2
    
    if window < 3 or len(x) < window:
        return x
    
    edge_weights = np.arange(1, window // 2 + 1)
    return np.concatenate([
        np.cumsum(x[:window // 2]) / edge_weights,
        np.convolve(x, np.full([window], 1 / window), 'valid'),
        np.cumsum(x[:-window // 2:-1])[::-1] / edge_weights[::-1]])

def argmax_decode(preds, exclude=()):
    preds = np.argmax(preds, axis=1)
    decoded = [preds[0]]
    for v in preds:
        if v != decoded[-1]:
            decoded.append(v)
    
    return np.array([v for v in decoded if v not in exclude])

## Prepare dataset

In [None]:
if not os.path.isdir("data/lisa/data/timit/raw/TIMIT"):
    assert os.path.exists("TIMIT.zip"), "Missing data archive"
    with ZipFile("TIMIT.zip", 'r') as f:
        f.extractall(path=".")

In [None]:
files = []
train_subset = []

for dirpath, _, filenames in os.walk("data/lisa/data/timit/raw/TIMIT"):
    for f in filenames:
        if f.endswith("WAV"):
            recording = SPHFile(dirpath + "/" + f).content
            files.append(dirpath + "/" + f[:-4])
            train_subset.append(dirpath[31:36] == "TRAIN")

files = np.array(files)
train_subset = np.array(train_subset, dtype=np.bool)

## Preprocessing

In [None]:
if not os.path.exists("preprocessed_dataset.pkl"):
    features = []
    labels = []

    for f in files:
        recording = SPHFile(f + ".WAV")
        signal = recording.content
        samplerate = recording.format['sample_rate']

        mfccfeats = mfcc(signal, samplerate=samplerate, winlen=0.01, winstep=0.005, 
                         numcep=13, nfilt=26, appendEnergy=True)
        derivatives = np.concatenate([
            mfccfeats[1, None] - mfccfeats[0, None],
            .5 * mfccfeats[2:] - .5 * mfccfeats[0:-2],
            mfccfeats[-1, None] - mfccfeats[-2, None]], axis=0)

        features.append(np.concatenate([mfccfeats, derivatives], axis=1).astype(np.float32))

        with open(f + ".PHN") as phonem_file:
            labels.append([l.split()[2] for l in phonem_file.readlines()])

    m = np.mean(np.concatenate(features, axis=0))
    s = np.std(np.concatenate(features, axis=0))

    for i in range(len(features)):
        features[i] = (features[i] - m) / s

    vocabulary = set()
    for lseq in labels:
        vocabulary |= set(lseq)

    vocabulary = list(vocabulary)
    vocabulary[-1], vocabulary[vocabulary.index('h#')] = \
        vocabulary[vocabulary.index('h#')], vocabulary[-1]
    blank = len(vocabulary) - 1

    for i in range(len(labels)):
        labels[i] = np.array([vocabulary.index(l) for l in labels[i]], dtype=np.int32)
    
    with open("preprocessed_dataset.pkl", 'wb') as f:
        pkl.dump((features, labels, vocabulary, blank), f, -1)


with open("preprocessed_dataset.pkl", 'rb') as f:
    features, labels, vocabulary, blank = pkl.load(f)

In [None]:
# let's go brutal and shove that in GPU memory

n_sequences = len(features)
feat_size = features[0].shape[1]
max_duration = max(len(seq) for seq in features)
max_labels = max(len(seq) - 2 for seq in labels)  # -2 for init and final blank

durations = np.array([len(seq) for seq in features], dtype=np.int32)
nlabels = np.array([len(seq) - 2 for seq in labels], dtype=np.int32)
all_features = np.zeros((n_sequences, max_duration, feat_size), dtype=np.float32)
for i in range(n_sequences):
    all_features[i, :durations[i]] = features[i]
all_labels = np.zeros((n_sequences, max_labels), dtype=np.int32)
for i in range(n_sequences):
    all_labels[i, :nlabels[i]] = labels[i][1:-1]

durations_var = T.as_tensor_variable(durations, name="durations")
all_features_var = T.as_tensor_variable(all_features, name="all_features")
nlabels_var = T.as_tensor_variable(nlabels, name="nlabels")
all_labels_var = T.as_tensor_variable(all_labels, name="all_labels")

minibatch_indexes = T.ivector()
batch_features = all_features_var[minibatch_indexes]
batch_durations = durations_var[minibatch_indexes]
batch_nlabels = nlabels_var[minibatch_indexes]
batch_labels = all_labels_var[minibatch_indexes]

## Model

In [None]:
batch_size = 16

l_in = InputLayer(
    input_var=batch_features,
    shape=(batch_size, max_duration, feat_size))

l_duration = InputLayer(input_var=batch_durations, shape=(1,))

l_mask = lasagne.layers.ExpressionLayer(
    l_duration, 
    lambda d: T.arange(max_duration)[None, :] < d[:, None])

l_noise = GaussianNoiseLayer(l_in, sigma=0.6)
# l_noise = l_in

l_fwlstm = LSTMLayer(
    l_noise, 100,
    ingate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    forgetgate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    cell=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=tanh),
    outgate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    nonlinearity=tanh,
    mask_input=l_mask, peepholes=True)
l_bwlstm = LSTMLayer(
    l_noise, 100,
    ingate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    forgetgate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    cell=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=tanh),
    outgate=lasagne.layers.Gate(W_cell=Uniform(0.1), nonlinearity=sigmoid),
    nonlinearity=tanh,
    mask_input=l_mask, peepholes=True, backwards=True)

l_cat = ConcatLayer([l_fwlstm, l_bwlstm], axis=2)

l_linout = DenseLayer(
    l_cat, len(vocabulary), 
    nonlinearity=None,
    num_leading_axes=2)

## Training

In [None]:
train_output = lasagne.layers.get_output(
    l_linout, deterministic=False).dimshuffle(1, 0, 2)

loss = ctc_loss(
    linout=train_output,
    durations=batch_durations,
    labels=batch_labels,
    label_sizes=batch_nlabels,
    blank=blank)

params = lasagne.layers.get_all_params(l_linout, trainable=True)
grads = theano.grad(loss.sum(), params)
updates = lasagne.updates.adam(
    grads, params, 
    learning_rate=1e-4)
update_fn = theano.function(
    [minibatch_indexes], 
    loss,
    updates=updates)

In [None]:
i = 0
nsteps = int(100 * n_sequences / batch_size)
params_history = []
loss_history = np.zeros((nsteps,))

def update_plot(fig, ax1, ax2, loss_history):
    ax1.clear()
    ax1.set_xlim(0, len(loss_history))
    ax1.set_yscale('log')
    ax1.set_ylim(0.8 * np.percentile(loss_history, 1), 
                1.2 * np.percentile(loss_history, 99))
    ax1.grid(color='gray', linestyle='-', linewidth=1)
    ax1.grid(color='gray', linestyle=':', which='minor', linewidth=1)
    ax1.set_axisbelow(True)
    xticks = np.arange(len(loss_history))
    ax1.scatter(xticks, loss_history, marker='.', 
               color='firebrick', edgecolor="none", alpha=0.1)
    smooth_history = smooth(loss_history, 0.6)
    ax1.plot(xticks, smooth_history, linewidth=2, color='firebrick')

    ax2.clear()
    ax2.set_yscale('log')
    ax2.set_ylim(0.8 * np.percentile(loss_history, 1), 
                 1.2 * np.percentile(loss_history, 99))
    ax2.grid(False)
    ax2.yaxis.set_label_position("right")
    ax2.set_yticks([], minor=True)
    ax2.set_yticks([smooth_history[-1]])
    ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())

    fig.canvas.draw()

In [None]:
fig = plt.figure()
ax1 = fig.add_subplot(111)
xticks = np.arange(i)
ax1.set_xlim(0, i + 1)
ax1.set_ylim(0, 1)
ax2 = ax1.twinx()

# Note: you can interrupt and resume the execution of this cell
while i < nsteps:
    t1 = time.time()
    batch_loss = np.mean(update_fn(
        np.random.choice(n_sequences, batch_size).astype(np.int32)))
    t2 = time.time()
    
    print("\r{:<6d} loss = {:>5.0f}, (d={:1.2f})".format(i, batch_loss, t2 - t1), end='', flush=True)
    loss_history[i] = batch_loss

    if (i + 1) % 10 == 0:        
        update_plot(fig, ax1, ax2, loss_history[:i])

#     if (i + 1) % 1000 == 0:
#         params_history.append(lasagne.layers.get_all_param_values(l_linout))

    i += 1

## Evaluate model

In [None]:
test_output = lasagne.layers.get_output(l_linout, deterministic=True)

logits_fn = theano.function(
    [minibatch_indexes],
    [batch_features, batch_durations, 
     batch_labels, batch_nlabels, 
     test_output])

In [None]:
sequence = 3

f, d, l, n, p = logits_fn(np.array([sequence], dtype=np.int32))
f = f[0, :d[0]]
l = l[0, :n[0]]
p = p[0, :d[0]]
s = np.exp(p - np.max(p, axis=-1, keepdims=True)) \
    / np.sum(np.exp(p - np.max(p, axis=-1, keepdims=True)), axis=-1, keepdims=True)

fig = plt.figure()
ax = plt.subplot(111)
lines = []

for c in np.argsort(vocabulary[:-1]):
    if c in l:
        line, = ax.plot(np.arange(len(p)), s[:, c], label=vocabulary[c], picker=5)
        lines.append(line)

ax.plot(np.arange(len(p)), s[:, -1], linestyle=":")

ax.set_ylim(0.0, 1.2)
# ax.set_yscale('log')
ax.set_title('Select curve to see the label')

ax.legend(
    framealpha=1,
    loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=8)

fig.subplots_adjust(bottom=0.5)
fig.show()

def onpick(event):
    for line in lines:
        line.set_alpha(0.3)
        line.set_linewidth(2)
    
    event.artist.set_alpha(1)
    event.artist.set_linewidth(2)
    ax.set_title(event.artist.get_label())

cid = fig.canvas.mpl_connect('pick_event', onpick)

print("target    : {}".format(", ".join(vocabulary[l_] for l_ in l)))
print("prediction: {}".format(", ".join(vocabulary[l_] for l_ in argmax_decode(s, [blank]))))