## Notebook Setup 
The following cell will install Drake, checkout the underactuated repository, and set up the path (only if necessary).
- On Google's Colaboratory, this **will take approximately two minutes** on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  Colab will ask you to "Reset all runtimes"; say no to save yourself the reinstall.
- On Binder, the machines should already be provisioned by the time you can run this; it should return (almost) instantly.

More details are available [here](http://underactuated.mit.edu/drake.html).

In [None]:
try:
  import pydrake
  import underactuated
except ImportError:
  !curl -s https://raw.githubusercontent.com/RussTedrake/underactuated/master/scripts/setup/jupyter_setup.py > jupyter_setup.py
  from jupyter_setup import setup_underactuated
  setup_underactuated()

# Setup matplotlib backend (to notebook, if possible, or inline).  
from underactuated.jupyter import setup_matplotlib_backend
plt_is_interactive = setup_matplotlib_backend()

# Value Iteration using Neural Networks as a Function Approximator

In this notebook, we'll use [PyTorch](https://pytorch.org/tutorials/) to implement a basic fitted value iteration algorithm using neural networks.

Note: I have done very little architecture/parameter tuning in this example (so far).  Try out your hand at improving it!  Can you get a nice reproduction of the true cost-to-go function?  (and if you do improve the code, please contribute it back!)

Let's start by setting up the double integrator plant, the neural network architecture, and two potential cost functions.

In [None]:
from IPython import get_ipython
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Define the double integrator
A = torch.tensor([[0., 1.], [0., 0.]])
B = torch.tensor([[0.], [1.]])
At = A.transpose(0, 1)
Bt = B.transpose(0, 1)

# Define the function approximator for J
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # Linear implements y = xA^T + b
        self.fc1 = nn.Linear(2, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

# Define the cost function
def min_time_cost(xt, ut):
    at_goal = torch.isclose(xt, torch.zeros(1,2))
    # cost = 1 if ~at_goal * [1;1] >= 1, 0 otherwise.
    return torch.min((~at_goal).float().matmul(torch.ones(2,1)), torch.ones(1))

Q = torch.eye(2)
R = torch.eye(1)
def quadratic_regulator_cost(xt, ut):
    return xt.matmul(Q.matmul(xt.transpose(-2,-1))) + ut.matmul(R.matmul(ut.transpose(-2,-1)))

BRinv = B*R.inverse()

# plot with pyplot
def plot(net):
  x1s = torch.linspace(-3,3,31)
  x2s = torch.linspace(-3,3,51)
  X1s, X2s = torch.meshgrid(x1s, x2s)
  X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)
  
  with torch.no_grad():
    J = net.forward(X)

  fig = plt.figure(figsize=(9, 4))
  ax = fig.subplots(1, 1, subplot_kw=dict(projection='3d'))
  ax.set_xlabel("q")
  ax.set_ylabel("qdot")
  ax.set_title("Cost-to-Go")
  ax.plot_surface(X1s, X2s, J.view(X1s.size()).detach().numpy(), rstride=1, cstride=1, cmap=cm.jet)

## Discrete time, continuous state, discrete action

This is the standard "fitted value iteration" algorithm with a torch network as the function approximator, and a single step of gradient descent performed on each iteration.

In [None]:
net = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)
#running_cost = min_time_cost
running_cost = quadratic_regulator_cost
timestep = 0.1

x1s = torch.linspace(-3,3,31)
x2s = torch.linspace(-3,3,51)
us = torch.linspace(-1,1,9)

X1s, X2s = torch.meshgrid(x1s, x2s)
# Want x as batch row vectors... size [num_state_samples, 1, num_states]
# (because the linear units in net expect row vectors)
X = torch.stack((X1s.flatten(), X2s.flatten()), 1).unsqueeze(1)

X1s, X2s, Us = torch.meshgrid(x1s, x2s, us)
# XwithU has size [num_state_samples, num_input_samples, 1, num_states]
# UwithX has size [num_state_samples, num_input_samples, 1, num_inputs]
XwithU = torch.stack((X1s.flatten(0,1), X2s.flatten(0,1)), 2).unsqueeze(2)
UwithX = Us.flatten(0,1).unsqueeze(-1).unsqueeze(-1)

Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))
G = timestep*running_cost(XwithU, UwithX)

target_net = Net()
for epoch in range(1000 if get_ipython() else 2):
  net.zero_grad()
  target_net.load_state_dict(net.state_dict())
  with torch.no_grad():
    Jnext = target_net.forward(Xnext)
    Jd, ind = torch.min(G + Jnext, dim=1)
  J = net.forward(X)
  loss = criterion(J, Jd)
  loss.backward()
  optimizer.step()
  
  if epoch % 20 == 19:
    print('[%d] loss: %.3f' % (epoch + 1, loss.item()))

plot(net)

Here is a similar version, but with states chosen at random on each iteration

In [None]:
net = Net()
criterion = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)
#running_cost = min_time_cost
running_cost = quadratic_regulator_cost
timestep = 0.1

us = torch.linspace(-1,1,9)
num_state_samples = 1000
XwithU = torch.empty((num_state_samples, us.size()[-1], 1, 2))
UwithX = torch.empty((num_state_samples, us.size()[-1], 1, 1))

target_net = Net()
for epoch in range(1000 if get_ipython() else 2):
  X = 6*torch.rand((num_state_samples, 1, 2)) - 3.
  X[0, :, :] = torch.zeros(1,2) # make sure zero appears

  for i in range(us.size()[-1]):
    XwithU[:, i, :, :] = X
    UwithX[:, i, :, :] = us[i]
      
  Xnext = XwithU + timestep * (XwithU.matmul(At) + UwithX.matmul(Bt))
  G = timestep*running_cost(XwithU, UwithX)
  
  net.zero_grad()
  target_net.load_state_dict(net.state_dict())
  with torch.no_grad():
    Jnext = target_net.forward(Xnext)
    Jd, ind = torch.min(G + Jnext, dim=1)
  J = net.forward(X)
  loss = criterion(J, Jd)
  loss.backward()
  optimizer.step()
  
  if epoch % 20 == 19:
    print('[%d] loss: %.6f' % (epoch + 1, loss.item()))

plot(net)