# **Problem 1 (MLP)**

First, we must import the necessary JAX library functions, namely random, vmap, grad, jit, ravel_pytree, and the Adam Optimizer.  In addition, we import the partial decorator and plotting tools used for visualization.

In [None]:
import jax.numpy as np
from jax import random, vmap, grad, jit
from jax.flatten_util import ravel_pytree
from jax.experimental import optimizers

import itertools
from functools import partial
from tqdm import trange
import numpy.random as npr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

Now, we can develop a class which can be used as a Multi-Layer Perceptron Network.  We use this as a feed-forward network.

In [None]:
class MLP():
  def __init__(self, X, y, layers, init_method = 'glorot', rng_key = random.PRNGKey(0)):
    # Normalize data
    self.Xmean, self.Xstd = X.mean(0), X.std(0)
    self.Ymean, self.Ystd = y.mean(0), y.std(0)
    X = (X - self.Xmean)/self.Xstd
    y = (y - self.Ymean)/self.Ystd

    # Store the normalized trainind data
    self.X = X
    self.y = y
    self.layers = layers

    # Use stax to set up network initialization and evaluation functions
    self.net_init, self.net_apply = self.init_MLP(init_method)
    
    # Initialize parameters, not committing to a batch shape
    self.net_params = self.net_init(rng_key, layers)
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.sgd(1e-4)
    self.opt_state = self.opt_init(self.net_params)

    # Logger to monitor the loss function
    self.loss_log = []
    self.itercount = itertools.count()

    # Logger to monitor error function

    self.error_log = []

  """
  First, we must initialize our MLP model.  We use a glorot initialization for 
  a uniform or a normal distribution for the weights. The init_W function 
  initializes our parameters (weights) according to each of the type of
  intialization.  We then use the distribution of weights to initialize the
  weights and biases for each layer of the MLP.
  """

  def init_MLP(self, method = 'glorot'):
    # Define init function
    def _init(rng_key, layers):
        # Define methods for initializing the weights
        if method == 'glorot':
          def init_W(rng_key, size):
            in_dim = size[0]
            out_dim = size[1]
            glorot_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
            return glorot_stddev*random.normal(rng_key, (in_dim, out_dim))
        elif method == 'random':
          def init_W(rng_key, size):
            in_dim = size[0]
            out_dim = size[1]
            return random.uniform(rng_key, (in_dim, out_dim))
        # Perform initialization
        weights = []
        biases = []
        num_layers = len(layers) 
        for l in range(0,num_layers-1):
            rng_key, _ = random.split(rng_key)
            W = init_W(rng_key, size=[layers[l], layers[l+1]])
            b = np.zeros((1,layers[l+1]))
            weights.append(W)
            biases.append(b)  
        params = weights, biases
        return params
    
    """
    Next, we define the apply function, which creates the output through 
    the layers of the MLP, given a particular input.  The matrix H applies the
    hyperbolic tangent activation function to the (H dot W) + b.  Once we get
    through all layers, we can then take the last entry in the weights and
    biases matrices and perform the same activation to output what the last
    layer will yield.
    """

    # Define apply function
    def _apply(params, input):
        H = input
        weights, biases = params
        num_layers = len(self.layers)
        for l in range(0,num_layers-2):
            W = weights[l]
            b = biases[l]
            H = np.tanh(np.add(np.dot(H, W), b))
        W = weights[-1]
        b = biases[-1]
        H = np.add(np.dot(H, W), b)
        return H.flatten()
    return _init, _apply

  """
  Here, we use the standard formula for squared error to calculate the loss
  for each element in a batch.  We take as inputs the current parameters and 
  batch of data, and output the loss of a single value.
  """

  def per_example_loglikelihood(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    loss = (y - y_pred)**2
    return loss

  def per_example_logerror(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    loss = (y - y_pred)**2
    return loss

  """
  Now, we can leverage the per-batch loss calculation to determine the loss of 
  the entire batch as it passes through the MLP.  We take as inputs the current
  parameters and batch of data, and output the loss of the entire batch.
  """ 

  def loss(self, params, batch):
    # Implementation #1
    pe_loss = lambda x: self.per_example_loglikelihood(params, x)
    loss = np.sum(vmap(pe_loss)(batch))
    return loss

  def error(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    error = np.linalg.norm(y - y_pred, 2)/np.linalg.norm(y, 2)
    return error

  """
  We can now update parameters based on the Adam Optimizer's update rule.  This
  means that we are updating parameters with a variable learning rate.  We use
  the partial decorator to indicate we'd like this function to utilize a
  hardware accelerator to compile the code faster.  We take as inputs the 
  iteration number, the current optimized state of the network, and the current
  batch number.
  """

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  """
  This function is used to create each batch from a given data set and output it
  in a stream.  We take as inputs the dimension of the training data, the number
  of batches, and the desired batch size.  We output the training data as the 
  desired number of batches whose dimensions match the desired size.
  """

  def data_stream(self, n, num_batches, batch_size):
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(n)
      for i in range(num_batches):
        batch_idx = perm[i*batch_size:(i+1)*batch_size]
        yield self.X[batch_idx, :], self.y[batch_idx]

  """
  This is how we train the model.  Using the data_stream function to create
  batches, we can iterate through each epoch to determine how the parameters 
  should be updated using our step function.  We take as inputs the number of 
  epochs and the desired batch size.  We output nothing, but after this function
  has run, we have generated the optimized state for our MLP.
  """

  def train(self, num_epochs = 100, batch_size = 64):   
    n = self.X.shape[0]
    num_complete_batches, leftover = divmod(n, batch_size)
    num_batches = num_complete_batches + bool(leftover) 
    batches = self.data_stream(n, num_batches, batch_size)
    pbar = trange(num_epochs)
    for epoch in pbar:
      for _ in range(num_batches):
        batch = next(batches)
        self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      self.net_params = self.get_params(self.opt_state)
      loss_value = self.loss(self.net_params, batch)
      self.loss_log.append(loss_value)
      # error_value = self.error(self.net_params, batch)
      # self.error_log.append(error_value)
      pbar.set_postfix({'Loss': loss_value})

  """
  This function is used to generate the prediction our model produces given a 
  particular test data set.  We take as inputs the current parameters of our 
  model and the test data, and we output the prediction our model produces using
  the inputs.
  """

  def predict(self, params, X_star):
    X_star = (X_star - self.Xmean)/self.Xstd
    pred_fn = lambda x: self.net_apply(params, x)
    y_pred = vmap(pred_fn)(X_star)
    y_pred = y_pred*self.Ystd + self.Ymean
    return y_pred

  """
  This function is used to compute the activation of each layer in the the MLP
  using the hyperbolic tangent.  This is very similar to how we initialize the 
  network.  We take as inputs the current parameters and a test data set, and we
  output the matrix of each layer's output using the test data.
  """

  def compute_activations(self, params, X_star):
    X_star = (X_star - self.Xmean)/self.Xstd
    def MLP_pass(params, input):
      H = input
      H_list = []
      H_list.append(H)
      weights, biases = params
      num_layers = len(self.layers)
      for l in range(0,num_layers-2):
          W = weights[l]
          b = biases[l]
          H = np.tanh(np.add(np.dot(H, W), b))
          H_list.append(H)
      W = weights[-1]
      b = biases[-1]
      H = np.add(np.dot(H, W), b)
      H_list.append(H)
      return H_list
    # Get predictions
    pred_fn = lambda x: MLP_pass(params, x)
    H_list = vmap(pred_fn)(X_star)
    return H_list

Here we have the function we'd like to develop a predictive model for.

In [None]:
def f(x):
  x1, x2 = x[0], x[1]
  y = np.sqrt(x1**2 + x2**2)
  return y

This is where we generate training data and test data.  The dimension of our data is 2 with lb and ub serving as how we can vectorize our input training/test data distributions to be sampled from.

In [None]:
rng_key = random.PRNGKey(0)

d = 2
lb = -2.0*np.ones(d)
ub = 2.0*np.ones(d)
n = 2000
noise = 0.1

# Create training data
X = lb + (ub-lb)*random.uniform(rng_key, (n, d))
y = vmap(f)(X)
y = y + noise*y.std(0)*random.normal(rng_key, y.shape)

# Create test data
nn = 50
xx = np.linspace(lb[0], ub[0], nn)
yy = np.linspace(lb[1], ub[1], nn)
XX, YY = np.meshgrid(xx, yy)
X_star = np.concatenate([XX.flatten()[:,None], YY.flatten()[:,None]], axis = 1)
y_star = vmap(f)(X_star)

Now we define the dimensions of each layer in the model, and using that, we can define the MLP model we'd like to train.

In [None]:
layers = [2, 32, 32, 32, 32, 1]
init_method = 'glorot'
model = MLP(X, y, layers, init_method, rng_key)

Now we train the model using the generated training data.

In [None]:
model.train(num_epochs = 2000, batch_size = 128)

After our model is trained, we can output the optimized parameters and the prediction of our test data.

In [None]:
opt_params = model.net_params
y_pred = model.predict(opt_params, X_star)

Finally, we can generate a visualization of our model's prediction using the optimal parameters and test data.  We first plot the surface that our model uses
as a prediction for the function.  Then, we plot the Ground Truth as a function of the predicted values of our model (to gauge model performance in testing).  Lastly, we plot the log(loss) of our model over the course of training.

In [None]:
Yplot = griddata(X_star, y_pred.flatten(), (XX, YY), method='cubic')

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(X[:,0], X[:,1], y, 'r.', ms = 6, alpha = 0.5)
ax.plot_surface(XX, YY, Yplot, alpha = 0.8)
# Hide grid lines
ax.grid(False)
# Hide axes ticks
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_zlabel('$y$')

plt.figure()
plt.plot(y_pred, y_star, 'r.', ms = 8, alpha = 0.5)
plt.plot(y_star, y_star, 'k--', lw = 3, alpha = 0.5)
plt.xlabel('Prediction')
plt.ylabel('Ground truth')

plt.figure()
plt.plot(model.loss_log)
plt.yscale('log')
plt.xscale('log')

Now, we can take all the data for our model and compute the activations of the optimized model.  We can also compute the gradient of the loss with respect to the parameters and input data.

In [None]:
full_batch = model.X, model.y
activations = model.compute_activations(opt_params, X_star)
weight_grads, bias_grads = grad(model.loss)(opt_params, full_batch)

Now, we can estimate the probability density function for the activation layers and the probability density function for the gradient of the weights.  This is done for each of the layers in our MLP.  Being honest, I am a bit lost on what exactly this cell is doing.

In [None]:
from scipy.stats import kde

def kde1d(x):
  nn = 1000
  xi = np.linspace(x.min(), x.max(), nn)
  k = kde.gaussian_kde(x)
  zi = k(xi)
  return xi,zi

plt.figure(figsize=(8,4))
plt.subplot(1,2,1)
for i in range(1,len(activations)-1):
  x, z = kde1d(activations[i].flatten())
  plt.plot(x, z, label = 'Layer %d' % (i))
plt.legend()
plt.xlabel('Activation')
plt.ylabel('Density')
plt.tight_layout()
plt.subplot(1,2,2)
for i in range(1,len(weight_grads)-1):
  x, z = kde1d(weight_grads[i].flatten())
  plt.plot(x, z, label = 'Layer %d' % (i))
plt.legend()
plt.xlabel('Gradient')
plt.ylabel('Density')
plt.tight_layout()

# **Problem 1 (CNN)**

First, we again import the necessary tools in the JAX library.  Notably, we also include the useful built in functions specifically for Convolutional Neural Networks.  We also import pytorch (for data acquisition), the Adam Optimizer, necessary plotting utilities, and the partial decorator.

In [None]:
import jax.numpy as np
from jax import random
from jax.experimental import stax
from jax.experimental.stax import BatchNorm, Conv, Dense, Flatten, Relu, Softmax
from jax.experimental import optimizers
from jax import jit, grad
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
from torchvision import datasets, transforms

import numpy as onp
from tqdm import tqdm
import itertools
from functools import partial
import time

Here we first define a function whose input is the labels and number of classes in our CNN, and whose output is a one hot encoding of the labels.  The second function is the proper initialization for a convolutional network with 5 Convolutions, each followed by a normalization operation, and a Rectified Linear Unit Activation function.  We have a densely connected layer which computes the class scores and then apply a Softmax activation function to generate our output.  The output will be a classifcation of the input according to the classes we have given.  This is all done in a serial manner.

In [None]:
def one_hot(labels, num_classes, dtype=np.float32):
  return np.array(labels[:, None] == np.arange(num_classes), dtype)

def mnist_simple_cnn(num_classes):
  init_fun, conv_net = stax.serial(Conv(out_chan=32, filter_shape=(5, 5), strides=(2, 2), padding="SAME"),
                                   BatchNorm(), Relu,
                                   Conv(out_chan=32, filter_shape=(5, 5), strides=(2, 2), padding="SAME"),
                                   BatchNorm(), Relu,
                                   Conv(out_chan=10, filter_shape=(3, 3), strides=(2, 2), padding="SAME"),
                                   BatchNorm(), Relu,
                                   Conv(out_chan=10, filter_shape=(3, 3), strides=(2, 2), padding="SAME"), 
                                   BatchNorm(), Relu,
                                   Flatten,
                                   Dense(num_classes),
                                   Softmax)
  return init_fun, conv_net

Now we define the class which will be used to generate and train a Convolutional Neural Network which can classify input data into user given classes.

In [None]:
class CNNclassifier:

  """
  First, we must intialize the network.  We need
  to keep track of the number of classes.  We use the
  above functions to the initialize the network.  We then
  initialize the parameters for the network and instantiate
  the Adam optimizer to be used for updating parameters.
  Lastly, we set up a counter for iterations, and arrays to
  track accuracy and loss.
  """

  # Initialize the class
  def __init__(self, num_classes, rng_key):
    # Store number of classes
    self.num_classes = num_classes

    # Use stax to set up network initialization and evaluation functions
    self.net_init, self.net_apply = mnist_simple_cnn(self.num_classes)
    
    # Initialize parameters, not committing to a batch shape
    _, self.net_params = self.net_init(rng_key, (-1, 1, 28, 28))
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                    decay_steps=100, 
                                                                    decay_rate=0.99))
    self.opt_state = self.opt_init(self.net_params)

    # Logger
    self.itercount = itertools.count()
    self.log_acc_train = []
    self.log_acc_test = [] 
    self.train_loss = []
  
  """
  We take as input the parameters and the current batch
  of data, and output the mean squared error loss.  
  """

  # Define a simple mean squared-error loss
  def loss(self, params, batch):
      inputs, targets = batch
      predictions = self.net_apply(params, inputs)
      loss = -np.sum(targets*np.log(predictions + 1e-8))
      return loss
    
  """
  We take as input the iteration number, the current
  state of the Adam optimizer, and the current batch
  of data.  We output the updated state of the Adam
  optimizer.  The partial decorator indicates we'd
  like to prioritize the speed of this function's 
  compilation and runtime using software and hardware
  acceleration.
  """

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  """
  Now, we can compute the accuracy of our model
  on a given dataset with it's curent parameters.
  This is done using the a one-hot comparison
  for the each batch in a dataset/dataloader.  I'm
  not sure but I would guess that a dataloader functions
  similar to a stream in java.
  """

  def accuracy(self, params, data_loader):
    """ Compute the accuracy for a provided dataloader """
    acc_total = 0
    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs = np.array(inputs)
        targets = one_hot(np.array(targets), self.num_classes)
        target_labels = np.argmax(targets, axis=1)
        predicted_labels = np.argmax(self.net_apply(params, inputs), axis=1)
        acc_total += np.sum(predicted_labels == target_labels)
    return acc_total/len(data_loader.dataset)
  
  """
  Here we train our CNN.  This is done by
  performing our accelerated step function
  for each batch in a dataset, for a given
  number of epochs.  We also compute the training
  and test accuracy for classification in 
  each epoch.  We take as inputs the streams of 
  training data and test data, as well as the 
  number of epochs.  At the termination of this
  function, our model will be trained and tested,
  with data characterizing its accuracy, and the set
  of optimal parameters for the CNN.
  """

  # Optimize parameters in a loop
  def train(self, train_loader, test_loader, num_epochs = 1000):
    for epoch in range(num_epochs):
      start_time = time.time()
      # Run epoch
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          batch = np.array(inputs), one_hot(np.array(targets), self.num_classes)
          self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      epoch_time = time.time() - start_time
      # Compute training and validation accuracy
      self.net_params = self.get_params(self.opt_state)  
      loss = self.loss(self.net_params, batch)
      train_acc = self.accuracy(self.net_params, train_loader)
      test_acc = self.accuracy(self.net_params, test_loader)
      self.train_loss.append(loss)
      self.log_acc_train.append(train_acc)
      self.log_acc_test.append(test_acc)
      print("Epoch {} | Time: {:0.2f} | Train Acc.: {:0.3f}% | Test Acc.: {:0.3f}%".format(epoch+1, epoch_time,
                                                                  train_acc, test_acc))

Here we generate the training data stream and test data stream.  I'm not really sure how this works, but I think it is more important that I understand what it is doing.  This is essentially how we take batches sequentially from input datasets and can pull from them iteratively.

In [None]:
# Set the PyTorch Data Loader for the training & test set
batch_size = 128

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

Here, we define the number of classes for classification and initialize the Convolutional Neural Network accordingly.

In [None]:
num_classes = 10
init_key = random.PRNGKey(0)
model = CNNclassifier(num_classes, init_key)

Here, we train our CNN using the training data stream and test data stream.  We dont need many epochs for this task because the the data streams have a large amount of data.  It is possible that training for the same number of epochs as some of our MLP's (i.e. on the order of 1000 epochs), that overfitting could occur.

In [None]:
model.train(train_loader, test_loader, num_epochs = 10)

Now that we have trained the CNN, we can visualize the log(loss) over the number of epochs, as well as the model accuracy over the number of epochs.

In [None]:
plt.figure()
plt.subplot(1,2,1)
plt.plot(model.train_loss)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Training loss')
plt.subplot(1,2,2)
plt.plot(model.log_acc_train, label = 'Training')
plt.plot(model.log_acc_test, label = 'Testing')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

We can then visualize the model applying a classification to the data in the test stream.  In this case, we are attempting to classify handwritten numbers.

In [None]:
# Visualize some predictions
plt.figure(figsize = (8,8))
for batch_idx, (image, label) in enumerate(test_loader): 
  # Perform predictions
  opt_params = model.net_params
  predicted_labels = np.argmax(model.net_apply(opt_params, np.array(image)), axis=-1)
  plt.imshow(image[-1].reshape(28,28))
  plt.title('This is a %d' % (predicted_labels[-1]))
  plt.pause(0.5)

# **Problem 1 (RNN)**

Again, we first import the necessary JAX library functions, the Adam optimizer, partial decorators, and the appropriate graphing tools.  These will be used to generate, train, and test a Recurrent Neural Network, whose output accuracy/error can be visualized.

In [None]:
import jax.numpy as np
from jax import random, vmap, grad, jit
from jax.experimental import optimizers
from jax.ops import index_update, index

import itertools
from functools import partial
from tqdm import trange
import numpy.random as npr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

We define a function to generate lags.  We take as input a dataset and a lag value.  We output a 3D array of the inputs and empirical outputs X,Y whose index in a particular layer is "lagged" according to the lag value.

In [None]:
def create_lags(data, L):
    N = data.shape[0] - L
    D = data.shape[1]
    X = np.zeros((L, N, D))
    Y = np.zeros((N, D))
    for i in range(0,N):
        X = index_update(X, index[:,i,:], data[i:(i+L), :])
        Y = index_update(Y, index[i,:], data[i+L, :])
    return X, Y

Here we develop the class that can generate, train, and test a Reccuring Neural Network, whose accuracy can be quantified accordingly.

In [None]:
class RNN():

  """
  First, we must initialize the RNN.  We take as inputs
  the dataset, the number of lags, and the number of hidden
  layers.  We first normalize the dataset according to its
  mean and standard deviation.  We then create our lagged
  normalized dataset using the above function.  We initialize
  the network, apply rule, and parameters.  Then, we designate
  the Adam optimizer and its state, as well as an array to track
  the loss and a counter variable.
  """
  def __init__(self, dataset, num_lags, hidden_dim, rng_key = random.PRNGKey(0)):
    # Normalize across data-points dimension
    self.mean, self.std = dataset.mean(0), dataset.std(0)
    dataset = (dataset - self.mean)/self.std

    # Create the lagged normalized trainind data
    # X: L x N x D
    # Y: N x D
    self.X, self.Y = create_lags(dataset, num_lags)
    self.X_dim = self.X.shape[-1]
    self.Y_dim = self.Y.shape[-1]
    self.hidden_dim = hidden_dim
    self.num_lags = num_lags

    # Initialization and evaluation functions
    self.net_init, self.net_apply = self.init_RNN()
    
    # Initialize parameters, not committing to a batch shape
    self.net_params = self.net_init(rng_key)
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.adam(1e-3)
    self.opt_state = self.opt_init(self.net_params)

    # Logger to monitor the loss function
    self.loss_log = []
    self.itercount = itertools.count()

  """
  Here we initialize the RNN.  We use the glorot
  initialization according the number of hidden
  layers and the dimension of the lagged normalized
  input data X.  We also initialize the biases, weights,
  and the output layer.  We output all of these.  We
  also define the apply function for our RNN.  This
  takes as input the current parameters and some input
  data.  We apply a hyperbolic activation function to
  (H*W + input*U + b).  This performs the application of 
  the RNN to input data, according to the activation
  functions.
  """

  def init_RNN(self):
    # Define init function
    def _init(rng_key):
        # Define methods for initializing the weights
        def glorot_normal(rng_key, size):
          in_dim = size[0]
          out_dim = size[1]
          glorot_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
          return glorot_stddev*random.normal(rng_key, (in_dim, out_dim))
        # Inputs
        U = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        b = np.zeros(self.hidden_dim)
        # Transition dynamics
        W = np.eye(self.hidden_dim)
        # Outputs
        V = glorot_normal(rng_key, (self.hidden_dim, self.Y_dim))
        c = np.zeros(self.Y_dim)
        return (U, b, W, V, c)
    # Define apply function
    def _apply(params, input):
        U, b, W, V, c = params
        H = np.zeros((input.shape[1], self.hidden_dim))
        for i in range(self.num_lags):
            H = np.tanh(np.matmul(H, W) + np.matmul(input[i,:,:], U) + b)       
        H = np.matmul(H, V) + c
        return H
    return _init, _apply

  """
  Here we calculate the mean squared error loss on an input
  of the current parameters and the current batch of data.
  """

  def loss(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    loss = np.mean((y - y_pred)**2)
    return loss

  """
  Here we definte the step function of our optimizer.
  We take as inputs the iteration number, current Adam
  optimizer state, and current batch of data.  We then
  output the updated state of the Adam optimizer using
  the given data.  We use the partial decorator to indicate
  that we will be using this function many times in rapid
  succession, so it is best to prioritize the speed of its
  compilation and processing using hardware and software
  accleration.
  """

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  """
  This function is used to create each batch from a given data set and output it
  in a stream.  We take as inputs the dimension of the training data, the number
  of batches, and the desired batch size.  We output the training data as the 
  desired number of batches whose dimensions match the desired size.
  """

  def data_stream(self, n, num_batches, batch_size):
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(n)
      for i in range(num_batches):
        batch_idx = perm[i*batch_size:(i+1)*batch_size]
        yield self.X[:, batch_idx, :], self.Y[batch_idx, :]

  """
  This is how we train the RNN.  We perform the accelerated
  update rule on each batch in the training data stream, for
  a given number of epochs.  We also update our parameters and
  loss at each epoch.  We take as input te 
  """

  def train(self, num_epochs = 100, batch_size = 64):   
    n = self.X.shape[1]
    num_complete_batches, leftover = divmod(n, batch_size)
    num_batches = num_complete_batches + bool(leftover) 
    batches = self.data_stream(n, num_batches, batch_size)
    pbar = trange(num_epochs)
    for epoch in pbar:
      for _ in range(num_batches):
        batch = next(batches)
        self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      self.net_params = self.get_params(self.opt_state)
      loss_value = self.loss(self.net_params, batch)
      self.loss_log.append(loss_value)
      pbar.set_postfix({'Loss': loss_value})

  """
  Finally, we have a function that can take in the current
  parameters and some input test data to make a prediction
  using the RNN.  We again use the partial decorator to speed
  up the testing of our RNN.
  """

  @partial(jit, static_argnums=(0,))
  def predict(self, params, inputs):
    Y_pred = self.net_apply(params, inputs)
    return Y_pred

Here we define a function that we'd like to model using an RNN.

In [None]:
def f(x):
    f = np.sin(np.pi*t)
    return f

Now, we can generate the training and test data sets using linearly spaced inputs and the function output as validation.  We use 2/3 of the data for
training, which is a reasonable split.

In [None]:
rng_key = random.PRNGKey(0)
noise = 0.0

t = np.arange(0,10,0.1)[:,None]
dataset = f(t)
dataset = dataset + dataset.std(0)*noise*random.normal(rng_key, dataset.shape)

# Use 2/3 of all data as training Data
train_size = int(len(dataset) * (2.0/3.0))
train_data = dataset[0:train_size,:]

Here, we create our Recurrent Neural Network with 5 lags and 4 hidden layers, with the generated training data.

In [None]:
# Model creation
num_lags = 5
hidden_dim = 4
model = RNN(train_data, num_lags, hidden_dim, rng_key)

We train the model on 10,000 epochs and designate the batch size to be 128.

In [None]:
model.train(num_epochs = 10000, batch_size = 128)

Now that our model is trained according to the generated training data, we can extract the optimized parameters.  We can then make a prediction of the normalized data according the parameters and the number of lags.  We output the de-normalized predictions, as well as teh L2 error for our testing.

In [None]:
opt_params = model.net_params
# One-step ahead prediction (normalized)
N, D = dataset.shape
pred = np.zeros((N-num_lags, D))
X_tmp =  model.X[:,0:1,:]
for i in trange(N-num_lags):
    pred = index_update(pred, index[i:i+1], model.net_apply(opt_params, X_tmp))
    X_tmp = index_update(X_tmp, index[:-1,:,:], X_tmp[1:,:,:])
    X_tmp = index_update(X_tmp, index[-1,:,:], pred[i])
# De-normalize predictions
pred = pred*model.std + model.mean
error = np.linalg.norm(dataset[num_lags:] - pred, 2)/np.linalg.norm(dataset[num_lags:], 2)
print('Relative L2 prediction error: %e' % (error))

Finally, we can visualize the how closely our trained model can predict the output of the desired function.

In [None]:
plt.figure(1)
plt.plot(dataset[num_lags:], 'b-', linewidth = 2, label = "Exact")
plt.plot(pred, 'r--', linewidth = 3, label = "Prediction")
# plt.plot(X.shape[1]*np.ones((2,1)), np.linspace(-1.75,1.75,2), 'k--', linewidth=2)
plt.axvline(train_size)
plt.axis('tight')
plt.xlabel('$t$')
plt.ylabel('$y_t$')
plt.legend(loc='lower left')

# **Problem 2**

We'd like to generate a Deep Neural Network to model a given function f(x), such that there are 3 hidden layers, 50 neurons per layer, and a hyperbolic tangent activation function.  We will use the framework given in class for an RNN.

In [None]:
import jax.numpy as np
from jax import random, vmap, grad, jit
from jax.flatten_util import ravel_pytree
from jax.experimental import optimizers

import itertools
from functools import partial
from tqdm import trange
import numpy.random as npr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

We will use the code from part 1 on number 1 for the MLP class.

In [None]:
def f(x):
  x1, x2 = x[0], x[1]
  f = np.sin(2*np.pi*x1) + np.sin(3*np.pi*x2) + np.sin(4*np.pi*x2)
  return f

In [None]:
rng_key = random.PRNGKey(0)

d = 2
lb = 3.0*np.ones(d)
ub = 4.0*np.ones(d)
n = 2000
noise = 0.1

# Create training data
X = lb + (ub-lb)*random.uniform(rng_key, (n, d))
y = vmap(f)(X)
y = y + noise*y.std(0)*random.normal(rng_key, y.shape)

# Create test data
nn = 50
xx = np.linspace(lb[0], ub[0], nn)
yy = np.linspace(lb[1], ub[1], nn)
XX, YY = np.meshgrid(xx, yy)
X_star = np.concatenate([XX.flatten()[:,None], YY.flatten()[:,None]], axis = 1)
y_star = vmap(f)(X_star)

In [None]:
# Model Creation
layers = [2, 50, 50, 50, 1]
init_method = 'glorot'
model = MLP(X, y, layers, init_method, rng_key)

In [None]:
model.train(num_epochs = 500, batch_size = 128)

In [None]:
opt_params = model.net_params
y_pred = model.predict(opt_params, X_star)

error = np.linalg.norm(y_star - y_pred, 2)/np.linalg.norm(y_star, 2)
print('Relative L2 prediction error: %e' % (error))

Relative L2 prediction error: 4.929670e+01


In [None]:
Yplot = griddata(X_star, y_pred.flatten(), (XX, YY), method='cubic')

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.plot(X[:,0], X[:,1], y, 'r.', ms = 6, alpha = 0.5)
ax.plot_surface(XX, YY, Yplot, alpha = 0.8)
# Hide grid lines
ax.grid(False)
# Hide axes ticks
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('$x_1$')
ax.set_ylabel('$x_2$')
ax.set_zlabel('$y$')

plt.figure()
plt.plot(y_pred, y_star, 'r.', ms = 8, alpha = 0.5)
plt.plot(y_star, y_star, 'k--', lw = 3, alpha = 0.5)
plt.xlabel('Prediction')
plt.ylabel('Ground truth')

plt.figure()
plt.plot(model.loss_log)
plt.yscale('log')
plt.xscale('log')


In [None]:
rnge = np.array([50, 100, 250, 500, 1000, 2000, 5000, 10000])
error = []

for n in rnge:
  rng_key = random.PRNGKey(0)

  d = 2
  lb = 3.0*np.ones(d)
  ub = 4.0*np.ones(d)
  noise = 0.1

  # Create training data
  X = lb + (ub-lb)*random.uniform(rng_key, (n, d))
  y = vmap(f)(X)
  y = y + noise*y.std(0)*random.normal(rng_key, y.shape)

  # Create test data
  nn = 50
  xx = np.linspace(lb[0], ub[0], nn)
  yy = np.linspace(lb[1], ub[1], nn)
  XX, YY = np.meshgrid(xx, yy)
  X_star = np.concatenate([XX.flatten()[:,None], YY.flatten()[:,None]], axis = 1)
  y_star = vmap(f)(X_star)

  # Model Creation
  layers = [2, 50, 50, 50, 1]
  init_method = 'glorot'
  model = MLP(X, y, layers, init_method, rng_key)

  # Model Training
  model.train(num_epochs = 1000, batch_size = 128)

  # Determining the L2 error
  opt_params = model.net_params
  y_pred = model.predict(opt_params, X_star)

  error.append(np.linalg.norm(y_star - y_pred, 2)/np.linalg.norm(y_star, 2))

plt.figure()
plt.plot(rnge, error)

# **Problem 3**

In [None]:
import jax.numpy as np
from jax import random
from jax.experimental import stax
from jax.experimental.stax import BatchNorm, Conv, Dense, Flatten, Relu, Softmax, MaxPool
from jax.experimental import optimizers
from jax import jit, grad
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import torch
from torchvision import datasets, transforms

import numpy as onp
from tqdm import tqdm
import itertools
from functools import partial
import time

Here we initialize the same one hot encoding for the 10 labels for the CIFAR10 dataset.  Additionally, we create the CIFAR10 Convolutional Neural Network according to the architechture given in the homework.

In [None]:
def one_hot(labels, num_classes, dtype=np.float32):
  return np.array(labels[:, None] == np.arange(num_classes), dtype)

def cifar_simple_cnn(num_classes):
  init_fun, conv_net = stax.serial(Conv(out_chan=6, filter_shape=(5, 5), strides=(1, 1), padding="SAME"),
                                   Relu, MaxPool(window_shape=(2,2), strides=(2, 2), padding="SAME"),
                                   Conv(out_chan=12, filter_shape=(5, 5), strides=(1, 1), padding="SAME"),
                                   Relu, MaxPool(window_shape=(2,2), strides=(2, 2), padding="SAME"),
                                   Conv(out_chan=24, filter_shape=(5, 5), strides=(1, 1), padding="SAME"),
                                   Relu, MaxPool(window_shape=(2,2), strides=(2, 2), padding="SAME"),
                                   Flatten,
                                   Dense(384),
                                   Dense(120), Dense(84), # Hidden fully connected layers
                                   Dense(num_classes), Softmax)
  return init_fun, conv_net

We make some small adjustments to the CNN Classifier given.

In [None]:
class CNNclassifier:

  """
  First, we must intialize the network.  We need
  to keep track of the number of classes.  We use the
  above functions to the initialize the network.  We then
  initialize the parameters for the network and instantiate
  the Adam optimizer to be used for updating parameters.
  Lastly, we set up a counter for iterations, and arrays to
  track accuracy and loss.
  """

  # Initialize the class
  def __init__(self, num_classes, rng_key):
    # Store number of classes
    self.num_classes = num_classes

    # Use stax to set up network initialization and evaluation functions
    self.net_init, self.net_apply = cifar_simple_cnn(self.num_classes)
    
    # Initialize parameters, not committing to a batch shape
    _, self.net_params = self.net_init(rng_key, (-1, 3, 32, 32)) # this has changed to the dimensions of the input for CIFAR-10
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.adam(optimizers.exponential_decay(1e-3, 
                                                                    decay_steps=100, 
                                                                    decay_rate=0.99))
    self.opt_state = self.opt_init(self.net_params)

    # Logger
    self.itercount = itertools.count()
    self.log_acc_train = []
    self.log_acc_test = [] 
    self.train_loss = []
  
  """
  We take as input the parameters and the current batch
  of data, and output the mean squared error loss.  
  """

  def loss(self, params, batch):
      inputs, targets = batch
      predictions = self.net_apply(params, inputs)
      loss = -np.sum(targets*np.log(predictions + 1e-8))
      return loss
    
  """
  We take as input the iteration number, the current
  state of the Adam optimizer, and the current batch
  of data.  We output the updated state of the Adam
  optimizer.  The partial decorator indicates we'd
  like to prioritize the speed of this function's 
  compilation and runtime using software and hardware
  acceleration.
  """

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  """
  Now, we can compute the accuracy of our model
  on a given dataset with it's curent parameters.
  This is done using the a one-hot comparison
  for the each batch in a dataset/dataloader.  I'm
  not sure but I would guess that a dataloader functions
  similar to a stream in java.
  """

  def accuracy(self, params, data_loader):
    """ Compute the accuracy for a provided dataloader """

    acc_total = 0
    total_targets = []
    total_pred = []

    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs = np.array(inputs)
        targets = one_hot(np.array(targets), self.num_classes)
        target_labels = np.argmax(targets, axis=1)
        predicted_labels = np.argmax(self.net_apply(params, inputs), axis=1)
        acc_total += np.sum(predicted_labels == target_labels)

        total_targets.extend(target_labels) # We include these logs to fill our confusion matrix
        total_pred.extend(predicted_labels)

    self.M = confusion_matrix(total_targets, total_pred) # Using the built in confusion matrix

    return acc_total/len(data_loader.dataset)
  
  """
  Here we train our CNN.  This is done by
  performing our accelerated step function
  for each batch in a dataset, for a given
  number of epochs.  We also compute the training
  and test accuracy for classification in 
  each epoch.  We take as inputs the streams of 
  training data and test data, as well as the 
  number of epochs.  At the termination of this
  function, our model will be trained and tested,
  with data characterizing its accuracy, and the set
  of optimal parameters for the CNN.
  """

  # Optimize parameters in a loop
  def train(self, train_loader, test_loader, num_epochs = 1000):
    for epoch in range(num_epochs):
      start_time = time.time()
      # Run epoch
      for batch_idx, (inputs, targets) in enumerate(train_loader):
          batch = np.array(inputs), one_hot(np.array(targets), self.num_classes)
          self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      epoch_time = time.time() - start_time
      # Compute training and validation accuracy
      self.net_params = self.get_params(self.opt_state)  
      loss = self.loss(self.net_params, batch)
      train_acc = self.accuracy(self.net_params, train_loader)
      test_acc = self.accuracy(self.net_params, test_loader)
      self.train_loss.append(loss)
      self.log_acc_train.append(train_acc)
      self.log_acc_test.append(test_acc)
      print("Epoch {} | Time: {:0.2f} | Train Acc.: {:0.3f}% | Test Acc.: {:0.3f}%".format(epoch+1, epoch_time,
                                                                  train_acc, test_acc))

Here we import the CIFAR-10 Dataset for training and test data.

In [None]:
# Set the PyTorch Data Loader for the training & test set
batch_size = 128

train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

In [None]:
num_classes = 10
init_key = random.PRNGKey(0)
model = CNNclassifier(num_classes, init_key)

In [None]:
model.train(train_loader, test_loader, num_epochs = 20)

We use the same confusion matrix plotting function from last week.

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

In [None]:
plt.figure(figsize=(24, 8))
plt.subplot(1,3,1)
plt.plot(model.train_loss)
plt.yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Training loss')
plt.title('Training Loss over Time')
plt.subplot(1,3,2)
plt.plot(model.log_acc_train, label = 'Training')
plt.plot(model.log_acc_test, label = 'Testing')
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Testing Accuracy over Time')
plt.subplot(1,3,3)
cifarClasses = {'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'}
plot_confusion_matrix(model.M, cifarClasses, normalize=True)

# **Problem 4**

In [1]:
import jax.numpy as np
from jax import random, vmap, grad, jit
from jax.experimental import optimizers
from jax.ops import index_update, index

import itertools
from functools import partial
from tqdm import trange
import numpy.random as npr
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata
from scipy.integrate import odeint

First, we use the SciPy ODE integration package to generate sequence data for the system of ODE's outlined in the problem.

In [2]:
def dydt(input, t):

  x = input[0]
  y = input[1]

  mu = 5
  xdot = mu*(x - (x**3)/3 - y)
  ydot = x/mu

  return np.array([xdot, ydot])

In [None]:
# initial conditions
y0 = np.array([0.1, 0.5])

# time
t = np.linspace(0, 60, 2000)

# solution to our ODE system for the given sequence
output = odeint(dydt, y0, t)
print(output)

In [None]:
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.plot(output)
plt.title('State Trajectories')
plt.subplot(1,2,2)
plt.plot(output[:,0], output[:,1])
plt.title('System Trajectories')

In [5]:
def create_lags(data, L):
    N = data.shape[0] - L
    D = data.shape[1]
    X = np.zeros((L, N, D))
    Y = np.zeros((N, D))
    for i in range(0,N):
        X = index_update(X, index[:,i,:], data[i:(i+L), :])
        Y = index_update(Y, index[i,:], data[i+L, :])
    return X, Y

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

Now we make some small eedits to the RNN from question 1 in order to make it Long Short-Term.  We must keep track of the gates for this network, so instead of just a reccurent model, our output vector utilizes an output gate, a track of the cell state, an external input gate, and a forget gate.  We perform a series of operations on a larger number of parameters, according to the user input hyperparameters, to produce our output of a given layer.  However, we also maintain the same general architechture for RNNs.

In [6]:
class RNN():

  """
  First, we must initialize the RNN.  We take as inputs
  the dataset, the number of lags, and the number of hidden
  layers.  We first normalize the dataset according to its
  mean and standard deviation.  We then create our lagged
  normalized dataset using the above function.  We initialize
  the network, apply rule, and parameters.  Then, we designate
  the Adam optimizer and its state, as well as an array to track
  the loss and a counter variable.
  """
  def __init__(self, dataset, num_lags, hidden_dim, rng_key = random.PRNGKey(0)):
    # Normalize across data-points dimension
    self.mean, self.std = dataset.mean(0), dataset.std(0)
    dataset = (dataset - self.mean)/self.std

    # Create the lagged normalized trainind data
    # X: L x N x D
    # Y: N x D
    self.X, self.Y = create_lags(dataset, num_lags)
    self.X_dim = self.X.shape[-1]
    self.Y_dim = self.Y.shape[-1]
    self.hidden_dim = hidden_dim
    self.num_lags = num_lags

    # Initialization and evaluation functions
    self.net_init, self.net_apply = self.init_RNN()
    
    # Initialize parameters, not committing to a batch shape
    self.net_params = self.net_init(rng_key)
                
    # Use optimizers to set optimizer initialization and update functions
    self.opt_init, \
    self.opt_update, \
    self.get_params = optimizers.adam(1e-3)
    self.opt_state = self.opt_init(self.net_params)

    # Logger to monitor the loss function
    self.loss_log = []
    self.itercount = itertools.count()

  """
  Here we initialize the RNN.  We use the glorot
  initialization according the number of hidden
  layers and the dimension of the lagged normalized
  input data X.  We also initialize the biases, weights,
  and the output layer.  We output all of these.  We
  also define the apply function for our RNN.  This
  takes as input the current parameters and some input
  data.  We apply a hyperbolic activation function to
  (H*W + input*U + b).  This performs the application of 
  the RNN to input data, according to the activation
  functions.
  """

  def init_RNN(self):
    # Define init function
    def _init(rng_key):

        # Define methods for initializing the weights
        def glorot_normal(rng_key, size):
          in_dim = size[0]
          out_dim = size[1]
          glorot_stddev = 1. / np.sqrt((in_dim + out_dim) / 2.)
          return glorot_stddev*random.normal(rng_key, (in_dim, out_dim))

        # First we define inputs
        Uo = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Us = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Ui = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))
        Uf = glorot_normal(rng_key, (self.X_dim, self.hidden_dim))

        # Now we can initialize our biases to 0
        bo = np.zeros(self.hidden_dim)
        bs = np.zeros(self.hidden_dim)
        bi = np.zeros(self.hidden_dim)
        bf = np.zeros(self.hidden_dim)

        # Now we can define the transition dynamics
        Wo = np.eye(self.hidden_dim)
        Ws = np.eye(self.hidden_dim)
        Wi = np.eye(self.hidden_dim)
        Wf = np.eye(self.hidden_dim)

        # Now we can define our outputs 
        V = glorot_normal(rng_key, (self.hidden_dim, self.Y_dim))
        c = np.zeros(self.Y_dim)

        return (Uo, Us, Ui, Uf, bo, bs, bi, bf, Wo, Ws, Wi, Wf, V, c)
    # Define apply function
    def _apply(params, input):
        Uo, Us, Ui, Uf, bo, bs, bi, bf, Wo, Ws, Wi, Wf, V, c = params
        H = np.zeros((input.shape[1], self.hidden_dim))
        s_t = np.zeros((input.shape[1], self.hidden_dim))
        for i in range(self.num_lags):
            s_t_tilde = np.tanh(np.matmul(H, Ws) + np.matmul(input[i,:,:], Us) + bs)
            f_t = sigmoid(np.matmul(H, Wf) + np.matmul(input[i,:,:], Uf) + bf)
            i_t = sigmoid(np.matmul(H, Wi) + np.matmul(input[i,:,:], Ui) + bi)
            s_t = (f_t * s_t) + (i_t * s_t_tilde)
            o_t = sigmoid(np.matmul(H, Wo) + np.matmul(input[i,:,:], Uo) + bo)
            H = o_t * np.tanh(s_t)
        H = np.matmul(H, V) + c
        return H
    return _init, _apply

  """
  Here we calculate the mean squared error loss on an input
  of the current parameters and the current batch of data.
  """

  def loss(self, params, batch):
    X, y = batch
    y_pred = self.net_apply(params, X)
    loss = np.mean((y - y_pred)**2)
    return loss

  """
  Here we definte the step function of our optimizer.
  We take as inputs the iteration number, current Adam
  optimizer state, and current batch of data.  We then
  output the updated state of the Adam optimizer using
  the given data.  We use the partial decorator to indicate
  that we will be using this function many times in rapid
  succession, so it is best to prioritize the speed of its
  compilation and processing using hardware and software
  accleration.
  """

  # Define a compiled update step
  @partial(jit, static_argnums=(0,))
  def step(self, i, opt_state, batch):
      params = self.get_params(opt_state)
      g = grad(self.loss)(params, batch)
      return self.opt_update(i, g, opt_state)

  """
  This function is used to create each batch from a given data set and output it
  in a stream.  We take as inputs the dimension of the training data, the number
  of batches, and the desired batch size.  We output the training data as the 
  desired number of batches whose dimensions match the desired size.
  """

  def data_stream(self, n, num_batches, batch_size):
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(n)
      for i in range(num_batches):
        batch_idx = perm[i*batch_size:(i+1)*batch_size]
        yield self.X[:, batch_idx, :], self.Y[batch_idx, :]

  """
  This is how we train the RNN.  We perform the accelerated
  update rule on each batch in the training data stream, for
  a given number of epochs.  We also update our parameters and
  loss at each epoch.
  """

  
  def train(self, num_epochs = 20000, batch_size = 128):   
    n = self.X.shape[1]
    num_complete_batches, leftover = divmod(n, batch_size)
    num_batches = num_complete_batches + bool(leftover) 
    batches = self.data_stream(n, num_batches, batch_size)
    pbar = trange(num_epochs)
    for epoch in pbar:
      for _ in range(num_batches):
        batch = next(batches)
        self.opt_state = self.step(next(self.itercount), self.opt_state, batch)
      self.net_params = self.get_params(self.opt_state)
      loss_value = self.loss(self.net_params, batch)
      self.loss_log.append(loss_value)
      pbar.set_postfix({'Loss': loss_value})

  """
  Finally, we have a function that can take in the current
  parameters and some input test data to make a prediction
  using the RNN.  We again use the partial decorator to speed
  up the testing of our RNN.
  """

  @partial(jit, static_argnums=(0,))
  def predict(self, params, inputs):
    Y_pred = self.net_apply(params, inputs)
    return Y_pred

