<a href="https://colab.research.google.com/github/NataliaDiaz/colab/blob/master/Memorize_with_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Memorize short sequences of letters with an LSTM

The goal of this test is to use an LSTM to memorize a short sequence of letters and have the network produce that sequence when given the first letter of the sequence with 100% accuracy.

In addition we want to do this in a strictly *online* setting. So each character is fed to the network one at a time, the network produces an output for each input, and we take a gradient step after each output is produced.

## Questions

Below we show that an LSTM is able to memorize and reproduce *some* short strings easily but not others. 


*   Why are some strings much harder to predict than others?
*   Why does increasing the network size help in case 7?
*   What are the underlying scaling factors in terms of memorization difficulty, capacity, and speed?





In [0]:
import torch
import six
import numpy as np
from difflib import SequenceMatcher
from plotly.offline import download_plotlyjs, init_notebook_mode
from plotly.offline import iplot as plot
from torch import nn

torch.set_printoptions(precision=4)

# Lib

In [0]:
# https://github.com/keras-team/keras-preprocessing/blob/master/keras_preprocessing/sequence.py#L15
def pad_sequences(sequences, maxlen=None, dtype='int32', padding='pre', truncating='pre', value=0.):
    """Pads sequences to the same length.
    This function transforms a list of
    `num_samples` sequences (lists of integers)
    into a 2D Numpy array of shape `(num_samples, num_timesteps)`.
    `num_timesteps` is either the `maxlen` argument if provided,
    or the length of the longest sequence otherwise.
    Sequences that are shorter than `num_timesteps`
    are padded with `value` at the end.
    Sequences longer than `num_timesteps` are truncated
    so that they fit the desired length.
    The position where padding or truncation happens is determined by
    the arguments `padding` and `truncating`, respectively.
    Pre-padding is the default.
    # Arguments
        sequences: List of lists, where each element is a sequence.
        maxlen: Int, maximum length of all sequences.
        dtype: Type of the output sequences.
            To pad sequences with variable length strings, you can use `object`.
        padding: String, 'pre' or 'post':
            pad either before or after each sequence.
        truncating: String, 'pre' or 'post':
            remove values from sequences larger than
            `maxlen`, either at the beginning or at the end of the sequences.
        value: Float or String, padding value.
    # Returns
        x: Numpy array with shape `(len(sequences), maxlen)`
    # Raises
        ValueError: In case of invalid values for `truncating` or `padding`,
            or in case of invalid shape for a `sequences` entry.
    """
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        lengths.append(len(x))

    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    for s in sequences:
        if len(s) > 0:
            sample_shape = np.asarray(s).shape[1:]
            break

    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_)
    if isinstance(value, six.string_types) and dtype != object and not is_dtype_str:
        raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n"
                         "You should set `dtype=object` for variable length strings."
                         .format(dtype, type(value)))

    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)
    for idx, s in enumerate(sequences):
        if not len(s):
            continue  # empty list/array was found
        if truncating == 'pre':
            trunc = s[-maxlen:]
        elif truncating == 'post':
            trunc = s[:maxlen]
        else:
            raise ValueError('Truncating type "%s" '
                             'not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError('Shape of sample %s of sequence at position %s '
                             'is different from expected shape %s' %
                             (trunc.shape[1:], idx, sample_shape))

        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x


