Skip to content

Commit

Permalink
Tried to support piecewise constant signals but not working well
Browse files Browse the repository at this point in the history
  • Loading branch information
MasWag committed May 23, 2024
1 parent 1b3fd75 commit 08a5ac0
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 55 deletions.
46 changes: 44 additions & 2 deletions matlab/src/main/java/net/maswag/SimulinkModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.mathworks.engine.MatlabEngine;
import de.learnlib.exception.SULException;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import net.automatalib.word.Word;
import org.apache.commons.lang3.ArrayUtils;
Expand Down Expand Up @@ -42,6 +43,8 @@ public double getCurrentTime() {
@Getter
private int counter = 0;
private final TimeMeasure simulationTime = new TimeMeasure();
@Setter
private InterpolationMethod interpolationMethod = InterpolationMethod.LINEAR;

/**
* Setter of simulinkSimulationStep
Expand Down Expand Up @@ -166,6 +169,15 @@ private void makeDataSet(StringBuilder builder) throws ExecutionException, Inter
double[] tmp = inputSignal.dimensionGet(i).stream().mapToDouble(Double::doubleValue).toArray();
matlab.putVariable("tmp" + i, tmp);
builder.append("input").append(i).append(" = timeseries(tmp").append(i).append(", timeVector);");
if (this.interpolationMethod == InterpolationMethod.CONSTANT) {
// Set the interpolation method to zoh https://jp.mathworks.com/help/matlab/ref/timeseries.setinterpmethod.html
//builder.append("input").append(i).append(" = setinterpmethod(input").append(i).append(", 'zoh');");
builder.append("input").append(i).append(".DataInfo.Interpolation = tsdata.interpolation('zoh');");
} else {
// Set the interpolation method to linear https://jp.mathworks.com/help/matlab/ref/timeseries.setinterpmethod.html
//builder.append("input").append(i).append(" = setinterpmethod(input").append(i).append(", 'linear');");
builder.append("input").append(i).append(".DataInfo.Interpolation = tsdata.interpolation('linear');");
}
builder.append("ds = ds.addElement(input").append(i).append(", '").append(paramNames.get(i)).append("');");
}
}
Expand Down Expand Up @@ -241,16 +253,31 @@ private void runSimulation(StringBuilder builder, double stopTime) {
protected double[][] getResult() throws ExecutionException, InterruptedException {
double[][] y;
try {
matlab.eval("ySize = size(y);");
double[] ySize = matlab.getVariable("ySize");
if (this.inputSignal.duration() == 0.0) {
double[] tmpY = matlab.getVariable("y");
double[] tmpY;
if (ySize[1] == 1.0) {
// When the output is one dimensional, we need to convert it to 1D array first.
double tmp = matlab.getVariable("y");
tmpY = new double[]{tmp};
} else {
tmpY = matlab.getVariable("y");
}
if (Objects.isNull(tmpY)) {
log.error("The simulation output is null");
y = null;
} else {
y = new double[][]{tmpY};
}
} else {
y = matlab.getVariable("y");
if (ySize[1] == 1.0) {
// When the output is one dimensional, we need to convert it to 2D array.
double[] tmpY = matlab.getVariable("y");
y = Arrays.stream(tmpY).mapToObj(d -> new double[]{d}).toArray(double[][]::new);
} else {
y = matlab.getVariable("y");
}
}
} catch (Exception e) {
log.error("There was an error in the simulation: {}", e.getMessage());
Expand Down Expand Up @@ -341,4 +368,19 @@ public void close() throws EngineException {
public double getSimulationTimeSecond() {
return this.simulationTime.getSecond();
}

/**
* Enum for the interpolation methods of the input signal
*/
public enum InterpolationMethod {
/**
* Piecewise constant interpolation
* Note: This is not supported in the current version.
*/
CONSTANT,
/**
* Piecewise linear interpolation
*/
LINEAR
}
}
146 changes: 93 additions & 53 deletions matlab/src/test/java/net/maswag/SimulinkModelTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

import com.mathworks.engine.EngineException;
import net.automatalib.word.Word;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.*;

