Welcome!  If you are new to Google Colab/Jupyter notebooks, you might take a look at [this notebook](https://colab.research.google.com/notebooks/basic_features_overview.ipynb) first.

**I recommend you run the first code cell of this notebook immediately, to start provisioning drake on the cloud machine, then you can leave this window open as you [read the textbook](http://underactuated.csail.mit.edu/dp.html).**

# Notebook Setup

The following cell will:
- on Colab (only), install Drake to `/opt/drake`, install Drake's prerequisites via `apt`, and add pydrake to `sys.path`.  This will take approximately two minutes on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  If you navigate between notebooks using Colab's "File->Open" menu, then you can avoid provisioning a separate machine for each notebook.
- import packages used throughout the notebook.

You will need to rerun this cell if you restart the kernel, but it should be fast (even on Colab) because the machine will already have drake installed.

In [None]:
import importlib
import sys
from urllib.request import urlretrieve

# Install drake (and underactuated).
if 'google.colab' in sys.modules and importlib.util.find_spec('underactuated') is None:
    urlretrieve(f"http://underactuated.csail.mit.edu/scripts/setup/setup_underactuated_colab.py",
                "setup_underactuated_colab.py")
    from setup_underactuated_colab import setup_underactuated
    setup_underactuated(underactuated_sha='560c2adace05eb20ebd78377582015d5b2d3859a', drake_version='0.25.0', drake_build='releases')


# Value Iteration using Neural Networks as a Function Approximator

In this notebook, we'll use [PyTorch](https://pytorch.org/tutorials/) to implement a basic fitted value iteration algorithm using neural networks.

Let's start by setting up the double integrator plant, the neural network architecture, and two potential cost functions.

In [None]:
from IPython import get_ipython
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from pydrake.all import DiscreteAlgebraicRiccatiEquation
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from underactuated.jupyter import SetupMatplotlibBackend
plt_is_interactive = SetupMatplotlibBackend()

# Define the double integrator
A = torch.tensor([[0., 1.], [0., 0.]])
B = torch.tensor([[0.], [1.]])
At = A.transpose(0, 1)
Bt = B.transpose(0, 1)

# Define the function approximator for J
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # Linear implements y = xA^T + b
        self.fc1 = nn.Linear(2, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

# Define the cost function
def min_time_cost(xt, ut):
    at_goal = torch.isclose(xt, torch.zeros(1,2))
    # cost = 1 if ~at_goal * [1;1] >= 1, 0 otherwise.
    return torch.min((~at_goal).float().matmul(torch.ones(2,1)), torch.ones(1))

Q = torch.eye(2)
R = torch.eye(1)
def quadratic_regulator_cost(xt, ut):
    return xt.matmul(Q.matmul(xt.transpose(-2,-1))) + ut.matmul(R.matmul(ut.transpose(-2,-1)))

BRinv = B*R.inverse()

def min_time_solution(xt):
  # Caveat: this does not take the time discretization (zero-order hold on u) into account.
  q = xt[:,:,0]
  qdot = xt[:,:,1]
  # mask indicates that we are in the regime where u = +1.
  mask = ((qdot < 0) & (2*q <= qdot.pow(2))) | ((qdot >= 0) & (2*q < -qdot.pow(2)))
  T = torch.empty(q.size())
  T[mask] = 2*(.5*(qdot[mask].pow(2)) - q[mask]).sqrt() - qdot[mask]
  T[~mask] = qdot[~mask] + 2*(.5*(qdot[~mask].pow(2)) + q[~mask]).sqrt()
  return T.unsqueeze(-1)
  
def quadratic_regulator_solution(xt, timestep):
  S = DiscreteAlgebraicRiccatiEquation(A=(np.eye(2)+timestep*A.numpy()),
                                       B=timestep*B.numpy(),
                                       Q=Q, R=R)
  return xt.matmul(torch.from_numpy(S).float().matmul(xt.transpose(-2,-1)))

def plot_and_compare(net, running_cost, timestep):
  x1s = torch.linspace(-3,3,31)
  x2s = torch.linspace(-3,3,51)
  X1s, X2s = torch.meshgrid(x1s, x2s)
  X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)
  
  with torch.no_grad():
    J = net.forward(X)

  fig = plt.figure(figsize=(9, 4))
  ax1, ax2 = fig.subplots(1, 2, subplot_kw=dict(projection='3d'))
  ax1.set_xlabel("q")
  ax1.set_ylabel("qdot")
  ax1.set_title("Estimated Cost-to-Go")
  ax1.plot_surface(X1s, X2s, J.view(X1s.size()).detach().numpy(), rstride=1, cstride=1, cmap=cm.jet)
  
  if running_cost == min_time_cost:
    Jd = min_time_solution(X)
  elif running_cost == quadratic_regulator_cost:
    Jd = quadratic_regulator_solution(X, timestep)

  ax2.set_xlabel("q")
  ax2.set_ylabel("qdot")
  ax2.set_title("Analytical Cost-to-Go")
  ax2.plot_surface(X1s, X2s, Jd.view(X1s.size()).detach().numpy(), rstride=1, cstride=1, cmap=cm.jet)
    
  # Score is worst absolute different (e.g. infinity-norm) of the samples
  criterion = nn.MSELoss()
  score = criterion(J, Jd).item()
  print("MSE(Ĵᵢ,Jᵢ) = %.2f" % score)   


## Discrete time, continuous state, discrete action

This is the standard "fitted value iteration" algorithm with a torch network as the function approximator, and a single step of gradient descent performed on each iteration.

In [None]:
def solve(net, optimizer, running_cost):
  criterion = nn.MSELoss()
  timestep = 0.1

  x1s = torch.linspace(-3,3,31)
  x2s = torch.linspace(-3,3,51)
  us = torch.linspace(-1,1,9)

  X1s, X2s = torch.meshgrid(x1s, x2s)
  # Want x as batch row vectors... size [num_state_samples, 1, num_states]
  # (because the linear units in net expect row vectors)
  X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)

  X1s, X2s, Us = torch.meshgrid(x1s, x2s, us)
  # XwithU has size [num_state_samples, num_input_samples, 1, num_states]
  # UwithX has size [num_state_samples, num_input_samples, 1, num_inputs]
  XwithU = torch.stack((X1s.flatten(0,1), X2s.flatten(0,1)), 2).unsqueeze(2)
  UwithX = Us.flatten(0,1).unsqueeze(-1).unsqueeze(-1)

  Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))
  G = timestep*running_cost(XwithU, UwithX)

  target_net = Net()
  for epoch in range(1000 if get_ipython() else 2):
    net.zero_grad()
    target_net.load_state_dict(net.state_dict())
    with torch.no_grad():
      Jnext = target_net.forward(Xnext)
      Jd, ind = torch.min(G + Jnext, dim=1)
    J = net.forward(X)
    loss = criterion(J, Jd)
    loss.backward()
    optimizer.step()

    if epoch % 20 == 19:
      print('[%d] loss: %.3f' % (epoch + 1, loss.item()))

  plot_and_compare(net, running_cost, timestep)