# https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py#L9-L37
def to_categorical(y, num_classes=None, dtype='float32'):
    """Converts a class vector (integers) to binary class matrix.
    E.g. for use with categorical_crossentropy.
    # Arguments
        y: class vector to be converted into a matrix
            (integers from 0 to num_classes).
        num_classes: total number of classes.
        dtype: The data type expected by the input, as a string
            (`float32`, `float64`, `int32`...)
    # Returns
        A binary matrix representation of the input. The classes axis
        is placed last.
    # Example
    ```python
    # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}:
    > labels
    array([0, 2, 1, 2, 0])
    # `to_categorical` converts this into a matrix with as many
    # columns as there are classes. The number of rows
    # stays the same.
    > to_categorical(labels)
    array([[ 1.,  0.,  0.],
           [ 0.,  0.,  1.],
           [ 0.,  1.,  0.],
           [ 0.,  0.,  1.],
           [ 1.,  0.,  0.]], dtype=float32)
    ```
    """

    y = np.array(y, dtype='int')
    input_shape = y.shape
    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
        input_shape = tuple(input_shape[:-1])
    y = y.ravel()
    if not num_classes:
        num_classes = np.max(y) + 1
    n = y.shape[0]
    categorical = np.zeros((n, num_classes), dtype=dtype)
    categorical[np.arange(n), y] = 1
    output_shape = input_shape + (num_classes,)
    categorical = np.reshape(categorical, output_shape)
    return categorical
  
def configure_plotly_browser_state():
    import IPython
    display(IPython.core.display.HTML('''
          <script src="/static/components/requirejs/require.js"></script>
          <script>
            requirejs.config({
              paths: {
                base: '/static/base',
                plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
              },
            });
          </script>
          '''))


# Source

In [0]:
def get_alphabet():
    """
    All the characters that might be seen in a sequence
    """
    return "ABCDEFGHIJKLMNOPQRSTUVWXYZ"


def get_sequence(name="alphabet"):

    sequences = dict(
        alphabet="ABCDEFGHIJKLMNOPQRSTUVWXYZ",
        # Repeated letter H
        repeatone="ABCDEFGHIJKLMNOPHRSTUVWXYZ",
        # Repeat five letters - Seq len 30
        # [C, D, F, J, T]
        repeatfive="ABCDETFGCHIJKLDMNOPFQRSTUJVWXYZ",
        # A sequence of 30 randomly selected characters
        random30="VNMVWKTFFLOZKUNSDYGOEFOBTZTTVU",
    )
    return sequences[name]


def get_char_to_int(alphabet):
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    return char_to_int


def get_int_to_char(alphabet):
    int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    return int_to_char


def get_data(alphabet, sequence):

    # Get mapping of characters to integers (0-25) and the reverse
    char_to_int = get_char_to_int(alphabet)

    # Prepare the dataset of input to output pairs encoded as integers
    dataX = []
    y = []
    for i in range(0, len(sequence)):
        # 0 - 25
        char_in = sequence[i]
        next_ind = (i + 1) % len(sequence)
        char_out = sequence[next_ind]
        dataX.append([char_to_int[char_in]])
        y.append(char_to_int[char_out])
        print(char_in, '->', char_out)

    # Convert list of lists to array and pad sequences if needed
    X = pad_sequences(dataX, maxlen=1, dtype='float32')

    # Reshape X to be [samples, time steps, features]
    X = np.reshape(dataX, (X.shape[0], 1, 1))

    # One hot encode input
    X = to_categorical(X, len(alphabet))
    
    # Convert our training data into a tensor
    X = torch.Tensor(X)

    return X, y

def letter_to_one_hot_tensor(letter):
    alphabet = get_alphabet()
    char_to_int = get_char_to_int(alphabet)

    letter_as_int = char_to_int[letter]
    l_as_array = [[letter_as_int]]
    l_as_one_hot = to_categorical(l_as_array, len(alphabet))
    one_hot_tensor = torch.Tensor(l_as_one_hot)
    return one_hot_tensor


def one_hot_tensor_to_letter(one_hot_tensor):
    one_hot_tensor = one_hot_tensor.detach()
    alphabet = get_alphabet()
    int_to_char = get_int_to_char(alphabet)
    in_int = int(np.argmax(one_hot_tensor[0]))
    letter = int_to_char[in_int]
    return letter


def run_forward_output(model, start, steps):
    x = letter_to_one_hot_tensor(start)
    for i in range(steps):
        in_char = one_hot_tensor_to_letter(x)
        output, state, pred = model(x)
        out_char = one_hot_tensor_to_letter(output)
        print(in_char, "->", out_char)
        # Feed the prediction back in to the next timestep
        x = output


