Skip to content

Commit

Permalink
Fixed bug with Hessian for functions with multi output.
Browse files Browse the repository at this point in the history
Hessian of multi output functions was wrong due to Fortran order of CasADi.
  • Loading branch information
Tim-Salzmann committed Nov 30, 2023
1 parent c3b9bb5 commit 53e8340
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 31 deletions.
1 change: 0 additions & 1 deletion examples/fish_turbulent_flow/trajectory_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def trajectory_generator_solver(fU, fV, dt, N, u_lim, T, GT):
}
nlp_opts = {
"ipopt.linear_solver": "mumps",
"ipopt.hessian_approximation": "limited-memory",
"ipopt.sb": "yes",
"ipopt.max_iter": 1000,
"ipopt.tol": 1e-4,
Expand Down
22 changes: 8 additions & 14 deletions l4casadi/l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import casadi as cs
import torch
try:
from torch.func import vmap, jacrev, hessian
from torch.func import jacrev, hessian, functionalize
except ImportError:
from functorch import vmap, jacrev, hessian
from functorch import jacrev, hessian, functionalize
from l4casadi.ts_compiler import ts_compile
from torch.fx.experimental.proxy_tensor import make_fx

Expand Down Expand Up @@ -200,16 +200,10 @@ def compile(self):
raise Exception(f'Compilation failed!\n\nAttempted to execute OS command:\n{os_cmd}\n\n')

def _trace_jac_model(self, inp):
if self.has_batch:
return make_fx(vmap(jacrev(self.model)))(inp)
else:
return make_fx(jacrev(self.model))(inp)
return make_fx(functionalize(jacrev(self.model), remove='mutations_and_views'))(inp)

def _trace_hess_model(self, inp):
if self.has_batch:
return make_fx(vmap(hessian(self.model)))(inp)
else:
return make_fx(hessian(self.model))(inp)
return make_fx(functionalize(hessian(self.model), remove='mutations_and_views'))(inp)

def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]:
if self.has_batch:
Expand Down Expand Up @@ -252,13 +246,13 @@ def export_torch_traces(self, rows: int, cols: int) -> Tuple[bool, bool]:
@staticmethod
def _jit_compile_and_save(model, file_path: str, dummy_inp: torch.Tensor):
# TODO: Could switch to torch export https://pytorch.org/docs/stable/export.html
# Try tracing
try:
torch.jit.trace(model, dummy_inp).save(file_path)
except: # noqa
# Try scripting
ts_compile(model).save(file_path)
except: # noqa
# Try tracing
try:
ts_compile(model).save(file_path)
torch.jit.trace(model, dummy_inp).save(file_path)
except: # noqa
return False
return True
12 changes: 6 additions & 6 deletions libl4casadi/src/l4casadi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ void L4CasADi::forward(const double* in, int rows, int cols, double* out) {
if (this->model_expects_batch_dim) {
in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat);
} else {
in_tensor = torch::from_blob(( void * )in, {rows, cols}, at::kDouble).to(torch::kFloat);
in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0});
}

torch::Tensor out_tensor = this->pImpl->forward(in_tensor).to(torch::kDouble).contiguous();
torch::Tensor out_tensor = this->pImpl->forward(in_tensor).to(torch::kDouble).permute({1, 0}).contiguous();
std::memcpy(out, out_tensor.data_ptr<double>(), out_tensor.numel() * sizeof(double));
}

Expand All @@ -101,10 +101,10 @@ void L4CasADi::jac(const double* in, int rows, int cols, double* out) {
if (this->model_expects_batch_dim) {
in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat);
} else {
in_tensor = torch::from_blob(( void * )in, {rows, cols}, at::kDouble).to(torch::kFloat);
in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0});
}
// CasADi expects the return in Fortran order -> Transpose last two dimensions
torch::Tensor out_tensor = this->pImpl->jac(in_tensor).to(torch::kDouble).transpose(-2, -1).contiguous();
torch::Tensor out_tensor = this->pImpl->jac(in_tensor).to(torch::kDouble).permute({3, 2, 1, 0}).contiguous();
std::memcpy(out, out_tensor.data_ptr<double>(), out_tensor.numel() * sizeof(double));
}

Expand All @@ -113,11 +113,11 @@ void L4CasADi::hess(const double* in, int rows, int cols, double* out) {
if (this->model_expects_batch_dim) {
in_tensor = torch::from_blob(( void * )in, {1, rows}, at::kDouble).to(torch::kFloat);
} else {
in_tensor = torch::from_blob(( void * )in, {rows, cols}, at::kDouble).to(torch::kFloat);
in_tensor = torch::from_blob(( void * )in, {cols, rows}, at::kDouble).to(torch::kFloat).permute({1, 0});
}

// CasADi expects the return in Fortran order -> Transpose last two dimensions
torch::Tensor out_tensor = this->pImpl->hess(in_tensor).to(torch::kDouble).transpose(-2, -1).contiguous();
torch::Tensor out_tensor = this->pImpl->hess(in_tensor).to(torch::kDouble).permute({5, 4, 3, 2, 1, 0}).contiguous();
std::memcpy(out, out_tensor.data_ptr<double>(), out_tensor.numel() * sizeof(double));
}

