# RNN training tutorial

This is a tutorial training various RNNs on simple datasets and doing some analysis.

Structure:

  1. basic (vanilla RNN) implementation
  2. observing exploding/vanishing gradients
  3. training an LSTM on character level langugage modelling task
    * comparing training of an LSTM and RNN, playing with architectures
  4. Intepretability by plotting and analysing activations of a network:
    * identifying interpretable neurons
    * identifying neurons-gates interactions
    * identifying hidden state dynamics through time
  
    

First three sections are almost independent, one can go switch between them without any code dependencies (apart from being unable to use vanilla RNN in section 4, if it was not implemented in 1.).

Cells that include "starting point" in their title require filling in some code gaps; all remaining ones are complete (but feel free to play with them if you want!)

Please pay attention to questions after each section. Finding out answers to these is crucial to make sure one understands various modes of RNN operation.

Language model exercises are based on [Sonnet LSTM example](https://github.com/deepmind/sonnet/blob/master/sonnet/examples/rnn_shakespeare.py).
Apart from loading the dataset, we make no further use of Sonnet in this colab.

## Imports

We will use tf.nn.rnn_cell and tf.layers.

In [0]:
#@title Imports

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math

import numpy as np
import seaborn as sns

import torch
import torch.nn as nn 


from matplotlib import pyplot as plt

# from sonnet.examples import dataset_shakespeare
  
sns.set_style('ticks')  


# Ex 1.    Vanilla RNN

Implement basic RNN cell using tf.layers.

   $$ h_t = f( Wx_t + Vh_{t-1}  + b) $$
   
   Where
   
   * $x_t$ input at time $t$
   * $h_t$ hidden state at time $t$
   * $W$ input-to-hidden mapping (trainable)
   * $V$ hidden-to-hidden mapping (trainable)
   * $b$ bias (trainable)
   * $f$ non-linearity chosen (usually tanh)
   
   
   You do not need to worry about the plotting and running code, but focus on the RNN implementation.

In [0]:
#@title Vanilla RNN
class VanillaRNNCell(nn.Module):
  
    def __init__(self, hidden_size, activation=nn.Tanh, bias=True):    
        """
        Constructor for a simple RNNCell where the hidden-to-hidden transitions
        are defined by a linear layer and the default activation of `tanh` 
        :param hidden_size: the size of the hidden state
        :param activation: the activation function used for computing the next hidden state
        """
        super(VanillaRNNCell, self).__init__()
    
        self._hidden_size = hidden_size
        self._activation = activation()  
        self._bias = bias
            
        # Create the hidden-to-hidden layer
        self._linear_hh = nn.Linear(hidden_size, hidden_size, bias=bias)


    def forward(self, inputs, hidden=None):
        out = inputs
        
        if hidden is not None:
            out += self._linear_hh(hidden)
        
        out = self._activation(out)
        return out, out


class VanillaRNN(nn.Module):
    def __init__(self, input_size=1, output_size=1, hidden_size=20, bias=False):
        """
        Creates a vanilla RNN where input-to-hidden is a nn.Linear layer
        and hidden-to-output is a nn.Linear layer
        
        :param input_size: the size of the input to the RNN
        :param hidden_size: size of the hidden state of the RNN
        :param output_size: size of the output
        """
        super(VanillaRNN, self).__init__()
        
        self._input_size = input_size
        self._hidden_size = hidden_size
        self._output_size = output_size
        self._bias = bias
        
        self.in_to_hidden = nn.Linear(self._input_size, self._hidden_size, bias=self._bias)
        self.rnn_cell = VanillaRNNCell(self._hidden_size, bias=self._bias)
        self.hidden_to_out = nn.Linear(self._hidden_size, self._output_size, bias=self._bias)
    
    def step(self, input, hidden=None):
        input_ = self.in_to_hidden(input)
        _, hidden_ = self.rnn_cell(input_, hidden=hidden)
        output_ = self.hidden_to_out(hidden_)
        
        return output_, hidden_
    
    def forward(self, inputs, hidden=None, force=True, warm_start=10):
        steps = len(inputs)
        
        outputs = torch.autograd.Variable(torch.zeros(steps, self._output_size, self._output_size))
        
        output_ = None
        hidden_ = hidden
        
        for i in range(steps):
            if force or i == 0:
                input_ = inputs[i]
            else:
                if i < warm_start:
                    input_ = inputs[i]
                else:
                    input_ = output_
                
            output_, hidden_ = self.step(input_, hidden_)
            outputs[i] = output_
            
        return outputs, hidden_
        

### Train RNN on sine wave

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f'Running code @ {device}')

