# CuDNN GRU in TensorFlow: Issues and Example Usage

## Introduction

Gated recurrent units (GRUs) are a type of recurrent neural network designed to have a persistent memory that can capture long-term dependencies. This repository contains a frozen GRU-based waveform generator inspired by Google DeepMind's WaveRNN (2018).

These are the four equations governing the CuDNN version of the GRU cell:

![GRU equations](img/gru_small.png "GRU equations")

The update gate $u$ (sometimes referred to as $z$) controls how much the hidden state is updated at each timestep by weighting the previous hidden state against a new candidate $\tilde h$. The reset gate $r$ controls the creation of the candidate $\tilde h$ by gating the contribution of the previous state.

## CuDNN GRU in TensorFlow

The CuDNN implementation of the GRU cell is very fast compared to a GRU implementation in native TensorFlow using `tf.while_loop` and explicit operations. For the WaveRNN model in this repository, CuDNN GRU yielded a __7.3x__ training speedup over the explicit version, reducing the length of training needed for convergence from 8-24 hours to 1-3 hours (depending on the size of the hidden state and other parameters). 

CuDNN GRU is supported by TF's `CudnnGRU` class, which has a weights buffer that is saved in a `CudnnOpaqueParamsSaveable` object and a `call(...)` method for running a forward step. The class also has methods for converting between TF and CuDNN canonical weights and biases.



## Issues with TF's `CudnnGRU` --- and workaround

Unfortunately, the CuDNN RNN module is poorly documented, buggy, and difficult to work with. After slotting the CuDNN GRU into our WaveRNN model, I was unable to get the network to train at all, even after trying to manually avoid some of the problematic built-in functions. 

But there is a workaround: the `_forward(...)` method that directly handles the GRU forward-pass works just fine. So if we create a weight buffer of the correct shape and pass it to `_forward(...)`, we can run a Cudnn GRU without fussing with the rest of the module.

First let's import TensorFlow, `cudnn_rnn_ops`, etc. (note: CuDNN requires __tensorflow-gpu__)

In [None]:
import math
import os
import random
import time
import numpy as np
import tensorflow as tf
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops

Now define and initialize a weight buffer of the correct shape (`"Kernel"`) and pass it to CuDNN GRU's `_forward(...)` method. The function returns a node corresponding to the GRU outputs.

In [None]:
def run_cudnn_gru(inputs, input_channels, recurrent_size):
    """
    Run the CuDNN GRU cell.
    
    Args:
        inputs: Input tensor, shape [batch_size, timesteps, input_channels]
        input_channels: Number of input channels
        recurrent_size: Size of GRU hidden state
        
    Returns:
        TF node corresponding to GRU outputs.
    """
    batch_size = tf.shape(inputs)[0]
    timesteps = tf.shape(inputs)[1]

    initial_state = tf.zeros(
        (1, batch_size, recurrent_size), dtype=tf.float32)
    dummy = tf.constant([], dtype=tf.float32)

    with tf.variable_scope("GRU"):
        kernel = tf.get_variable("Kernel",
            shape=[recurrent_size *
                (3 * recurrent_size + 3 * input_channels + 6)],
            initializer=tf.contrib.layers.xavier_initializer())

    # Transpose inputs from batch-major to time-major for CuDNN GRU.
    inputs = tf.transpose(inputs, [1, 0, 2])

    # CuDNN GRU forward pass.
    gru = tf.contrib.cudnn_rnn.CudnnGRU(1, recurrent_size)
    hidden, _ = gru._forward(
        inputs, initial_state, dummy, kernel, training=False)

    # Transpose outputs from time-major to batch-major.
    hidden = tf.transpose(hidden, [1, 0, 2])
    return hidden

## Explicit TensorFlow implementation of CuDNN GRU

The CuDNN GRU only runs on GPU, so if we want to run inference on a CPU, we can use an explicit TensorFlow implementation based on `tf.while_loop` instead. (An explicit implementation is also useful if we want to experiment with specific changes to the GRU equations that are not supported by the CuDNN cell.)

