Skip to content

Commit

Permalink
Added Octave support to generic external cost interface
Browse files Browse the repository at this point in the history
  • Loading branch information
mindThomas committed Oct 9, 2020
1 parent 399c6a9 commit 45744ff
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 172 deletions.
Expand Up @@ -73,17 +73,31 @@
ocp_model.set('cost_type', 'ext_cost');
ocp_model.set('cost_type_e', 'ext_cost');

generic_or_casadi = 1; % 0=generic, 1=casadi
generic_or_casadi = 0; % 0=generic, 1=casadi, 2=mixed
if (generic_or_casadi == 0)
ocp_model.set('ext_fun_type', 'generic');
% Generic stage cost
ocp_model.set('ext_fun_type', 'generic');
ocp_model.set('cost_source_ext_cost', 'generic_ext_cost.c');
ocp_model.set('cost_function_ext_cost', 'ext_cost');
% Generic terminal cost
ocp_model.set('ext_fun_type_e', 'generic');
ocp_model.set('cost_source_ext_cost_e', 'generic_ext_cost.c');
ocp_model.set('cost_function_ext_cost_e', 'ext_costN');
else
elseif (generic_or_casadi == 1)
% Casadi stage cost
ocp_model.set('ext_fun_type', 'casadi');
ocp_model.set('cost_expr_ext_cost', model.expr_ext_cost);
% Casadi terminal cost
ocp_model.set('ext_fun_type_e', 'casadi');
ocp_model.set('cost_expr_ext_cost_e', model.expr_ext_cost_e);
elseif (generic_or_casadi == 2)
% Generic stage cost
ocp_model.set('ext_fun_type', 'generic');
ocp_model.set('cost_source_ext_cost', 'generic_ext_cost.c');
ocp_model.set('cost_function_ext_cost', 'ext_cost');
% Casadi terminal cost
ocp_model.set('ext_fun_type_e', 'casadi');
ocp_model.set('cost_expr_ext_cost_e', model.expr_ext_cost_e);
end

