Skip to content
Permalink
Browse files

add eval_para_sens to matlab mex interface

  • Loading branch information...
giaf committed Aug 21, 2019
1 parent 57d320e commit 273c8499b26b96de172f4a65287142df677c62ce
@@ -109,7 +109,7 @@ ACADOS_WITH_OPENMP = 0
ACADOS_NUM_THREADS = 4

# include QPOASES
ACADOS_WITH_QPOASES = 1
ACADOS_WITH_QPOASES = 0

# include HPMPC
ACADOS_WITH_HPMPC = 0
@@ -896,11 +896,13 @@ int main()
ocp_nlp_opts_update(config, dims, nlp_opts);

/************************************************
* ocp_nlp out
* ocp_nlp_out & solver
************************************************/

ocp_nlp_out *nlp_out = ocp_nlp_out_create(config, dims);

ocp_nlp_out *sens_nlp_out = ocp_nlp_out_create(config, dims);

ocp_nlp_solver *solver = ocp_nlp_solver_create(config, dims, nlp_opts);

/************************************************
@@ -978,7 +980,10 @@ int main()
// solve NLP
status = ocp_nlp_solve(solver, nlp_in, nlp_out);

// config->eval_param_sens(config, dims, solver->opts, solver->mem, solver->work, "ex", 0, 0, nlp_out);
// evaluate parametric sensitivity of solution
// ocp_nlp_out_print(dims, nlp_out);
ocp_nlp_eval_param_sens(solver, "ex", 0, 0, sens_nlp_out);
// ocp_nlp_out_print(dims, nlp_out);

// update initial condition
// TODO(dimitris): maybe simulate system instead of passing x[1] as next state
@@ -1072,6 +1077,7 @@ int main()
ocp_nlp_opts_destroy(nlp_opts);
ocp_nlp_in_destroy(nlp_in);
ocp_nlp_out_destroy(nlp_out);
ocp_nlp_out_destroy(sens_nlp_out);
ocp_nlp_solver_destroy(solver);
ocp_nlp_dims_destroy(dims);
ocp_nlp_config_destroy(config);
@@ -75,8 +75,8 @@
qp_solver_ric_alg = 0;
qp_solver_warm_start = 2;
%sim_method = 'erk';
%sim_method = 'irk';
sim_method = 'irk_gnsf';
sim_method = 'irk';
%sim_method = 'irk_gnsf';
sim_method_num_stages = 4;
sim_method_num_steps = 3;
cost_type = 'linear_ls';
@@ -323,10 +323,10 @@

%% figures

% for ii=1:N+1
% x_cur = x(:,ii);
% visualize;
% end
for ii=1:N+1
x_cur = x(:,ii);
% visualize;
end

figure(2);
subplot(2,1,1);
@@ -364,6 +364,45 @@
end


% paramteric sensitivity of solution

field = 'ex'; % equality constraint on states
stage = 0;
index = 0;
ocp.eval_param_sens(field, stage, index);

sens_u = ocp.get('sens_u');
sens_x = ocp.get('sens_x');

% plot sensitivity
figure(4);
subplot(2,1,1);
plot(0:N, sens_x);
xlim([0 N]);
legend('p', 'theta', 'v', 'omega');
subplot(2,1,2);
plot(0:N-1, sens_u);
xlim([0 N]);
legend('F');

% plot predicted solution
figure(5);
subplot(2,1,1);
plot(0:N, x+sens_x);
xlim([0 N]);
legend('p', 'theta', 'v', 'omega');
subplot(2,1,2);
plot(0:N-1, u+sens_u);
xlim([0 N]);
legend('F');

for ii=1:N+1
x_cur = x(:,ii)+sens_x(:,ii);
% visualize;
end



waitforbuttonpress;


@@ -92,6 +92,12 @@ function solve(obj)



function eval_param_sens(obj, field, stage, index)
ocp_eval_param_sens(obj.C_ocp, field, stage, index);
end



% function set(obj, field, value)
% ocp_set(obj.model_struct, obj.opts_struct, obj.C_ocp, obj.C_ocp_ext_fun, field, value);
% end
@@ -55,6 +55,7 @@ function ocp_compile_mex(opts)
'ocp_precompute', ...
'ocp_set', ...
'ocp_get' ...
'ocp_eval_param_sens', ...
};
mex_files = cell(length(mex_names), 1);
for k=1:length(mex_names)
@@ -828,30 +828,33 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
/* LHS */