def run_forward_pred(model, sequence, start=0, steps=25, one_step=False):
    """
    Run the model forward for <steps> steps starting at <start>
    index in the sequence, feeding the output of the model back
    in as the input in the next step.
    """
    start_letter = sequence[start]
    x = letter_to_one_hot_tensor(start_letter)
    steps_correct = 0
    broken = False
    generated = []
    # Make sure model is in a clean starting state
    model.reset_prev_state()
    for i in range(steps):
        # If the model was trained without state evaluate in the same way
        if one_step:
          model.reset_prev_state()
        in_char = one_hot_tensor_to_letter(x)
        output = model(x)
        out_char = one_hot_tensor_to_letter(output)
        generated.append(out_char)
        next_letter = sequence[(start + i + 1) % len(sequence)]

        # Feed the prediction back in to the next timestep
        x = output
        if out_char != next_letter:
            broken = True
        if not broken:
            steps_correct += 1

    target_sequence = sequence[1:]
    print("Target Sequence:", target_sequence)
    generated = "".join(generated)
    print("Produced Sequence:", generated)
    matcher = SequenceMatcher(None, target_sequence, generated)
    match = matcher.find_longest_match(0, len(target_sequence), 0, len(generated))
    longest_substring = target_sequence[match.a: match.a + match.size]
    longest_len = len(longest_substring)
    print("Longest Substring:", longest_substring)

    print("Initial steps without breaking: {}".format(steps_correct))
    print("Longest substring produced: {}".format(longest_len))
    return steps_correct


def evaluate_model(model, sequence, one_step):

    steps_correct = run_forward_pred(
        model,
        sequence,
        start=0, 
        steps=len(sequence) - 1,
        one_step=one_step
    )
    return steps_correct == len(sequence) - 1
    

class LSTMMod(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.lin = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=2)

        # Init hidden state and context to zeros
        self.reset_prev_state()

    def _get_prev_state_size(self):
        return self.hidden_size

    def _store_prev(self, hidden, context):
        self.prev_state = hidden
        self.prev_context = context

    def reset_prev_state(self):
        self.prev_state = torch.zeros(1, 1, self._get_prev_state_size())
        self.prev_state.requires_grad = True

        self.prev_context = torch.zeros(1, 1, self._get_prev_state_size())
        self.prev_context.requires_grad = True

    def forward(self, inp):
        inp = inp.unsqueeze(0)  # Add extra batch dimension
        hidden, state = self.lstm(inp, (self.prev_state, self.prev_context))
        h_n, c_n = state
        lin_out = self.lin(hidden)
        output = self.softmax(lin_out)

        output = output.squeeze(dim=0)  # Remove extra dimension

        self._store_prev(hidden, c_n)
        return output


def train(model, loss_module, opt, alphabet, sequence, max_epochs, one_step=False):
    X, y = get_data(alphabet, sequence)

    ff_losses = []

    for i in range(max_epochs):
        if i % 10 == 0:
            print(" ========== EPOCH: {} ============".format(i))
            
            if evaluate_model(model, sequence, one_step):
              print("Training SUCCESSFUL.")
              break
        # Reset state at the beginning of each epoch
        model.reset_prev_state()
        # Feed the network the sequence character by character
        for j, inp in enumerate(X):
            # If set reset state after each character and truncate
            # backprop to a single step
            if one_step:
                model.reset_prev_state()
            next_ind = (j + 1) % len(sequence)
            next_char = sequence[next_ind]
            label = letter_to_one_hot_tensor(next_char)
            # Labels must be Long tensors
            # This exact format is required
            label = torch.argmax(label, dim=1)

            output = model(inp)

            # Compute the feed forward loss
            loss = loss_module(output, label)
            ff_losses.append(loss.detach().item())

            # Zero gradients
            opt.zero_grad()

            # Calculate the gradients of the losses
            loss.backward(retain_graph=True)

            # Take an optimization step
            opt.step()

    return ff_losses




