# Load Jars + Imports

In [16]:
%jars target/trustyai-drools-1.0-SNAPSHOT.jar
%jars target/lib/*

In [7]:
import drools_integrators.DroolsWrapper;
import rulebases.buspass.Person;
import org.kie.api.KieServices;
import org.kie.api.runtime.KieContainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualConfig;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualExplainer;
import org.kie.kogito.explainability.local.counterfactual.CounterfactualResult;
import org.kie.kogito.explainability.local.counterfactual.SolverConfigBuilder;
import org.kie.kogito.explainability.local.counterfactual.entities.CounterfactualEntity;
import org.kie.kogito.explainability.model.CounterfactualPrediction;
import org.kie.kogito.explainability.model.Feature;
import org.kie.kogito.explainability.model.Output;
import org.kie.kogito.explainability.model.Prediction;
import org.kie.kogito.explainability.model.PredictionInput;
import org.kie.kogito.explainability.model.PredictionOutput;
import org.kie.kogito.explainability.model.PredictionProvider;
import org.kie.kogito.explainability.model.Type;
import org.kie.kogito.explainability.model.Value;
import org.kie.kogito.explainability.model.domain.NumericalFeatureDomain;
import org.optaplanner.core.config.solver.EnvironmentMode;
import org.optaplanner.core.config.solver.SolverConfig;
import org.optaplanner.core.config.solver.termination.TerminationConfig;

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.Supplier;
import java.util.stream.Collectors;

In [8]:
private CounterfactualResult runCounterfactualSearch(Long randomSeed, List<Output> goal,
                                                     List<Feature> features,
                                                     PredictionProvider model,
                                                     double goalThresold) throws InterruptedException, ExecutionException, TimeoutException {
    final TerminationConfig terminationConfig = new TerminationConfig().withScoreCalculationCountLimit(30_000L);
    final SolverConfig solverConfig = SolverConfigBuilder
            .builder().withTerminationConfig(terminationConfig).build();
    solverConfig.setRandomSeed(randomSeed);
    solverConfig.setEnvironmentMode(EnvironmentMode.REPRODUCIBLE);
    final CounterfactualConfig counterfactualConfig = new CounterfactualConfig();
    counterfactualConfig.withSolverConfig(solverConfig).withGoalThreshold(goalThresold);
    final CounterfactualExplainer explainer = new CounterfactualExplainer(counterfactualConfig);
    final PredictionInput input = new PredictionInput(features);
    PredictionOutput output = new PredictionOutput(goal);
    Prediction prediction =
            new CounterfactualPrediction(input,
                    output,
                    null,
                    UUID.randomUUID(),
                    600L);
    return explainer.explainAsync(prediction, model)
            .get(11L, TimeUnit.MINUTES);
}

# Buspass Demo

In [21]:
KieServices ks = KieServices.Factory.get();
KieContainer kieContainer = ks.getKieClasspathContainer();

// build the function to supply objects into the model
Supplier<List<Object>> objectSupplier = () -> {
    Person p = new Person("Yoda", 10);
    return List.of(p);
};

// initialize the wrapper
DroolsWrapper droolsWrapper = new DroolsWrapper(kieContainer,"BussPassGoodKS", objectSupplier);

// setup Feature extraction
droolsWrapper.displayFeatureCandidates();
droolsWrapper.setFeatureExtractorFilters(List.of("(age)"));
droolsWrapper.displayFeatureCandidates();
for (Feature f: droolsWrapper.featureExtractor(objectSupplier.get()).keySet()) {
    droolsWrapper.addFeatureDomain(f.getName(), NumericalFeatureDomain.create(0., 100.));
}
PredictionInput samplePI = new PredictionInput(new ArrayList<>(droolsWrapper.featureExtractor(objectSupplier.get()).keySet()));
droolsWrapper.generateOutputCandidates(true);
droolsWrapper.selectOutputIndicesFromCandidates(List.of(10));

// wrap model into predictionprovider
PredictionProvider wrappedModel = droolsWrapper.wrap();
System.out.println("== Original Output ==");
wrappedModel.predictAsync(List.of(samplePI)).get().get(0).getOutputs().get(0).getValue();

// run counterfactual
List<Output> goal = new ArrayList<>();
goal.add(new Output("rulebases.buspass.ChildBusPass_5", Type.CATEGORICAL, new Value("Not Created"), 0.0d));
try{
    CounterfactualResult result = runCounterfactualSearch(0L, goal, samplePI.getFeatures(), wrappedModel, .01);
} catch (Exception e) {
    e.printStackTrace();
}
System.out.println(result.getEntities().stream().map(CounterfactualEntity::asFeature).collect(Collectors.toList()));
System.out.println(result.isValid());
System.out.println(result.getOutput().get(0).getOutputs());

                                                  Feature |  Value
---------------------------------------------------------------------
 rulebases.buspass.Person_0_rulebases.buspass.Person.name |   Yoda
  rulebases.buspass.Person_0_rulebases.buspass.Person.age |     10
                                                 Feature |  Value
--------------------------------------------------------------------
 rulebases.buspass.Person_0_rulebases.buspass.Person.age |     10
 Index |                 Rule |                                  Field Name |  Final Value
---------------------------------------------------------------------------------------
     0 |    IssueChildBusPass |                  rulebases.buspass.Person_1 |      Created
     1 |    IssueChildBusPass |                rulebases.buspass.Person.age |           10
     2 |    IssueChildBusPass |       rulebases.buspass.IsChild.person.name |         Yoda
     3 |    IssueChildBusPass |                 rulebases.buspass.IsChild_4 

java.util.concurrent.ExecutionException: java.lang.IllegalArgumentException: The scoreDirectorFactory lacks a configuration for an easyScoreCalculatorClass or an incrementalScoreCalculatorClass.
	at java.base/java.util.concurrent.CompletableFuture.reportGet(CompletableFuture.java:395)
	at java.base/java.util.concurrent.CompletableFuture.get(CompletableFuture.java:2022)
	at REPL.$JShell$48.runCounterfactualSearch($JShell$48.java:66)
	at REPL.$JShell$163.do_it$($JShell$163.java:47)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at io.github.spencerpark.ijava.execution.IJavaExecutionControl.lambda$execute$1(IJavaExecutionControl.java:95)
	at java.base/java.util.concurrent.FutureTask.r

EvalException: null