To extract the explicit weights and biases from the CuDNN GRU weight buffer, we can use the `_OpaqueParamsToCanonical()`. This gives us a list of weights and a list of biases, which can be manually processed to build the indivdual weights and biases used in the GRU equations ($W_r$, $W_u$, $W_h$, $R_r$, $R_u$, $R_h$, $b_{Wr}$, $b_{Wu}$, $b_{Wh}$, $b_{Rr}$, $b_{Ru}$, $b_{Rh}$):

In [45]:
def run_cudnn_gru_explicit(inputs, input_channels, recurrent_size):
    """
    Run a replica of the CuDNN GRU cell manually, using a slow
    TensorFlow while-loop and explicit GRU equations.
    
    Args:
        inputs: Input tensor, shape [batch_size, timesteps, input_channels].
        input_channels: Number of input channels.
        recurrent_size: Size of GRU hidden state.
        
    Returns:
        TF node corresponding to GRU outputs.
    """
    batch_size = tf.shape(inputs)[0]
    timesteps = tf.shape(inputs)[1]

    tf.get_variable_scope().reuse_variables()
    with tf.variable_scope("GRU"):
        kernel = tf.get_variable("Kernel",
            shape=[recurrent_size *
                (3 * recurrent_size + 3 * input_channels + 6)],
            initializer=tf.contrib.layers.xavier_initializer())
        
    # Extract lists of weights/biases from the CuDNN GRU weight buffer.
    saveable = cudnn_rnn_ops.CudnnGRUSaveable(
        kernel, 1, recurrent_size, input_channels)
    weights, biases = saveable._OpaqueParamsToCanonical()
    
    # Build the individual weights and biases.
    W_r = tf.transpose(weights[0])
    W_u = tf.transpose(weights[1])
    W_h = tf.transpose(weights[2])
    R_r = tf.transpose(weights[3])
    R_u = tf.transpose(weights[4])
    R_h = tf.transpose(weights[5])
    b_Wr = tf.expand_dims(biases[0], 0)
    b_Wu = tf.expand_dims(biases[1], 0)
    b_Wh = tf.expand_dims(biases[2], 0)
    b_Rr = tf.expand_dims(biases[3], 0)
    b_Ru = tf.expand_dims(biases[4], 0)
    b_Rh = tf.expand_dims(biases[5], 0) 

    def condition(i, *_state):
        """Stopping condition."""
        return tf.less(i, timesteps)

    def body(i, state, array):
        """Loop body."""
        # Run a single step of explicit GRU using the above weights/biases.
        state = run_gru_step(state, inputs[:, i, :], 
                             W_r, W_u, W_h, 
                             R_r, R_u, R_h,
                             b_Wr, b_Wu, b_Wh,
                             b_Rr, b_Ru, b_Rh)
        array = array.write(i, state)
        return i + 1, state, array

    initial_state = [
        tf.constant(0),
        tf.zeros((batch_size, recurrent_size), dtype=tf.float32),
        tf.TensorArray(tf.float32, size=timesteps),
    ]
    final_state = tf.while_loop(condition, body, initial_state)
    return tf.transpose(final_state[-1].stack(), [1, 0, 2])


The single forward step is easy to implement using the GRU equations:

![GRU equations](img/gru_small.png "GRU equations")