// field names of output struct
char *fieldnames[6];
char *fieldnames[7];
fieldnames[0] = (char*) mxMalloc(50);
fieldnames[1] = (char*) mxMalloc(50);
fieldnames[2] = (char*) mxMalloc(50);
fieldnames[3] = (char*) mxMalloc(50);
fieldnames[4] = (char*) mxMalloc(50);
fieldnames[5] = (char*) mxMalloc(50);
fieldnames[6] = (char*) mxMalloc(50);

memcpy(fieldnames[0],"config",sizeof("config"));
memcpy(fieldnames[1],"dims",sizeof("dims"));
memcpy(fieldnames[2],"opts",sizeof("opts"));
memcpy(fieldnames[3],"in",sizeof("in"));
memcpy(fieldnames[4],"out",sizeof("out"));
memcpy(fieldnames[5],"solver",sizeof("solver"));
memcpy(fieldnames[6],"sens_out",sizeof("sens_out"));

// create output struct
plhs[0] = mxCreateStructMatrix(1, 1, 6, (const char **) fieldnames);
plhs[0] = mxCreateStructMatrix(1, 1, 7, (const char **) fieldnames);

mxFree( fieldnames[0] );
mxFree( fieldnames[1] );
mxFree( fieldnames[2] );
mxFree( fieldnames[3] );
mxFree( fieldnames[4] );
mxFree( fieldnames[5] );
mxFree( fieldnames[6] );



@@ -1967,6 +1970,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])



/* sens_out */

ocp_nlp_out *sens_out = ocp_nlp_out_create(config, dims);



/* populate output struct */

// config
@@ -2005,6 +2014,12 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
l_ptr[0] = (long long) solver;
mxSetField(plhs[0], 0, "solver", solver_mat);

// sens_out
mxArray *sens_out_mat = mxCreateNumericMatrix(1, 1, mxINT64_CLASS, mxREAL);
l_ptr = mxGetData(sens_out_mat);
l_ptr[0] = (long long) sens_out;
mxSetField(plhs[0], 0, "sens_out", sens_out_mat);



/* return */
@@ -75,6 +75,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// solver
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "solver" ) );
ocp_nlp_solver *solver = (ocp_nlp_solver *) ptr[0];
// sens_out
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "sens_out" ) );
ocp_nlp_out *sens_out = (ocp_nlp_out *) ptr[0];



@@ -86,6 +89,7 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
ocp_nlp_in_destroy(in);
ocp_nlp_out_destroy(out);
ocp_nlp_solver_destroy(solver);
ocp_nlp_out_destroy(sens_out);



@@ -0,0 +1,88 @@
/*
* Copyright 2019 Gianluca Frison, Dimitris Kouzoupis, Robin Verschueren,
* Andrea Zanelli, Niels van Duijkeren, Jonathan Frey, Tommaso Sartor,
* Branimir Novoselnik, Rien Quirynen, Rezart Qelibari, Dang Doan,
* Jonas Koenemann, Yutao Chen, Tobias Schöls, Jonas Schlagenhauf, Moritz Diehl
*
* This file is part of acados.
*
* The 2-Clause BSD License
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
* this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
* POSSIBILITY OF SUCH DAMAGE.;
*/

// system
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
// acados
//#include "acados/sim/sim_common.h"
#include "acados_c/ocp_nlp_interface.h"
// mex
#include "mex.h"



