Skip to content

Commit

Permalink
Ported ML-CasADi functionality to RealTimeL4CasADi.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim-Salzmann committed Sep 9, 2023
1 parent 139232c commit 62219f6
Show file tree
Hide file tree
Showing 8 changed files with 396 additions and 11 deletions.
1 change: 1 addition & 0 deletions l4casadi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ctypes

from .l4casadi import L4CasADi, dynamic_lib_file_ending
from .realtime import RealTimeL4CasADi


file_dir = files('l4casadi')
Expand Down
30 changes: 19 additions & 11 deletions l4casadi/l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class L4CasADi(object):
def __init__(self,
model: Callable[[torch.Tensor], torch.Tensor],
model_expects_batch_dim: bool = True,
device: Union[torch.device, Text] = "cpu",
name: Text = "l4casadi_f",
device: Union[torch.device, Text] = 'cpu',
name: Text = 'l4casadi_f',
build_dir: Text = './_l4c_generated'):
"""
:param model: PyTorch model.
Expand All @@ -44,7 +44,7 @@ def __init__(self,

self.build_dir = pathlib.Path(build_dir)

self._ext_cs_fun: Optional[cs.Function] = None
self._cs_fun: Optional[cs.Function] = None
self._built = False

def __call__(self, *args):
Expand All @@ -62,7 +62,7 @@ def forward(self, inp: Union[cs.MX, cs.SX, cs.DM]):
if not self._built:
self.build(inp)

out = self._ext_cs_fun(inp) # type: ignore[misc]
out = self._cs_fun(inp) # type: ignore[misc]

return out

Expand Down Expand Up @@ -93,7 +93,7 @@ def build(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None:
self.generate_cpp_function_template(rows, cols, has_jac, has_hess)
self.compile_cs_function()

self._ext_cs_fun = cs.external(
self._cs_fun = cs.external(
f'{self.name}',
f"{self.build_dir / f'lib{self.name}'}{dynamic_lib_file_ending()}"
)
Expand Down Expand Up @@ -153,6 +153,18 @@ def compile_cs_function(self):
if status != 0:
raise Exception(f'Compilation failed!\n\nAttempted to execute OS command:\n{os_cmd}\n\n')

def _trace_jac_model(self, inp):
if self.has_batch:
return make_fx(vmap(jacrev(self.model)))(inp)
else:
return make_fx(jacrev(self.model))(inp)

def _trace_hess_model(self, inp):
if self.has_batch:
return make_fx(vmap(hessian(self.model)))(inp)
else:
return make_fx(hessian(self.model))(inp)

def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]:
if self.has_batch:
d_inp = torch.zeros((1, rows))
Expand All @@ -164,12 +176,8 @@ def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]:

torch.jit.trace(self.model, d_inp).save((out_folder / f'{self.name}_forward.pt').as_posix())

if self.has_batch:
jac_model = make_fx(vmap(jacrev(self.model)))(d_inp)
hess_model = make_fx(vmap(hessian(self.model)))(d_inp)
else:
jac_model = make_fx(jacrev(self.model))(d_inp)
hess_model = make_fx(hessian(self.model))(d_inp)
jac_model = self._trace_jac_model(d_inp)
hess_model = self._trace_hess_model(d_inp)

exported_jacrev = self._jit_compile_and_save(
jac_model,
Expand Down
18 changes: 18 additions & 0 deletions l4casadi/realtime/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Real-time L4CasADi
This is the underlying framework enabling Real-time Neural-MPC in our paper
```
Real-time Neural-MPC: Deep Learning Model Predictive Control for Quadrotors and Agile Robotic Platforms
```
[Arxiv Link](https://arxiv.org/pdf/2203.07747)

## Citing
If you use our work please cite our paper
```
@article{salzmann2023neural,
title={Real-time Neural-MPC: Deep Learning Model Predictive Control for Quadrotors and Agile Robotic Platforms},
author={Salzmann, Tim and Kaufmann, Elia and Arrizabalaga, Jon and Pavone, Marco and Scaramuzza, Davide and Ryll, Markus},
journal={IEEE Robotics and Automation Letters},
doi={10.1109/LRA.2023.3246839},
year={2023}
}
```
1 change: 1 addition & 0 deletions l4casadi/realtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .realtime_l4casadi import RealTimeL4CasADi
48 changes: 48 additions & 0 deletions l4casadi/realtime/examples/readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import l4casadi as l4c
import casadi as cs
import numpy as np
import torch


class MultiLayerPerceptron(torch.nn.Module):
def __init__(self):
super().__init__()

self.input_layer = torch.nn.Linear(2, 512)

hidden_layers = []
for i in range(20):
hidden_layers.append(torch.nn.Linear(512, 512))

self.hidden_layer = torch.nn.ModuleList(hidden_layers)
self.out_layer = torch.nn.Linear(512, 1)

def forward(self, x):
x = self.input_layer(x)
for layer in self.hidden_layer:
x = torch.tanh(layer(x))
x = self.out_layer(x)
return x


pyTorch_model = MultiLayerPerceptron()

size_in = 2
size_out = 1
l4c_model = l4c.RealTimeL4CasADi(pyTorch_model, approximation_order=1) # approximation_order=2

x_sym = cs.MX.sym('x', 2, 1)
y_sym = l4c_model(x_sym)

casadi_func = cs.Function('model_rt_approx',
[x_sym, l4c_model.get_sym_params()],
[y_sym])

x = np.ones([1, size_in]) # torch needs batch dimension
casadi_param = l4c_model.get_params(x)
casadi_out = casadi_func(x.transpose((-2, -1)), casadi_param) # transpose for vector rep. expected by casadi

t_out = pyTorch_model(torch.tensor(x, dtype=torch.float32))

print(casadi_out)
print(t_out)
151 changes: 151 additions & 0 deletions l4casadi/realtime/realtime_l4casadi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Union, Callable, Text, List

import casadi as cs
import numpy as np
import torch

from l4casadi import L4CasADi
from .sensitivities import batched_jacobian, batched_hessian


class RealTimeL4CasADi(L4CasADi):
def __init__(self,
model: Callable[[torch.Tensor], torch.Tensor],
approximation_order: int = 1,
device: Union[torch.device, Text] = 'cpu',
name: Text = 'rt_l4casadi_f'):
"""
:param model: PyTorch model.
:param approximation_order: Order of the Taylor approximation. 1 for linearization, 2 for quadratic
approximation.
:param device: Device on which the PyTorch model is executed.
:param name: Unique name of the generated L4CasADi model. This name is used for autogenerated files.
Creating two L4CasADi models with the same name will result in overwriting the files of the first model.
"""
super().__init__(model, model_expects_batch_dim=True, device=device, name=name)

if approximation_order > 2 or approximation_order < 1:
raise ValueError("Taylor approximation order must be 1 or 2.")

self._approximation_order = approximation_order
self._taylor_params = None

@property
def order(self):
return self._approximation_order

def _init_taylor_params(self, rows_in: int, rows_out: int):
a = cs.MX.sym('a', rows_in, 1)
f_a = cs.MX.sym('f_a', rows_out, 1)
df_a = cs.MX.sym('df_a', rows_out, rows_in)

if self.order == 2:
ddf_as = []
for i in range(rows_out):
ddf_a_i = cs.MX.sym(f'ddf_a_{i}', rows_in, rows_in)
ddf_as.append(ddf_a_i)
return a, f_a, df_a, ddf_as
else:
return a, f_a, df_a

@property
def sym_params(self):
return self._flatten_taylor_params(self._taylor_params)

def get_sym_params(self):
if len(self.sym_params) == 0:
return cs.vertcat([])
return cs.vcat([cs.reshape(mx, np.prod(mx.shape), 1) for mx in self.sym_params])

def _get_params(self, a_t: torch.Tensor):
if len(a_t.shape) == 1:
a_t = a_t.unsqueeze(0)
if self._approximation_order == 1:
df_a, f_a = batched_jacobian(self.model, a_t, return_func_output=True)
return [a_t.cpu().numpy(), f_a.cpu().numpy(), df_a.transpose(-2, -1).cpu().numpy()]
elif self._approximation_order == 2:
ddf_a, df_a, f_a = batched_hessian(self.model, a_t, return_func_output=True, return_jacobian=True)
return ([a_t.cpu().numpy(), f_a.cpu().numpy(), df_a.transpose(-2, -1).cpu().numpy()]
+ [ddf_a[:, i].transpose(-2, -1).cpu().numpy() for i in range(ddf_a.shape[1])])

def get_params(self, a: Union[np.array, torch.Tensor]):
a_t = torch.tensor(a).float().to(self.device)
params = self._get_params(a_t)

if len(params) == 0:
return np.array([])
if len(a.shape) > 1:
return np.hstack([p.reshape(p.shape[0], -1) for p in params])
return np.hstack([p.flatten() for p in params])

def taylor_approx(self, x: cs.MX, a: cs.MX, f_a: cs.MX, df_a: cs.MX, ddf_a: List[cs.MX] = None, parallel=False):
"""
Approximation using first or second order Taylor Expansion
"""
x_minus_a = x - a
if ddf_a is None:
return (f_a
+ cs.mtimes(df_a, x_minus_a))
else:
if parallel:
# Using OpenMP to parallel compute second order term of Taylor for all output dims

def second_order_oi_term(x_minus_a, f_ddf_a):
return cs.mtimes(cs.transpose(x_minus_a), cs.mtimes(f_ddf_a, x_minus_a))

ddf_a_expl = ddf_a[3]
x_minus_a_exp = cs.MX.sym('x_minus_a_exp', x_minus_a.shape[0], x_minus_a.shape[1])
second_order_term_oi_fun = cs.Function('second_order_term_fun',
[x_minus_a_exp, ddf_a_expl],
[second_order_oi_term(x_minus_a_exp, ddf_a_expl)])

n_o = f_a.shape[0]

second_order_term_fun = second_order_term_oi_fun.map(n_o, 'openmp')

x_minus_a_rep = cs.repmat(x_minus_a, 1, n_o)
f_ddf_a_stack = cs.hcat(ddf_a)

second_order_term = 0.5 * cs.transpose(second_order_term_fun(x_minus_a_rep, f_ddf_a_stack))
else:
f_ddf_as = ddf_a
second_order_term = 0.5 * cs.vcat(
[cs.mtimes(cs.transpose(x_minus_a), cs.mtimes(f_ddf_a, x_minus_a))
for f_ddf_a in f_ddf_as])

return (f_a
+ cs.mtimes(df_a, x_minus_a)
+ second_order_term)

@staticmethod
def _flatten_taylor_params(taylor_params):
flat_params = list()
for param in taylor_params:
if isinstance(param, cs.MX):
flat_params.append(param)
else:
flat_params.extend(param)
return flat_params

def build(self, inp: Union[cs.MX, cs.SX, cs.DM]) -> None:
rows, cols = inp.shape # type: ignore[attr-defined]
rows_out = self.model(torch.zeros(1, rows).to(self.device)).shape[-1]

self._taylor_params = self._init_taylor_params(rows, rows_out)
self._cs_fun = cs.Function(
f'taylor_approx_{self.name}',
[inp] + self.sym_params,
[self.taylor_approx(inp, *self._taylor_params)])

self._built = True

def forward(self, inp: Union[cs.MX, cs.SX, cs.DM]):
if not inp.shape[-1] == 1: # type: ignore[attr-defined]
raise ValueError("RealTimeL4CasADi only accepts vector inputs.")

if not self._built:
self.build(inp)

out = self._cs_fun(inp, *self.sym_params) # type: ignore[misc]

return out
83 changes: 83 additions & 0 deletions l4casadi/realtime/sensitivities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Callable

import torch
import torch.func as functorch


def aux_function(func):
def inner_aux(inputs):
out = func(inputs)
return out, out

return inner_aux


def batched_jacobian(func: Callable, inputs: torch.Tensor, create_graph=False, return_func_output=False):
r"""Function that computes batches of the Jacobian of a given function and a batch of inputs.
Args:
func: a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
inputs: inputs to the function ``func``. First dimension is treated as batch dimension
create_graph: If ``True``, the Jacobian will be computed in a differentiable manner.
return_func_output: If ``True``, the function output will be returned.
Returns:Jacobian
"""

if not create_graph:
with torch.no_grad():
if not return_func_output:
return functorch.vmap(functorch.jacrev(func))(inputs)
return functorch.vmap(functorch.jacrev(aux_function(func), has_aux=True))(inputs)
else:
if not return_func_output:
return functorch.vmap(functorch.jacrev(func))(inputs)
return functorch.vmap(functorch.jacrev(aux_function(func), has_aux=True))(inputs)


def batched_hessian(func: Callable, inputs: torch.Tensor, create_graph=False,
return_jacobian=False, return_func_output=False):
r"""
Args:
func: a Python function that takes Tensor inputs and returns a tuple of Tensors or a Tensor.
inputs: inputs to the function ``func``. First dimension is treated as batch dimension
create_graph: If ``True``, the Hessian will be computed in a differentiable manner.
return_jacobian: If ``True``, the Jacobian will be returned.
return_func_output: If ``True``, the function output will be returned.
Returns: Hessian
"""
def aux_function_jac(func):
def inner_aux(inputs):
out = func(inputs)
return out[0], (out[0], out[1])
return inner_aux

if not create_graph:
with torch.no_grad():
if not return_func_output and not return_jacobian:
return functorch.vmap(functorch.jacrev(functorch.jacrev(func)))(inputs)
elif not return_func_output and return_jacobian:
return functorch.vmap(functorch.jacrev(aux_function_jac(functorch.jacrev(func)), has_aux=True))(inputs)
elif return_func_output and not return_jacobian:
return functorch.vmap(functorch.jacrev(functorch.jacrev(aux_function(func), has_aux=True)))(inputs)
elif return_func_output and return_jacobian:
(hessian, (jacobian, value)) = functorch.vmap(
functorch.jacrev(aux_function_jac(functorch.jacrev(aux_function(func), has_aux=True)),
has_aux=True))(inputs)
return hessian, jacobian, value
else:
if not return_func_output and not return_jacobian:
return functorch.vmap(functorch.jacrev(functorch.jacrev(func)))(inputs)
elif not return_func_output and return_jacobian:
return functorch.vmap(functorch.jacrev(aux_function_jac(functorch.jacrev(func)), has_aux=True))(inputs)
elif return_func_output and not return_jacobian:
return functorch.vmap(functorch.jacrev(functorch.jacrev(aux_function(func), has_aux=True)))(inputs)
elif return_func_output and return_jacobian:
(hessian, (jacobian, value)) = functorch.vmap(
functorch.jacrev(aux_function_jac(functorch.jacrev(aux_function(func), has_aux=True)), has_aux=True))(
inputs)
return hessian, jacobian, value

0 comments on commit 62219f6

Please sign in to comment.