# Main

In [0]:
def run_experiment(sequence, one_step, units=32, max_epochs=1000):
    alphabet = get_alphabet()

    model = LSTMMod(
        input_size=len(alphabet),
        hidden_size=units,
        output_size=len(alphabet)
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_module = nn.CrossEntropyLoss()
    max_epochs = 1000

    ff_losses = train(
        model,
        loss_module,
        optimizer,
        alphabet,
        sequence,
        max_epochs,
        one_step
    )

    print("Double check the evaluation ...")
    evaluate_model(model, sequence, one_step)

    # Plot losses from training run
    configure_plotly_browser_state()
    init_notebook_mode(connected=False)

    ff_losses = np.array(ff_losses)
    ff_loss_trace = dict(y=ff_losses, name="Feed Forward")
    data = [
        ff_loss_trace,
    ]

    plot(data)

## Case 1 (Working)

Target sequence, the alphabet itself, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

Here we train an LSTM with 32 units as if it were a feed-forward network. 

We reset state after each character fed to the network, and thus backprop is only 1 step. (Controlled by the variable `one_step`)

This converges in around 50-100 epochs.

In [0]:
sequence = get_sequence("alphabet")
one_step = True

run_experiment(sequence, one_step)

## Case 2 (Working)

Target sequence, the alphabet itself, "ABCDEFGHIJKLMNOPQRSTUVWXYZ"

Here we train an LSTM with 32 units as normal for an RNN.

We **DO NOT** reset state after each character fed to the network, we **do** reset after each epoch, and thus backprop is at most 25 steps.

This converges in around 50-100 epochs.

Loss reduction over time is less smooth.

In [0]:
sequence = get_sequence("alphabet")
one_step = False

run_experiment(sequence, one_step)

## Case 3 (Failing - But Expected)

Target sequence has a **single repeated character**, the letter "H" appears twice, and there is no "Q"

"ABCDEFGHIJKLMNOPHRSTUVWXYZ"

We reset state after each character fed to the network.

There is no way for a network trained in this manner to correctly predict BOTH the transitions:

G -> H

P -> H

This reaches `max_epochs` without converging

In [0]:
sequence = get_sequence("repeatone")
one_step = True

run_experiment(sequence, one_step)

## Case 4 (Working - Slower)

Target sequence has a **single repeated character**, the letter "H" appears twice, and there is no "Q"

"ABCDEFGHIJKLMNOPHRSTUVWXYZ"

We **DO NOT** reset state after each character fed to the network.

This converges after 100 to 200 epochs


In [0]:
sequence = get_sequence("repeatone")
one_step = False

run_experiment(sequence, one_step)

## Case 5 (Working - Slow)

Target sequence has a **five repeated characters**

"ABCDETFGCHIJKLDMNOPFQRSTUJVWXYZ"

We **DO NOT** reset state after each character fed to the network.

This often converges after 200-300 epochs

In [0]:
sequence = get_sequence("repeatfive")
one_step = False

run_experiment(sequence, one_step)

## Case 6 (Failing)

Target sequence is 30 characters that were selected randomly from the alphabet (with replacement)

"VNMVWKTFFLOZKUNSDYGOEFOBTZTTVU"

We **DO NOT** reset state after each character fed to the network.

This fails to converge after 1000 epochs.



In [0]:
sequence = get_sequence("random30")
one_step = False

run_experiment(sequence, one_step)

## Case 7 (Working - Slow)

Here we significantly increase the size of the network to have **128 hidden units**.

Target sequence is 30 characters that were selected randomly from the alphabet (with replacement)

"VNMVWKTFFLOZKUNSDYGOEFOBTZTTVU"

We **DO NOT** reset state after each character fed to the network.

This often converges after 400-600 epochs



In [0]:
sequence = get_sequence("random30")
one_step = False

run_experiment(sequence, one_step, units=128, max_epochs=2000)