In [None]:
def run_gru_step(state, inp, W_r, W_u, W_h, R_r, R_u, R_h,
                 b_Wr, b_Wu, b_Wh, b_Rr, b_Ru, b_Rh):
    """
    A carbon copy of the CuDNN-GRU forward-step.
    
    Args:
        state:            Previous hidden state,
                          shape [batch_size, recurrent_size].
        inp:              Input tensor, 
                          shape [batch_size, input_channels].
        W_r, W_u, W_h:    Weights applied to input,
                          each with shape [input_channels, recurrent_size].
        R_r, R_u, R_h:    Weights applied to previous hidden state,
                          each with shape [recurrent_size, recurrent_size].
        b_Wr, b_Wu, b_Wh: Biases applied to input transform,
                          each with shape [1, recurrent_size].
        b_Rr, b_Ru, b_Rh: Biases applied to hidden state transform,
                          each with shape [1, recurrent_size].
    Returns:
        New hidden state, shape [batch_size, recurrent_size]
    """
    # Individual matrix multiplies are shown for clarity; these would be 
    # more computationally efficient in block form.
    X_r = tf.matmul(inp, W_r) + b_Wr
    X_u = tf.matmul(inp, W_u) + b_Wu
    X_h = tf.matmul(inp, W_h) + b_Wh
    H_r = tf.matmul(state, R_r) + b_Rr
    H_u = tf.matmul(state, R_u) + b_Ru
    H_h = tf.matmul(state, R_h) + b_Rh
    
    r = tf.nn.sigmoid(X_r + H_r)
    u = tf.nn.sigmoid(X_u + H_u)
    candidate = tf.tanh(X_h + r * H_h)
    state = state * u + candidate * (1 - u)

    return state

## Sanity check: compare both models w/ same input

Let's do a forward pass using a random input array to confirm that both implementations give the same result.

In [47]:
RECURRENT_SIZE = 256
BATCH_SIZE = 1
TIMESTEPS = 10
INPUT_CHANNELS = 100
input_shape = (BATCH_SIZE, TIMESTEPS, INPUT_CHANNELS)

inp = tf.placeholder(tf.float32, shape=input_shape)

# Outputs from CuDNN GRU.
outputs = run_cudnn_gru(inp, INPUT_CHANNELS, RECURRENT_SIZE)

# Outputs from explicit-TF GRU.
outputs_explicit = run_cudnn_gru_explicit(inp, INPUT_CHANNELS, RECURRENT_SIZE)

init = tf.global_variables_initializer()
np.random.seed(9999)
inputs = np.random.random_sample(input_shape)

with tf.Session() as sess:
    init.run()
    outputs = sess.run(outputs, feed_dict={inp: inputs})
    print("Outputs of forward pass using CuDNN GRU cell:\n")
    print("Shape: {}".format(outputs.shape))
    print(outputs)
    print("\n")
    outputs_explicit = sess.run(outputs_explicit, feed_dict={inp: inputs})
    print("Outputs of forward pass using explicit TensorFlow operations and `tf.while_loop`:\n")
    print("Shape: {}".format(outputs_explicit.shape))
    print(outputs_explicit)

Outputs of forward pass using CuDNN GRU cell:

(1, 10, 256)
[[[ 1.0947134e-02  5.3554983e-04 -3.4574396e-03 ... -3.3140765e-05
   -3.8697512e-03  4.5041819e-03]
  [ 9.0450915e-03 -3.1024218e-03  4.7755952e-04 ...  1.0971370e-03
   -5.9758904e-03  7.8519406e-03]
  [ 1.1812703e-02 -7.2304592e-03  3.4528917e-03 ... -3.6396014e-03
   -6.8984684e-03  1.1104198e-02]
  ...
  [ 1.5333121e-02 -8.1954943e-03 -7.7663868e-04 ... -8.4617957e-03
   -4.9419687e-03  3.2436673e-03]
  [ 1.8759312e-02 -8.8395784e-03  1.1147156e-04 ... -9.3853874e-03
   -4.9264096e-03  4.8976815e-03]
  [ 1.6782669e-02 -7.2793141e-03 -8.4629748e-03 ... -5.4480475e-03
   -2.0670865e-03  6.7125186e-03]]]


Outputs of forward pass using explicit TensorFlow operations and `tf.while_loop`:

(1, 10, 256)
[[[ 1.0947133e-02  5.3554960e-04 -3.4574401e-03 ... -3.3140306e-05
   -3.8697512e-03  4.5041828e-03]
  [ 9.0450915e-03 -3.1024241e-03  4.7755882e-04 ...  1.0971368e-03
   -5.9758904e-03  7.8519396e-03]
  [ 1.1812702e-02 -7.23046