Expand Down
65 changes: 58 additions & 7 deletions tests/test_l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def forward(self, x):

class TrigModel(torch.nn.Module):
def forward(self, x):
return torch.stack([torch.sin(x[:1]), torch.cos(x[1:2])], dim=-1)
return torch.concat([torch.sin(x[:1]), torch.cos(x[1:2])], dim=0)


class TestL4CasADi:
Expand All @@ -52,15 +52,48 @@ def test_l4casadi_deep_model(self, deep_model):

l4c_out = l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(rand_inp.transpose(-2, -1).detach().numpy())

np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())
assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())

def test_l4casadi_triag_model(self, triag_model):
rand_inp = torch.rand((12, 12))
torch_out = triag_model(rand_inp)

l4c_out = l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(rand_inp.detach().numpy())

np.allclose(l4c_out, torch_out.detach().numpy())
assert np.allclose(l4c_out, torch_out.detach().numpy())

def test_l4casadi_triag_model_jac(self, triag_model):
rand_inp = torch.rand((12, 12))
torch_out = torch.func.jacrev(triag_model)(rand_inp)

mx_inp = cs.MX.sym('x', 12, 12)

jac_fun = cs.Function('f_jac',
[mx_inp],
[cs.jacobian(l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp)])

l4c_out = jac_fun(rand_inp.detach().numpy())

assert np.allclose(
np.moveaxis(np.array(l4c_out).reshape((12, 2, 12, 12)), (0, 1, 2, 3), (1, 0, 3, 2)), # Reshape in Fortran
torch_out.detach().numpy())

def test_l4casadi_triag_model_hess_double_jac(self, triag_model):
rand_inp = torch.rand((12, 12))
torch_out = torch.func.hessian(triag_model)(rand_inp)

mx_inp = cs.MX.sym('x', 12, 12)

hess_fun = cs.Function('f_hess_double_jac',
[mx_inp],
[cs.jacobian(
cs.jacobian(
l4c.L4CasADi(triag_model, model_expects_batch_dim=False)(mx_inp), mx_inp
)[0, 0], mx_inp)])

l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy())

assert np.allclose(np.reshape(l4c_out, (12, 12)).transpose((-2, -1)), torch_out[0, 0, 0, 0].detach().numpy())

def test_l4casadi_deep_model_jac(self, deep_model):
rand_inp = torch.rand((1, deep_model.input_layer.in_features))
Expand All @@ -74,20 +107,38 @@ def test_l4casadi_deep_model_jac(self, deep_model):

l4c_out = jac_fun(rand_inp.transpose(-2, -1).detach().numpy())

np.allclose(l4c_out, torch_out.detach().numpy())
assert np.allclose(l4c_out, torch_out.detach().numpy())

def test_l4casadi_deep_model_hess(self):
deep_model = DeepModel(1, 1)
deep_model = DeepModel(4, 1)
rand_inp = torch.rand((1, deep_model.input_layer.in_features))
torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0]

mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1)

hess_fun = cs.Function('f_jac',
hess_fun = cs.Function('f_hess',
[mx_inp],
[cs.hessian(l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp)[0]])

l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy())

np.allclose(l4c_out, torch_out.detach().numpy())
assert np.allclose(l4c_out, torch_out.detach().numpy())

def test_l4casadi_deep_model_hess_double_jac(self):
deep_model = DeepModel(4, 2)
rand_inp = torch.rand((1, deep_model.input_layer.in_features))
torch_out = torch.func.vmap(torch.func.hessian(deep_model))(rand_inp)[0]

mx_inp = cs.MX.sym('x', deep_model.input_layer.in_features, 1)

hess_fun = cs.Function('f_hess_double_jac',
[mx_inp],
[cs.jacobian(
cs.jacobian(
l4c.L4CasADi(deep_model, model_expects_batch_dim=True)(mx_inp), mx_inp
)[0], mx_inp)])

l4c_out = hess_fun(rand_inp.transpose(-2, -1).detach().numpy())

assert np.allclose(l4c_out, torch_out[0, 0].detach().numpy())

2 changes: 1 addition & 1 deletion tests/test_naive_l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ def test_naive_l4casadi_mlp(self):

l4c_out = l4c.L4CasADi(naive_mlp, model_expects_batch_dim=True)(cs_inp)

np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())
assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())
4 changes: 2 additions & 2 deletions tests/test_realtime_l4casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_realtime_l4casadi_deep_model(self, deep_model):

l4c_out = cs_fun(rand_inp.transpose(-2, -1).detach().numpy(), params)

np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())
assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())

def test_realtime_l4casadi_deep_model_second_order(self, deep_model):
rand_inp = torch.rand((1, deep_model.input_layer.in_features))
Expand All @@ -71,5 +71,5 @@ def test_realtime_l4casadi_deep_model_second_order(self, deep_model):

l4c_out = cs_fun(rand_inp.transpose(-2, -1).detach().numpy(), params)

np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())
assert np.allclose(l4c_out, torch_out.transpose(-2, -1).detach().numpy())

0 comments on commit 53e8340

Please sign in to comment.