Let's apply it to the quadratic-regulator cost.

In [None]:
torch.random.manual_seed(12345)  # for scoring
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
solve(net, optimizer, quadratic_regulator_cost)

and the minimum-time cost

In [None]:
torch.random.manual_seed(12345)  # for scoring
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.0015)
solve(net, optimizer, min_time_cost)

Here is a similar version, but with states chosen at random on each iteration

In [None]:
def solve(net, optimizer, running_cost):
  criterion = nn.MSELoss()
  timestep = 0.1

  us = torch.linspace(-1,1,9)
  num_state_samples = 1000
  XwithU = torch.empty((num_state_samples, us.size()[-1], 1, 2))
  UwithX = torch.empty((num_state_samples, us.size()[-1], 1, 1))

  target_net = Net()
  for epoch in range(1000 if get_ipython() else 2):
    X = 6*torch.rand((num_state_samples, 1, 2)) - 3.
    X[0, :, :] = torch.zeros(1,2) # make sure zero appears

    for i in range(us.size()[-1]):
      XwithU[:, i, :, :] = X
      UwithX[:, i, :, :] = us[i]

    Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))
    G = timestep*running_cost(XwithU, UwithX)

    net.zero_grad()
    target_net.load_state_dict(net.state_dict())
    with torch.no_grad():
      Jnext = target_net.forward(Xnext)
      Jd, ind = torch.min(G + Jnext, dim=1)
    J = net.forward(X)
    loss = criterion(J, Jd)
    loss.backward()
    optimizer.step()

    if epoch % 20 == 19:
      print('[%d] loss: %.6f' % (epoch + 1, loss.item()))

  plot_and_compare(net, running_cost, timestep)

In [None]:
torch.random.manual_seed(12345)  # for scoring
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.01)
solve(net, optimizer, quadratic_regulator_cost)

In [None]:
torch.random.manual_seed(12345)  # for scoring
net = Net()
optimizer = optim.SGD(net.parameters(), lr=0.0015)
solve(net, optimizer, min_time_cost)

# My challenge to you

I am intensely interested in knowing how well these representations can work, compared to the known analytical solutions, but have so far not spent any time tuning the architecture nor learning parameters.  Can you do better? If you can make a substantial improvement to this example, I would love to incorporate your updates here, and record your contributions.

While there are no precise rules to this game, ideally I would like to see the same network architecture work for both problems, and a solution that does not bake in any information about the double integrator (e.g. it should work almost immediately on the pendulum, and with only a small amount of changes for the acrobot, cart-pole, etc).

<style>
th {
    font-weight:bold;
}
</style>    

<table style="width:90%; .td { text-align:left; };">
    <tr><th style="text-align:left">Contributor</th><th>MSE: Quadratic Regulator</th><th>MSE: Minimum-time problem</th><th>What changed?</th></tr>
    <tr><td style="text-align:left">Russ Tedrake</td><td>13928.50</td><td>6.37</td><td>Initial Example</td></tr>
</table>