In [None]:
UNROLL_LENGTH = 30  #@param {type:"integer"}
NUM_ITERATIONS = 10000  #@param {type:"integer"}
WARM_START = 10  #@param {type:"integer"}
TEACHER_FORCING = False  #@param {type:"boolean"}
HIDDEN_UNITS = 20  #@param {type:"integer"}
LEARNING_RATE = 0.0001  #@param {type:"number"}
REPORTING_INTERVAL = 200  #@param {type:"integer"}

# We create training data, sine wave over [0, 2pi]
x_train = np.arange(0, 2*np.pi, 0.1).reshape(-1, 1, 1)
y_train = np.sin(x_train)

net = VanillaRNN(hidden_size=HIDDEN_UNITS, bias=False)
net.train()
net = net.to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)

running_loss = 0

for iteration in range(NUM_ITERATIONS):
    # select a start point in the training set for a sequence of UNROLL_LENGTH
    start = np.random.choice(range(x_train.shape[0] - UNROLL_LENGTH))
    train_sequence = y_train[start : (start + UNROLL_LENGTH)]
    
    train_inputs = torch.from_numpy(train_sequence[:-1]).float().to(device)
    train_targets = torch.from_numpy(train_sequence[1:]).float().to(device)
    
    optimizer.zero_grad()
    
    outputs, hidden = net(train_inputs, hidden=None, force=TEACHER_FORCING, warm_start=WARM_START)
    loss = criterion(outputs, train_targets)
    loss.backward()
    
    running_loss += loss.item()
    
    optimizer.step()
    
    if iteration % REPORTING_INTERVAL == REPORTING_INTERVAL - 1:
        # let's see how well we do on predictions for the whole sequence
        avg_loss = running_loss / REPORTING_INTERVAL
        
        report_sequence = torch.from_numpy(y_train[:-1]).float().to(device)
        report_targets = torch.from_numpy(y_train[1:]).float().to(device)
        report_output, report_hidden = net(report_sequence, hidden=None, force=False, warm_start=WARM_START)
        
        report_loss = criterion(report_output, report_targets)
        print('[%d] avg_loss: %.5f, report_loss: %.5f, ' % (iteration + 1, avg_loss, report_loss.item()))
        
        plt.figure()
        plt.title('Training Loss %.5f;  Sampling loss %.5f; Iteration %d' % (avg_loss, report_loss.item(), iteration))
        
        plt.plot(y_train[1:].ravel(), c='blue', label='Ground truth',
               linestyle=":", lw=6)
        plt.plot(range(start, start+UNROLL_LENGTH-1), outputs.data.numpy().ravel(), c='gold',
               label='Train prediction', lw=5, marker="o", markersize=5,
               alpha=0.7)
        plt.plot(report_output.data.numpy().ravel(), c='r', label='Generated', lw=4, alpha=0.7)
        plt.legend()
        plt.show()

### Train the RNN

Train the RNN on sine data - predict the next sine value from *predicted* sine values.

Predict   $$ sin (x +t \epsilon) $$ from $$ sin (x), sin (x + \epsilon), ..., sin (x + (t-1) \epsilon) $$

In particular, we want the network to predict the next value in a loop, conditioning the prediction on some initial values (provided) and all subsequent predictions.

To learn the prediction model, we will use *teacher forcing*. This means that when training the model, the input at time $t$ is the real sequence at time $t$, rather than the output produced by the model at $t-1$. When we want to generate data from the model, we do not have access to the true sequence, so we do not use teacher forcing. However, in the case of our problem, we will also use *warm starting*, because we require multiple time steps to predict the next sine wave value (at least 2, for the initial value and for the step). 