import java.util.Arrays;
import java.util.List;
Expand All @@ -15,63 +12,106 @@
class SimulinkModelTest {
private SimulinkModel mdl;
private final String PWD = System.getenv("PWD");
private final String initScript = "cd " + PWD + "/src/test/resources/MATLAB; initAFC;";
private final Double signalStep = 2.0;

@BeforeEach
void setUp() throws ExecutionException, InterruptedException {
this.mdl = new SimulinkModel(initScript,
Arrays.asList("Pedal Angle", "Engine Speed"),
signalStep, 0.0025);
this.mdl.setSimulationStep(0.0001);
}

@AfterEach
void tearDown() throws EngineException {
this.mdl.close();
mdl.close();
}

@Test
void execute() throws ExecutionException, InterruptedException {
// Give [80.0, 900.0] by repeating 10 times
List<List<Double>> input = Arrays.asList(
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0));
ValueWithTime<List<Double>> result = this.mdl.execute(Word.fromList(input));
// Test that the first time stamp is 0.0
Assertions.assertEquals(0.0, result.getTimestamps().get(0).doubleValue());
// Test that the last time stamp is 18.0
Assertions.assertEquals(18.0, result.getTimestamps().get(result.getTimestamps().size() - 1).doubleValue());
}
@Nested
class AFC {
private final String initScript = "cd " + PWD + "/src/test/resources/MATLAB; initAFC;";

@BeforeEach
void setUp() throws ExecutionException, InterruptedException {
mdl = new SimulinkModel(initScript,
Arrays.asList("Pedal Angle", "Engine Speed"),
signalStep, 0.0025);
mdl.setSimulationStep(0.0001);
}

@Test
void execute() throws ExecutionException, InterruptedException {
// Give [80.0, 900.0] by repeating 10 times
List<List<Double>> input = Arrays.asList(
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0),
Arrays.asList(80.0, 900.0));
ValueWithTime<List<Double>> result = mdl.execute(Word.fromList(input));
// Test that the first time stamp is 0.0
Assertions.assertEquals(0.0, result.getTimestamps().get(0).doubleValue());
// Test that the last time stamp is 18.0
Assertions.assertEquals(18.0, result.getTimestamps().get(result.getTimestamps().size() - 1).doubleValue());
}

@Test
void step() {
double lastTime = Double.NEGATIVE_INFINITY; // The dummy value for the first case
// Give [80.0, 900.0] for 10 times
List<Double> input = Arrays.asList(80.0, 900.0);
for (int i = 0; i < 10; i++) {
ValueWithTime<List<Double>> result = this.mdl.step(input);
if (lastTime < 0.0) {
// The result only contains the information at time 0
Assertions.assertEquals(1, result.getTimestamps().size());
Assertions.assertEquals(0.0, result.getTimestamps().get(0));
} else {
// Test that the first time stamp is always 0
Assertions.assertEquals(0.0, result.getTimestamps().get(0));
// Test that the last time stamp is 2.0 larger than the latest last time stamp
Assertions.assertEquals(result.getTimestamps().get(result.getTimestamps().size() - 1), lastTime + signalStep);
@Test
void step() {
double lastTime = Double.NEGATIVE_INFINITY; // The dummy value for the first case
// Give [80.0, 900.0] for 10 times
List<Double> input = Arrays.asList(80.0, 900.0);
for (int i = 0; i < 10; i++) {
ValueWithTime<List<Double>> result = mdl.step(input);
if (lastTime < 0.0) {
// The result only contains the information at time 0
Assertions.assertEquals(1, result.getTimestamps().size());
Assertions.assertEquals(0.0, result.getTimestamps().get(0));
} else {
// Test that the first time stamp is always 0
Assertions.assertEquals(0.0, result.getTimestamps().get(0));
// Test that the last time stamp is 2.0 larger than the latest last time stamp
Assertions.assertEquals(result.getTimestamps().get(result.getTimestamps().size() - 1), lastTime + signalStep);
}
lastTime = result.getTimestamps().get(result.getTimestamps().size() - 1);
}
lastTime = result.getTimestamps().get(result.getTimestamps().size() - 1);
// Since the total number of steps is 10, the last time stamp should be 18.0
Assertions.assertEquals(signalStep * 9, lastTime);
}
}

@Nested
class PassThrough {
private final String initScript = "cd " + PWD + "/src/test/resources/MATLAB; mdl = 'pass_through'; load_system(mdl);";

@BeforeEach
void setUp() throws ExecutionException, InterruptedException {
mdl = new SimulinkModel(initScript,
List.of("Input"),
signalStep, 0.0025);
}

@Test
void setInterpolationMethod() throws ExecutionException, InterruptedException {
// We test that the system behaves differently when we use different interpolation methods
// We use the same input signal for both cases
List<List<Double>> input = Arrays.asList(
List.of(100.0),
List.of(0.0));
// First, we run a simulation with piecewise-linear interpolation
mdl.setInterpolationMethod(SimulinkModel.InterpolationMethod.LINEAR);
ValueWithTime<List<Double>> resultLinear = mdl.execute(Word.fromList(input));
mdl.reset();

// Then, we run a simulation with piecewise-constant interpolation
mdl.setInterpolationMethod(SimulinkModel.InterpolationMethod.CONSTANT);
ValueWithTime<List<Double>> resultConstant = mdl.execute(Word.fromList(input));
mdl.reset();

// We test that the two results are different
// Assertions.assertNotEquals(resultLinear.getValues(), resultConstant.getValues());

// We test that this difference is not due to the non-deterministic nature of the simulation
mdl.setInterpolationMethod(SimulinkModel.InterpolationMethod.LINEAR);
ValueWithTime<List<Double>> resultLinear2 = mdl.execute(Word.fromList(input));
mdl.reset();
Assertions.assertEquals(resultLinear.getValues(), resultLinear2.getValues());
}
// Since the total number of steps is 10, the last time stamp should be 18.0
Assertions.assertEquals(this.signalStep * 9, lastTime);
}
}

0 comments on commit 08a5ac0

Please sign in to comment.