void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{

// mexPrintf("\nin ocp_solve\n");

long long *ptr;

/* RHS */

// C_ocp

// solver
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "solver" ) );
ocp_nlp_solver *solver = (ocp_nlp_solver *) ptr[0];
// sens_out
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "sens_out" ) );
ocp_nlp_out *sens_out = (ocp_nlp_out *) ptr[0];

// field
char *field = mxArrayToString( prhs[1] );

// stage
int stage = mxGetScalar( prhs[2] );

// index
int index = mxGetScalar( prhs[3] );



/* solver */
ocp_nlp_eval_param_sens(solver, field, stage, index, sens_out);



/* return */

return;

}




@@ -67,6 +67,9 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
// solver
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "solver" ) );
ocp_nlp_solver *solver = (ocp_nlp_solver *) ptr[0];
// sens_out
ptr = (long long *) mxGetData( mxGetField( prhs[0], 0, "sens_out" ) );
ocp_nlp_out *sens_out = (ocp_nlp_out *) ptr[0];

// field
char *field = mxArrayToString( prhs[1] );
@@ -152,6 +155,78 @@ void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
return;
}
}
else if(!strcmp(field, "sens_x"))
{
if(nrhs==2)
{
plhs[0] = mxCreateNumericMatrix(nx, N+1, mxDOUBLE_CLASS, mxREAL);
double *x = mxGetPr( plhs[0] );
for(ii=0; ii<=N; ii++)
{
ocp_nlp_out_get(config, dims, sens_out, ii, "x", x+ii*nx);
}
}
else if(nrhs==3)
{
plhs[0] = mxCreateNumericMatrix(nx, 1, mxDOUBLE_CLASS, mxREAL);
double *x = mxGetPr( plhs[0] );
int stage = mxGetScalar( prhs[2] );
ocp_nlp_out_get(config, dims, sens_out, stage, "x", x);
}
else
{
mexPrintf("\nocp_get: wrong nrhs: %d\n", nrhs);
return;
}
}
else if(!strcmp(field, "sens_u"))
{
if(nrhs==2)
{
plhs[0] = mxCreateNumericMatrix(nu, N, mxDOUBLE_CLASS, mxREAL);
double *u = mxGetPr( plhs[0] );
for(ii=0; ii<N; ii++)
{
ocp_nlp_out_get(config, dims, sens_out, ii, "u", u+ii*nu);
}
}
else if(nrhs==3)
{
plhs[0] = mxCreateNumericMatrix(nu, 1, mxDOUBLE_CLASS, mxREAL);
double *u = mxGetPr( plhs[0] );
int stage = mxGetScalar( prhs[2] );
ocp_nlp_out_get(config, dims, sens_out, stage, "u", u);
}
else
{
mexPrintf("\nocp_get: wrong nrhs: %d\n", nrhs);
return;
}
}
else if(!strcmp(field, "sens_pi"))
{
if(nrhs==2)
{
plhs[0] = mxCreateNumericMatrix(nx, N, mxDOUBLE_CLASS, mxREAL);
double *pi = mxGetPr( plhs[0] );
for(ii=0; ii<N; ii++)
{
ocp_nlp_out_get(config, dims, sens_out, ii, "pi", pi+ii*nx);
}
}
else if(nrhs==3)
{
plhs[0] = mxCreateNumericMatrix(nx, 1, mxDOUBLE_CLASS, mxREAL);
double *pi = mxGetPr( plhs[0] );
int stage = mxGetScalar( prhs[2] );
ocp_nlp_out_get(config, dims, sens_out, stage, "pi", pi);
}
else
{
mexPrintf("\nocp_get: wrong nrhs: %d\n", nrhs);
return;
}
}
else if(!strcmp(field, "status"))
{
plhs[0] = mxCreateNumericMatrix(1, 1, mxDOUBLE_CLASS, mxREAL);

0 comments on commit 273c849

Please sign in to comment.
You can’t perform that action at this time.