The code below unrolls the RNN core you have defined above, does the training using backprop though time and plots the real data ("ground truth"), the data generated during training ("train predictions") and the model samples "generated".


Please add your final sampling losses to this spreadsheet in a new row: https://docs.google.com/spreadsheets/d/1Zpi_A6RP89E00vurqz9dRHYCd29PqzB9VB7Y4giydyA/edit#gid=0

In [0]:
# Default hypers:
# UNROLL_LENGTH = 30  #@param {type:"integer"}
# NUM_ITERATIONS = 10000  #@param {type:"integer"}
# WARM_START = 2  #@param {type:"integer"}
# TEACHER_FORCING = False  #@param {type:"boolean"}
# HIDDEN_UNITS = 20  #@param {type:"integer"}
# LEARNING_RATE = 0.0001  #@param {type:"number"}
# REPORTING_INTERVAL = 2000  #@param {type:"integer"}

# You may want to try:
# default hypers with/without teacher forcing
# use UNROLL_LENGTH = 62 to train on the whole sequence (is teacher forcing useful?)
# use UNROLL_LENGTH = 62, no teacher forcing and warm_start = 2 # this should break training

**Note:** initialization is not fixed (we do not fix a random seed), so each time the cell is executed, the parameters take new initial values and hence training can lead to different results. What happens if you run it multiple times?

###What is worth trying/understanding here?

* Difference between teacher forcing and learning on own samples:
 * What are the pros and cons of teacher forcing?
 * Why is the model struggling to learn in one of the setups?
 * What is it we actually care about for models like this? What should be the actual surrogate?
* How does warm starting affect our training? Why?
* What happens if the structure of interest is much longer than the unroll length?

Answers:
* Teacher forcing because BPTT is much easier and works better in practice. Intuition is similar to immitation learning. Without TF it is very hard to learn because error tend to accumulate. If you use TF then you get very local structure.
* No teacher forcing makes training very difficult.
* Depending on what you want to model, this loss may be fine if you care about probabilities but not generating samples.

# Ex. 2      Vanishing and exploding gradients

Given an input sequence $(x_1, ..., x_N)$ of random floats (sampled from normal distribution), train an RNN as before and compute the gradients of the last output state w.r.t. every previous state:
$$
\left \| \frac{\partial h_{N}}{\partial h_i} \right \|
$$
for each unroll $i$, and plot these quantities for various RNNs.

Note, that during learning one would compute
$$
\frac{\partial L}{\partial \theta}  
$$
which, using chain rule will involve terms like
$$
\frac{\partial L}{\partial h_N} \cdot
\frac{\partial h_N}{\partial h_{N-1}} \cdot
\dots \cdot
\frac{\partial h_i}{\partial h_{i-1}} \cdot
\dots \cdot
\frac{\partial h_0}{\partial \theta}
$$
so if one of them vanishes, all of them do.

# Hints:

Tensorflow already defines many types of RNN Cells, such as LSTM, GRU, etc.

NB: There is no training here, we are just computing the norms of the gradients of the last hidden state with respect to the hidden state across steps in the sequence.


In [None]:
#@title Vanishing and exploding gradients

SEQ_LENGTH = 15  #@param {type:"integer"}
HIDDEN_UNITS = 20  #@param {type:"integer"}

dummy_input = [torch.from_numpy(np.array([[np.random.normal()]])) for _ in range(SEQ_LENGTH)] 

######################
#   YOUR CODE HERE   #
######################
# Add several cell constructors (use those already defined in Tensorflow) to the
# list (e.g., also add a GRU, and a few more LSTMS with their initial 
# forget_bias values set to: 0, +1, +2 and -2).
# If in doubt, check the documentation.

def _set_forget_bias(lstm_cell, fill_value=0.):
    # The bias terms in the lstm_cell are arranged as bias_input_gate, bias_forget_gate, bias_gain_gate, bias_output_gate
    # To alter the forget_gate bias, we need to modify the parameters from 1/4 to 1/2 of the length of the bias vectors
    for name, _ in lstm_cell.named_parameters():
        if "bias" in name:
            bias = getattr(lstm_cell, name)
            n = bias.size(0)
            start, end = n//4, n//2
            bias.data[start:end].fill_(float(fill_value))
            
    return lstm_cell