In [7]:
rng_key = random.PRNGKey(0)
noise = 0.0

dataset = output
dataset = dataset + dataset.std(0)*noise*random.normal(rng_key, dataset.shape)

# Use 2/3 of all data as training Data
train_size = int(len(dataset) * (2.0/3.0))
train_data = dataset[0:train_size,:]

In [24]:
# Model creation
num_lags = 8
hidden_dim = 20
model = RNN(train_data, num_lags, hidden_dim, rng_key)

In [None]:
model.train(num_epochs = 100, batch_size = 128)

In [None]:
opt_params = model.net_params
# One-step ahead prediction (normalized)
N, D = dataset.shape
pred = np.zeros((N-num_lags, D))
X_tmp =  model.X[:,0:1,:]

for i in trange(N-num_lags):
    pred = index_update(pred, index[i:i+1], model.net_apply(opt_params, X_tmp))
    X_tmp = index_update(X_tmp, index[:-1,:,:], X_tmp[1:,:,:])
    X_tmp = index_update(X_tmp, index[-1,:,:], pred[i])
# De-normalize predictions
pred = pred*model.std + model.mean
error = np.linalg.norm(dataset[num_lags:] - pred, 2)/np.linalg.norm(dataset[num_lags:], 2)
print('Relative L2 prediction error: %e' % (error))

As we see from the above output, the L2 error is fairly small.

Additionally, below, the graph of our predicted vs actual values is shown.  The prediction appears to be fairly accurate in this regard as well.

In [None]:
plt.figure(figsize=(16, 8))
plt.plot(dataset[num_lags:], 'b-', linewidth = 2, label = "Exact")
plt.plot(pred, 'r--', linewidth = 3, label = "Prediction")
# plt.plot(X.shape[1]*np.ones((2,1)), np.linspace(-1.75,1.75,2), 'k--', linewidth=2)
plt.axvline(train_size)
plt.axis('tight')
plt.xlabel('$t$')
plt.ylabel('$y_t, outputs$')
plt.legend(loc='lower left')