% dynamics
Expand Down
Expand Up @@ -3,7 +3,9 @@ extern "C" {
#endif

#include <math.h>
#include "acados_c/ocp_nlp_interface.h"
#include "acados/utils/external_function_generic.h"
#include "blasfeo_d_blas.h"
#include "blasfeo_d_aux.h"

//void ext_cost(void *ext_fun, ext_fun_arg_t *type_in, void **in, ext_fun_arg_t *type_out, void **out)
void ext_cost(void **in, void **out, void *params)
Expand Down
7 changes: 5 additions & 2 deletions interfaces/acados_matlab_octave/acados_ocp_model.m
Expand Up @@ -46,6 +46,7 @@
% default values
obj.model_struct.name = 'ocp_model';
obj.model_struct.ext_fun_type = 'casadi'; % generic
obj.model_struct.ext_fun_type_e = 'casadi'; % generic
obj.model_struct.cost_type = 'auto';
obj.model_struct.cost_type_e = 'auto';
obj.model_struct.dyn_type = 'implicit';
Expand Down Expand Up @@ -94,13 +95,13 @@
obj.model_struct.ext_fun_type = 'casadi'
elseif (strcmp(field, 'cost_expr_ext_cost_e'))
obj.model_struct.cost_expr_ext_cost_e = value;
obj.model_struct.ext_fun_type = 'casadi';
obj.model_struct.ext_fun_type_e = 'casadi';
elseif (strcmp(field, 'cost_source_ext_cost'))
obj.model_struct.cost_source_ext_cost = value;
obj.model_struct.ext_fun_type = 'generic';
elseif (strcmp(field, 'cost_source_ext_cost_e'))
obj.model_struct.cost_source_ext_cost_e = value;
obj.model_struct.ext_fun_type = 'generic';
obj.model_struct.ext_fun_type_e = 'generic';
elseif (strcmp(field, 'cost_function_ext_cost'))
obj.model_struct.cost_function_ext_cost = value;
elseif (strcmp(field, 'cost_function_ext_cost_e'))
Expand Down Expand Up @@ -347,6 +348,8 @@
obj.model_struct.name = value;
elseif (strcmp(field, 'ext_fun_type'))
obj.model_struct.ext_fun_type = value;
elseif (strcmp(field, 'ext_fun_type_e'))
obj.model_struct.ext_fun_type_e = value;
elseif (strcmp(field, 'T'))
obj.model_struct.T = value;
else
Expand Down
8 changes: 4 additions & 4 deletions interfaces/acados_matlab_octave/generate_c_code_ext_cost.m
Expand Up @@ -77,7 +77,7 @@ function generate_c_code_ext_cost( model, opts, target_dir )

model_name = model.name;

if isfield(model, 'cost_expr_ext_cost')
if isfield(model, 'cost_expr_ext_cost') && strcmp(model.ext_fun_type, 'casadi')
ext_cost = model.cost_expr_ext_cost;
% generate jacobians
jac_x = jacobian(ext_cost, x);
Expand All @@ -89,7 +89,7 @@ function generate_c_code_ext_cost( model, opts, target_dir )
hes_xx = jacobian(jac_x', x);
% Set up functions
ext_cost_fun = Function([model_name,'_cost_ext_cost_fun'], {x, u, p}, {ext_cost});
ext_cost_fun_jac = Function([model_name,'_cost_ext_cost_fun_jac'], {x, u, p}, {ext_cost});
ext_cost_fun_jac = Function([model_name,'_cost_ext_cost_fun_jac'], {x, u, p}, {ext_cost, [jac_u'; jac_x']});
ext_cost_fun_jac_hess = Function([model_name,'_cost_ext_cost_fun_jac_hess'], {x, u, p},...
{ext_cost, [jac_u'; jac_x'], [hes_uu, hes_xu; hes_ux, hes_xx]});
% generate C code
Expand All @@ -98,12 +98,12 @@ function generate_c_code_ext_cost( model, opts, target_dir )
ext_cost_fun_jac.generate([model_name,'_cost_ext_cost_fun_jac'], casadi_opts);
end

if isfield(model, 'cost_expr_ext_cost_e')
if isfield(model, 'cost_expr_ext_cost_e') && strcmp(model.ext_fun_type_e, 'casadi')
ext_cost_e = model.cost_expr_ext_cost_e;
% generate jacobians
jac_x_e = jacobian(ext_cost_e, x);
% generate hessians
hes_xx_e = jacobian(jac_x', x);
hes_xx_e = jacobian(jac_x_e', x);
% Set up functions
ext_cost_e_fun = Function([model_name,'_cost_ext_cost_e_fun'], {x, p}, {ext_cost_e});
ext_cost_e_fun_jac = Function([model_name,'_cost_ext_cost_e_fun_jac'], {x, p}, {ext_cost_e, jac_x_e'});
Expand Down
66 changes: 45 additions & 21 deletions interfaces/acados_matlab_octave/ocp_destroy_ext_fun.c
Expand Up @@ -61,6 +61,10 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
if (mxGetField( matlab_model, 0, "ext_fun_type" )!=NULL)
ext_fun_type = mxArrayToString( mxGetField( matlab_model, 0, "ext_fun_type" ) );

char *ext_fun_type_e;
if (mxGetField( matlab_model, 0, "ext_fun_type_e" )!=NULL)
ext_fun_type_e = mxArrayToString( mxGetField( matlab_model, 0, "ext_fun_type_e" ) );

// dims
ptr = (long long *) mxGetData( mxGetField( prhs[1], 0, "dims" ) );
ocp_nlp_dims *dims = (ocp_nlp_dims *) ptr[0];
Expand All @@ -69,48 +73,68 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])

// XXX hard-code number and size of phases for now !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
int NN[] = {N, 1}; // number of phases, i.e. shooting nodes with same dimensions
int Nf = 2; // number of phases
int Nf;

//
int struct_size = mxGetNumberOfFields( prhs[2] );
for (ii=0; ii<struct_size; ii++)
{
// printf("\n%s\n", mxGetFieldNameByNumber( prhs[2], ii) );
//printf("\n%s\n", mxGetFieldNameByNumber( prhs[2], ii) );
mex_field = mxGetFieldByNumber( prhs[2], 0, ii );
ptr = (long long *) mxGetData( mex_field );
Nf = mxGetN( mex_field );
for (jj=0; jj<Nf; jj++)
{
// external function param casadi
if (!strcmp(ext_fun_type, "casadi"))

if (!strcmp(mxGetFieldNameByNumber(prhs[2], ii), "cost_ext_cost_fun") ||
!strcmp(mxGetFieldNameByNumber(prhs[2], ii), "cost_ext_cost_fun_jac_hess")) {

for (jj=0; jj<Nf; jj++)
{
external_function_param_casadi *ext_fun_ptr = (external_function_param_casadi *) ptr[jj];
if (ext_fun_ptr!=0)
{
for (kk=0; kk<NN[jj]; kk++)
// external function param casadi
if ((jj == 0 && !strcmp(ext_fun_type, "casadi")) ||
(jj == 1 && !strcmp(ext_fun_type_e, "casadi")))
{
external_function_param_casadi_free(ext_fun_ptr+kk);
external_function_param_casadi *ext_fun_ptr = (external_function_param_casadi *) ptr[jj];
if (ext_fun_ptr!=0)
{
for (kk=0; kk<NN[jj]; kk++)
{
external_function_param_casadi_free(ext_fun_ptr+kk);
}
free(ext_fun_ptr);
}
}
// external function param generic
else if ((jj == 0 && !strcmp(ext_fun_type, "generic")) ||
(jj == 1 && !strcmp(ext_fun_type_e, "generic")))
{
external_function_param_generic *ext_fun_ptr = (external_function_param_generic *) ptr[jj];
if (ext_fun_ptr!=0)
{
for (kk=0; kk<NN[jj]; kk++)
{
external_function_param_generic_free(ext_fun_ptr+kk);
}
free(ext_fun_ptr);
}
}
else
{
MEX_FIELD_VALUE_NOT_SUPPORTED_SUGGEST(fun_name, "ext_fun_type", ext_fun_type, "casadi, generic");
}
free(ext_fun_ptr);
}
}
// external function param generic
else if (!strcmp(ext_fun_type, "generic"))
} else {
for (jj=0; jj<Nf; jj++)
{
external_function_param_generic *ext_fun_ptr = (external_function_param_generic *) ptr[jj];
external_function_param_casadi *ext_fun_ptr = (external_function_param_casadi *) ptr[jj];
if (ext_fun_ptr!=0)
{
for (kk=0; kk<NN[jj]; kk++)
{
external_function_param_generic_free(ext_fun_ptr+kk);
external_function_param_casadi_free(ext_fun_ptr+kk);
}
free(ext_fun_ptr);
}
}
else
{
MEX_FIELD_VALUE_NOT_SUPPORTED_SUGGEST(fun_name, "ext_fun_type", ext_fun_type, "casadi, generic");
}
}
}

Expand Down
47 changes: 31 additions & 16 deletions interfaces/acados_matlab_octave/ocp_generate_casadi_ext_fun.m
Expand Up @@ -35,6 +35,13 @@ function ocp_generate_casadi_ext_fun(model_struct, opts_struct)

model_name = model_struct.name;

% get acados folder
acados_folder = getenv('ACADOS_INSTALL_DIR');

% set paths
acados_include = ['-I' acados_folder];
blasfeo_include = ['-I' fullfile(acados_folder, 'external' , 'blasfeo', 'include')];

% select files to compile
c_files = {};
% dynamics
Expand Down Expand Up @@ -132,20 +139,20 @@ function ocp_generate_casadi_ext_fun(model_struct, opts_struct)
end

% external cost
if (strcmp(model_struct.ext_fun_type, 'casadi') && (strcmp(model_struct.cost_type, 'ext_cost') || strcmp(model_struct.cost_type_e, 'ext_cost')))
% generate c for function and derivatives using casadi
if (strcmp(opts_struct.codgen_model, 'true'))
generate_c_code_ext_cost(model_struct, opts_struct);
end
% sources list
if isfield(model_struct, 'cost_expr_ext_cost')
c_files{end+1} = [model_name, '_cost_ext_cost_fun.c'];
c_files{end+1} = [model_name, '_cost_ext_cost_fun_jac_hess.c'];
end
if isfield(model_struct, 'cost_expr_ext_cost_e')
c_files{end+1} = [model_name, '_cost_ext_cost_e_fun.c'];
c_files{end+1} = [model_name, '_cost_ext_cost_e_fun_jac_hess.c'];
end
if (strcmp(opts_struct.codgen_model, 'true') && ...
((strcmp(model_struct.ext_fun_type, 'casadi') && strcmp(model_struct.cost_type, 'ext_cost')) || ...
(strcmp(model_struct.ext_fun_type_e, 'casadi') && strcmp(model_struct.cost_type_e, 'ext_cost'))))
% generate c for function and derivatives using casadi
generate_c_code_ext_cost(model_struct, opts_struct);
end
% external cost sources list
if (strcmp(model_struct.cost_type, 'ext_cost') && strcmp(model_struct.ext_fun_type, 'casadi') && isfield(model_struct, 'cost_expr_ext_cost'))
c_files{end+1} = [model_name, '_cost_ext_cost_fun.c'];
c_files{end+1} = [model_name, '_cost_ext_cost_fun_jac_hess.c'];
end
if (strcmp(model_struct.cost_type_e, 'ext_cost') && strcmp(model_struct.ext_fun_type_e, 'casadi') && isfield(model_struct, 'cost_expr_ext_cost_e'))
c_files{end+1} = [model_name, '_cost_ext_cost_e_fun.c'];
c_files{end+1} = [model_name, '_cost_ext_cost_e_fun_jac_hess.c'];
end

if ispc
Expand All @@ -167,10 +174,18 @@ function ocp_generate_casadi_ext_fun(model_struct, opts_struct)
c_files_path{k} = fullfile(opts_struct.output_dir, c_files{k});
end

% generic external cost
if (strcmp(model_struct.cost_type, 'ext_cost') && strcmp(model_struct.ext_fun_type, 'generic') && isfield(model_struct, 'cost_source_ext_cost') && isfield(model_struct, 'cost_function_ext_cost'))
c_files_path{end+1} = model_struct.cost_source_ext_cost;
end
if (strcmp(model_struct.cost_type_e, 'ext_cost') && strcmp(model_struct.ext_fun_type_e, 'generic') && isfield(model_struct, 'cost_source_ext_cost_e') && isfield(model_struct, 'cost_function_ext_cost_e'))
c_files_path{end+1} = model_struct.cost_source_ext_cost_e;
end

if ispc
mbuild(c_files_path{:}, '-output', lib_name, 'CFLAGS="$CFLAGS"', 'LDTYPE="-shared"', ['LDEXT=', ldext]);
mbuild(unique(c_files_path{:}), '-output', lib_name, 'CFLAGS="$CFLAGS"', 'LDTYPE="-shared"', ['LDEXT=', ldext]);
else
system(['gcc -O2 -fPIC -shared ', strjoin(c_files_path, ' '), ' -o ', [lib_name, ldext]]);
system(['gcc -O2 -fPIC -shared ', acados_include, ' ', blasfeo_include, ' ', strjoin(unique(c_files_path), ' '), ' -o ', [lib_name, ldext]]);
end

movefile([lib_name, ldext], fullfile(opts_struct.output_dir, [lib_name, ldext]));
Expand Down

0 comments on commit 45744ff

Please sign in to comment.