### Solution
rnn_types = {
    'LSTM (0)': lambda nhid:  _set_forget_bias(nn.modules.LSTMCell(input_size=1, hidden_size=nhid), fill_value=0.),
    'LSTM (+1)': lambda nhid:  _set_forget_bias(nn.modules.LSTMCell(input_size=1, hidden_size=nhid), fill_value=1.),
    'LSTM (-2)': lambda nhid:  _set_forget_bias(nn.modules.LSTMCell(input_size=1, hidden_size=nhid), fill_value=-2.),
    'LSTM (+2)': lambda nhid:  _set_forget_bias(nn.modules.LSTMCell(input_size=1, hidden_size=nhid), fill_value=2.),
    'LSTM (+10)': lambda nhid:  _set_forget_bias(nn.modules.LSTMCell(input_size=1, hidden_size=nhid), fill_value=10.),
    'GRU': lambda nhid: nn.modules.GRUCell(input_size=1, hidden_size=nhid),
    'RNN': lambda nhid: VanillaRNN(input_size=1, hidden_size=nhid),
}



depths = {rnn_type: [] for rnn_type in rnn_types}
grad_norms = {rnn_type: [] for rnn_type in rnn_types}

for rnn_type in rnn_types:
  
    constructor = rnn_types[rnn_type]
    rnn = constructor(HIDDEN_UNITS)
    
    rnn.zero_grad()
    
    rnn_at_time = []
    gradients_at_time = []
    
    prev_state = None
    
    for i in range(SEQ_LENGTH):
        if prev_state is None:
            prev_state = rnn(dummy_input[i].float())
        else:
            if rnn_type.startswith('RNN'):
                prev_state = rnn(dummy_input[i].float(), hidden=prev_state[1])
            else:
                prev_state = rnn(dummy_input[i].float(), prev_state)
        
        if rnn_type.startswith('LSTM'):
            prev_state[0].retain_grad()
            prev_state[1].retain_grad()
            rnn_at_time.append(prev_state[1])
            
        elif rnn_type.startswith('GRU'):
            prev_state.retain_grad()
            rnn_at_time.append(prev_state)
        
        elif rnn_type.startswith('RNN'):
            prev_state[1].retain_grad()
            rnn_at_time.append(prev_state[1])
    
    # We don't really care about the loss here: we are not solving a specific 
    # problem, any loss will work to inspect the behavior of the gradient.
    dummy_loss = torch.sum(rnn_at_time[-1])
    dummy_loss.backward()
    
    for i in range(1, SEQ_LENGTH):
        current_gradient = rnn_at_time[i].grad
        gradients_at_time.append(current_gradient)
    
    for gid, grad in enumerate(gradients_at_time):
        depths[rnn_type].append(len(gradients_at_time) - gid)    
        grad_norms[rnn_type].append(np.linalg.norm(grad))
        
    dummy_loss.detach_()

plt.figure()
for rnn_type in depths:
    plt.plot(depths[rnn_type], grad_norms[rnn_type], label="%s" % rnn_type, alpha=0.7, lw=2)
plt.legend()  
plt.ylabel("$ \\| \\partial \\sum_i {c_{N}}_i / \\partial c_t \\|$", fontsize=15)
plt.xlabel("Steps through time - $t$", fontsize=15)
plt.xlim((1, SEQ_LENGTH-1))
plt.title("Gradient magnitudes across time for: RNN-Type (forget_bias value)")
#plt.savefig("mygraph.png")
plt.show()


### What do we learn from this?

This particular experiment is an extremely simple surrogate for actual problem, but shows a few interesting aspects:

* Is LSTM by construction free of *exploding* gradients too?
* What are other ways of avoiding explosions you can think of?
* Does initialisation (of gates here, but in general) matter a lot?
* Does this look like a solution that can really scale time-wise? Say to be doing credit assignment through years of experience?
* If not, what might be a next step?

See http://proceedings.mlr.press/v37/jozefowicz15.pdf for a more detailed discussion of the effect of the forget gate bias.

Canonical Answers:
* If you make forget_bias=10 the the gradients will 'explode'
* No. LSTM still has the problem but it can be a little less bad. Clipping is very often used by default to introduce some robustness.
* Gradient clipping.
* Init matters. Identity hidden_to_hidden can help alleviate gradient 'explosion'.
* No scalable to train with very long sequences.
* Transformers.

# Ex. 3    Language Modelling

Now we will train a character level RNN on text data - specifically Shakespeare sonnets. We will reuse the same concepts, such as teacher forcing and different types of RNN cores. 

At the end of the exercise, after you have filled in the TextModel class, you can train the model and see that in generates text that has sonnet structure and learns words.  You should focus on the TextModel class implementation, and leverage the code provided to do the training and visualization and data loading.

## Ex 3.1   Analysis of single neurons and gates

We will now look at the individual activations of neurons in a Recurrent network. For this to work, you need to have completed the previous exercise in which you expose the network activations, as well as train a model.

For a similar analysis, see [this paper](https://arxiv.org/pdf/1506.02078.pdf).

In [0]:
#@title String plot function

def string_plot(chars, values, title=None):
  """
  Given a string "chars" and a vector of numbers "values" of the same length
  displays the string, using "|" as EOL symbol, and colors each character
  background using corresponding value in values
  """
  
  assert len(chars) == len(values)
  
  lines = []
  line = ""
  for char in chars:
    if char != '|':
      line += char
    else:
      line += " "
      lines.append(line)
      line = ""
  lines.append(line)
  
  height = len(lines) 
  width = max(map(len, lines))
    
  data = np.zeros((height, width))
  data[:,:] = np.nan
  
  pos = 0
  for lid, line in enumerate(lines):
    data[lid, :len(line)] = values[pos: pos+len(line)]
    pos = pos+len(line)
    
  assert pos == len(values)
    
  plt.figure(figsize=(width * 0.3, height * 0.3))
  plt.title(title)
  plt.imshow(data.reshape((height, width)), interpolation='none',
             cmap='Reds', alpha=0.5)
  plt.axis('off')
  
  for lid, line in enumerate(lines):
    for cid, char in enumerate(line):
      plt.text(cid-0.2,lid+0.2,char,color='k',fontsize=9)

  plt.show()
  plt.close()

### What kind of neurons can we expect to find?

* Lots of counting neurons (their activity is just growing/decreasing independently from input)
* Names neuron - activates around names of people in the play, such as HAMLET: or JOHN OF GAUNT:
* Line width neuron - with activity proportional to the length of the current line (number of charaters since last "|")
* Paragraph length neuron - activity proportional to the length of the paragraph in lines
* Special character neurons - such as coding for probability of generating ":"
* Many, many mixtures of the above

Note, that if neurons like these do not appear it does not mean that network does not "know" these elements. Highly discriminative, single neuron decoupling is not something neural networks are trained to do, it is just an empirical observation, shared across many domains (cat neurons in visual classifiers etc.). Knowledge can be represented in many other ways, in particular the fact that it is represented in a single neuron does not mean network does not have a distributed "backup" of the same knowledge somewhere else.


## Ex 3.2   Analysis of the state dynamics

In this exercise, we will visualize the activations in a different way, by projecting them to 2 dimensions, via dimensionality reduction. 

When using different projection techniques, you willl see different results. For example, PCA will display the directions with most variance in the data.

### So what am I looking at?

2D projections of high-dimensional spaces are always loosing a lot of information, however the general structure can still be recovered. Here, one can see that both paragraph-splits and line-splits can be decoded by just looking at the dynamics of the hidden state, giving more insights into internals of an RNN. Note, that contrary to single-neuron analysis, here we are truly looking at the whole picture, thus what is observed is likely behind the dynamics of this model.

Canonical Answer:
Recurrent nets use the dimensions of the hidden state to encode position in the sequence, like a counter. The trajectory through the hidden space can be thought of some form of memory. E.g. the RNN could be storing bits of information by setting a dimension to +1 or -1.

# Done.