diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/DroolsScorecard.java b/drools-scorecards/src/main/java/org/drools/scorecards/DroolsScorecard.java index 9ea309cdefd..0755a52acea 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/DroolsScorecard.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/DroolsScorecard.java @@ -56,35 +56,10 @@ public void setCalculatedScore(double calculatedScore) { this.calculatedScore = calculatedScore; } - public void sortReasonCodes() { - - } - -// public void addPartialScore(int partialScore) { -// this.calculatedScore += partialScore; -// } -// -// public void setInitialScore(int initialScore) { -// this.calculatedScore = initialScore; -// } - public void setInitialScore(double initialScore) { this.calculatedScore = initialScore; } -// public void addPartialScore(double partialScore) { -// this.calculatedScore += partialScore; -// } -// -// public void addPartialScore(String field, double partialScore, String reasonCode) { -// this.calculatedScore += partialScore; -// reasonCodes.add(reasonCode); -// } - -// public void addReasonCode(String reasonCode){ -// reasonCodes.add(reasonCode); -// } -// public List getReasonCodes() { return Collections.unmodifiableList(reasonCodes); } @@ -93,26 +68,40 @@ public void setReasonCodes(List reasonCodes) { this.reasonCodes = reasonCodes; } - public void sortReasonCodes(List partialScores) { + public void sortAndSetReasonCodes(List partialScores) { + sortAndSetReasonCodes(reasonCodeAlgorithm, partialScores); + } + + public void sortAndSetReasonCodes(int reasonCodeAlgorithm, List partialScores) { + setReasonCodeAlgorithm(reasonCodeAlgorithm); TreeMap distanceMap = new TreeMap(); for (PartialScore partialScore : partialScores ){ - if (baselineScoreMap.get(partialScore.getCharacteristic()) != null ) { - double baseline = baselineScoreMap.get(partialScore.getCharacteristic()); - double distance = 0; - if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSABOVE) { - distance = (baseline - partialScore.getScore())+partialScore.getPosition(); + double baseline = partialScore.getBaselineScore(); + double distance = 0; + if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSABOVE) { + distance = (baseline - partialScore.getScore())+partialScore.getPosition(); + if( distance >= baseline) { distanceMap.put(distance, partialScore.getReasoncode()); - } else if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSBELOW){ - distance = (partialScore.getScore()-baseline)+partialScore.getPosition(); + } + } else if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSBELOW){ + distance = (partialScore.getScore()-baseline)+partialScore.getPosition(); + if( distance <= baseline) { distanceMap.put(distance, partialScore.getReasoncode()); } } } + List reasonCodes = new ArrayList(); for ( Double distance : distanceMap.descendingKeySet()) { - System.out.println(distance+" "+distanceMap.get(distance)); + reasonCodes.add(distanceMap.get(distance)); + } + while (reasonCodes.size() < partialScores.size()){ + reasonCodes.add(reasonCodes.get(reasonCodes.size()-1)); } + setReasonCodes(reasonCodes); } + + public DroolsScorecard() { } } diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/PartialScore.java b/drools-scorecards/src/main/java/org/drools/scorecards/PartialScore.java index 94effc3c94b..cc49af5aa2e 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/PartialScore.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/PartialScore.java @@ -20,6 +20,7 @@ public class PartialScore extends BaselineScore implements Serializable { protected String reasoncode; protected int position; + protected double baselineScore; public PartialScore(String scorecardName, String characteristic, double score, String reasoncode, int position) { super(scorecardName, characteristic, score); @@ -27,6 +28,13 @@ public PartialScore(String scorecardName, String characteristic, double score, S this.position = position; } + public PartialScore(String scorecardName, String characteristic, double score, String reasoncode, double baselineScore, int position) { + super(scorecardName, characteristic, score); + this.reasoncode = reasoncode; + this.position = position; + this.baselineScore = baselineScore; + } + public PartialScore(String scorecardName, String characteristic, double score) { super(scorecardName, characteristic, score); this.scorecardName = scorecardName; @@ -45,4 +53,16 @@ public String getReasoncode() { public void setReasoncode(String reasoncode) { this.reasoncode = reasoncode; } + + public double getBaselineScore() { + return baselineScore; + } + + public void setBaselineScore(double baselineScore) { + this.baselineScore = baselineScore; + } + + public void setPosition(int position) { + this.position = position; + } } diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/ScoringStrategy.java b/drools-scorecards/src/main/java/org/drools/scorecards/ScoringStrategy.java new file mode 100644 index 00000000000..689c7449560 --- /dev/null +++ b/drools-scorecards/src/main/java/org/drools/scorecards/ScoringStrategy.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 JBoss Inc + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.drools.scorecards; + +public enum ScoringStrategy { + AGGREGATE_SCORE, AVERAGE_SCORE, MAXIMUM_SCORE, MINIMUM_SCORE, + WEIGHTED_AGGREGATE_SCORE, WEIGHTED_AVERAGE_SCORE, WEIGHTED_MAXIMUM_SCORE, WEIGHTED_MINIMUM_SCORE + +} diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/drl/AbstractDRLEmitter.java b/drools-scorecards/src/main/java/org/drools/scorecards/drl/AbstractDRLEmitter.java index 4cdec06cc69..1ee602ad744 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/drl/AbstractDRLEmitter.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/drl/AbstractDRLEmitter.java @@ -17,6 +17,7 @@ import org.dmg.pmml.pmml_4_1.descr.*; import org.drools.core.util.StringUtils; +import org.drools.scorecards.ScoringStrategy; import org.drools.scorecards.parser.xls.XLSKeywords; import org.drools.scorecards.pmml.PMMLExtensionNames; import org.drools.scorecards.pmml.PMMLOperators; @@ -100,11 +101,15 @@ public String emitDRL( PMML pmml ) { private void addImports( PMML pmml, Package aPackage ) { String importsFromDelimitedString = ScorecardPMMLUtils.getExtensionValue( pmml.getHeader().getExtensions(), PMMLExtensionNames.SCORECARD_IMPORTS ); - if ( !( importsFromDelimitedString == null || importsFromDelimitedString.isEmpty() ) ) { - for ( String importStatement : importsFromDelimitedString.split( "," ) ) { + if ( StringUtils.isEmpty(importsFromDelimitedString) ) { + Import imp = new Import(); + imp.setClassName("java.util.*"); + aPackage.addImport(imp); + } else { + for (String importStatement : importsFromDelimitedString.split(",")) { Import imp = new Import(); - imp.setClassName( importStatement ); - aPackage.addImport( imp ); + imp.setClassName(importStatement); + aPackage.addImport(imp); } } Import defaultScorecardImport = new Import(); @@ -155,9 +160,9 @@ protected List createRuleList( PMML pmmlDocument ) { if ( desc != null ) { rule.setDescription( desc ); } - attributePosition++; - populateLHS( rule, pmmlDocument, scorecard, c, scoreAttribute ); + populateLHS(rule, pmmlDocument, scorecard, c, scoreAttribute); populateRHS( rule, pmmlDocument, scorecard, c, scoreAttribute, attributePosition ); + attributePosition++; ruleList.add( rule ); } } @@ -176,7 +181,9 @@ protected void createInitialRule( List ruleList, rule.setDescription( "set the initial score" ); Condition condition = createInitialRuleCondition( scorecard, objectClass ); - rule.addCondition( condition ); + if ( condition != null) { + rule.addCondition(condition); + } if ( scorecard.getInitialScore() > 0 ) { Consequence consequence = new Consequence(); //consequence.setSnippet("$sc.setInitialScore(" + scorecard.getInitialScore() + ");"); @@ -201,16 +208,16 @@ protected void createInitialRule( List ruleList, } } } - if ( scorecard.getReasonCodeAlgorithm() != null ) { - Consequence consequence = new Consequence(); - if ( "pointsAbove".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) { - //TODO: ReasonCode Algorithm - consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE);" ); - } else if ( "pointsBelow".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) { - consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW);" ); - } - rule.addConsequence( consequence ); - } +// if ( scorecard.getReasonCodeAlgorithm() != null ) { +// Consequence consequence = new Consequence(); +// if ( "pointsAbove".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) { +// //TODO: ReasonCode Algorithm +// consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE);" ); +// } else if ( "pointsBelow".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) { +// consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW);" ); +// } +// rule.addConsequence( consequence ); +// } } ruleList.add( rule ); } @@ -331,13 +338,21 @@ protected void populateRHS( Rule rule, String setter = "insertLogical(new PartialScore(\""; String field = ScorecardPMMLUtils.extractFieldNameFromCharacteristic( c ); - stringBuilder.append( setter ).append( objectClass ).append( "\",\"" ).append( field ).append( "\"," ).append( scoreAttribute.getPartialScore() ); + //stringBuilder.append( setter ).append( objectClass ).append( "\",\"" ).append( field ).append( "\"," ).append( scoreAttribute.getPartialScore() ); + ScoringStrategy scoringStrategy = getScoringStrategy(scorecard); + if ( scoringStrategy.toString().startsWith("WEIGHTED")) { + String weight = ScorecardPMMLUtils.getExtensionValue(scoreAttribute.getExtensions(), PMMLExtensionNames.CHARACTERTISTIC_WEIGHT); + stringBuilder.append(setter).append(objectClass).append("\",\"").append(field).append("\",(").append(scoreAttribute.getPartialScore()).append("*").append(weight).append(")"); + } else { + stringBuilder.append(setter).append(objectClass).append("\",\"").append(field).append("\",").append(scoreAttribute.getPartialScore()); + } if ( scorecard.isUseReasonCodes() ) { String reasonCode = scoreAttribute.getReasonCode(); if ( reasonCode == null || StringUtils.isEmpty( reasonCode ) ) { reasonCode = c.getReasonCode(); } - stringBuilder.append( ",\"" ).append( reasonCode ).append( "\", " ).append( position ); + stringBuilder.append(",\"").append(reasonCode).append("\", ").append(c.getBaselineScore()); + stringBuilder.append(",").append(position); } stringBuilder.append( "));" ); consequence.setSnippet( stringBuilder.toString() ); @@ -350,7 +365,30 @@ protected void createSummationRules( List ruleList, Rule calcTotalRule = new Rule( objectClass + "_calculateTotalScore", 1, 1 ); StringBuilder stringBuilder = new StringBuilder(); Condition condition = new Condition(); - stringBuilder.append( "$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"" ).append( objectClass ).append( "\", $partialScore:score), sum($partialScore))" ); + //stringBuilder.append( "$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"" ).append( objectClass ).append( "\", $partialScore:score), sum($partialScore))" ); + ScoringStrategy strategy = getScoringStrategy(scorecard); + switch (strategy) { + case WEIGHTED_AGGREGATE_SCORE: + case AGGREGATE_SCORE: { + stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), sum($partialScore))"); + break; + } + case WEIGHTED_AVERAGE_SCORE: + case AVERAGE_SCORE:{ + stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), average($partialScore))"); + break; + } + case WEIGHTED_MAXIMUM_SCORE: + case MAXIMUM_SCORE:{ + stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), max($partialScore))"); + break; + } + case WEIGHTED_MINIMUM_SCORE: + case MINIMUM_SCORE:{ + stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), min($partialScore))"); + break; + } + } condition.setSnippet( stringBuilder.toString() ); calcTotalRule.addCondition( condition ); if ( scorecard.getInitialScore() > 0 ) { @@ -368,7 +406,8 @@ protected void createSummationRules( List ruleList, rule.setDescription( "collect and sort the reason codes as per the specified algorithm" ); condition = new Condition(); stringBuilder = new StringBuilder(); - stringBuilder.append( "$reasons : List() from accumulate ( PartialScore(scorecardName == \"" ).append( objectClass ).append( "\", $reasonCode : reasoncode ); collectList($reasonCode) )" ); + // stringBuilder.append("$reasons : List() from accumulate ( PartialScore(scorecardName == \"").append(objectClass).append("\", $reasonCode : reasoncode ); collectList($reasonCode) )"); + stringBuilder.append("$partialScoresList : List() from collect ( PartialScore(scorecardName == \"").append(objectClass).append("\"))"); condition.setSnippet( stringBuilder.toString() ); rule.addCondition( condition ); ruleList.add( rule ); @@ -381,32 +420,22 @@ protected void createSummationRules( List ruleList, addAdditionalSummationConsequence( calcTotalRule, scorecard ); } - protected abstract void addDeclaredTypeContents( PMML pmmlDocument, - StringBuilder stringBuilder, - Scorecard scorecard ); - - protected abstract void internalEmitDRL( PMML pmml, - List ruleList, - Package aPackage ); - - protected abstract void addLHSConditions( Rule rule, - PMML pmmlDocument, - Scorecard scorecard, - Characteristic c, - Attribute scoreAttribute ); - - protected abstract void addAdditionalReasonCodeConsequence( Rule rule, - Scorecard scorecard ); - - protected abstract void addAdditionalReasonCodeCondition( Rule rule, - Scorecard scorecard ); - - protected abstract void addAdditionalSummationConsequence( Rule rule, - Scorecard scorecard ); - - protected abstract void addAdditionalSummationCondition( Rule rule, - Scorecard scorecard ); + protected ScoringStrategy getScoringStrategy(Scorecard scorecard) { + ScoringStrategy strategy = ScoringStrategy.AGGREGATE_SCORE; + String scoringStrategyName = ScorecardPMMLUtils.getExtensionValue(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.SCORECARD_SCORING_STRATEGY); + if ( !StringUtils.isEmpty(scoringStrategyName)) { + strategy = ScoringStrategy.valueOf(scoringStrategyName); + } + return strategy; + } - protected abstract Condition createInitialRuleCondition( Scorecard scorecard, - String objectClass ); + protected abstract void addDeclaredTypeContents( PMML pmmlDocument, StringBuilder stringBuilder, Scorecard scorecard ); + protected abstract void internalEmitDRL( PMML pmml, List ruleList, Package aPackage ); + protected abstract void addLHSConditions( Rule rule, PMML pmmlDocument, Scorecard scorecard, + Characteristic c, Attribute scoreAttribute ); + protected abstract void addAdditionalReasonCodeConsequence( Rule rule, Scorecard scorecard ); + protected abstract void addAdditionalReasonCodeCondition( Rule rule, Scorecard scorecard ); + protected abstract void addAdditionalSummationConsequence( Rule rule, Scorecard scorecard ); + protected abstract void addAdditionalSummationCondition( Rule rule, Scorecard scorecard ); + protected abstract Condition createInitialRuleCondition( Scorecard scorecard, String objectClass ); } diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/drl/DeclaredTypesDRLEmitter.java b/drools-scorecards/src/main/java/org/drools/scorecards/drl/DeclaredTypesDRLEmitter.java index 65ce10c3d95..b3d2db271cd 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/drl/DeclaredTypesDRLEmitter.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/drl/DeclaredTypesDRLEmitter.java @@ -17,6 +17,7 @@ package org.drools.scorecards.drl; import org.dmg.pmml.pmml_4_1.descr.*; +import org.drools.scorecards.ScoringStrategy; import org.drools.scorecards.parser.xls.XLSKeywords; import org.drools.scorecards.pmml.ScorecardPMMLUtils; import org.drools.template.model.Condition; @@ -68,11 +69,14 @@ protected void addLHSConditions(Rule rule, PMML pmmlDocument, Scorecard scorecar @Override protected void addAdditionalReasonCodeConsequence(Rule rule, Scorecard scorecard) { Consequence consequence = new Consequence(); - consequence.setSnippet("$sc.setReasonCodes($reasons);"); - rule.addConsequence(consequence); - consequence = new Consequence(); - consequence.setSnippet("$sc.sortReasonCodes();"); - rule.addConsequence(consequence); + if (scorecard.getReasonCodeAlgorithm() != null) { + if ("pointsAbove".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.sortAndSetReasonCodes(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE, $partialScoresList);"); + } else if ("pointsBelow".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.sortAndSetReasonCodes(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW, $partialScoresList);"); + } + rule.addConsequence(consequence); + } } @Override @@ -89,10 +93,27 @@ protected void addAdditionalSummationCondition(Rule calcTotalRule, Scorecard sco protected void addAdditionalSummationConsequence(Rule calcTotalRule, Scorecard scorecard) { Consequence consequence = new Consequence(); + ScoringStrategy scoringStrategy = getScoringStrategy(scorecard); + switch (scoringStrategy) { + case AGGREGATE_SCORE: + case MINIMUM_SCORE: + case MAXIMUM_SCORE: + case AVERAGE_SCORE: + case WEIGHTED_AVERAGE_SCORE: + case WEIGHTED_MAXIMUM_SCORE: + case WEIGHTED_MINIMUM_SCORE: + case WEIGHTED_AGGREGATE_SCORE: { + consequence.setSnippet("double calculatedScore = $calculatedScore;"); + break; + } + } + + calcTotalRule.addConsequence(consequence); + consequence = new Consequence(); if (scorecard.getInitialScore() > 0) { - consequence.setSnippet("$sc.setCalculatedScore(($calculatedScore+$initialScore));"); + consequence.setSnippet("$sc.setCalculatedScore(calculatedScore+$initialScore);"); } else { - consequence.setSnippet("$sc.setCalculatedScore($calculatedScore);"); + consequence.setSnippet("$sc.setCalculatedScore(calculatedScore);"); } calcTotalRule.addConsequence(consequence); diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/drl/ExternalModelDRLEmitter.java b/drools-scorecards/src/main/java/org/drools/scorecards/drl/ExternalModelDRLEmitter.java index 95381832e06..6cd90a0b401 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/drl/ExternalModelDRLEmitter.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/drl/ExternalModelDRLEmitter.java @@ -99,11 +99,35 @@ protected void addAdditionalReasonCodeConsequence( Rule rule, } if ( !( reasonCodesField == null || reasonCodesField.isEmpty() ) && !( externalClassName == null || externalClassName.isEmpty() ) && !( fieldName == null || fieldName.isEmpty() ) ) { Consequence consequence = new Consequence(); - StringBuilder stringBuilder = new StringBuilder( "$" ); - stringBuilder.append( fieldName ).append( "Var" ).append( ".set" ).append( Character.toUpperCase( reasonCodesField.charAt( 0 ) ) ).append( reasonCodesField.substring( 1 ) ); - stringBuilder.append( "($reasons);" ); - consequence.setSnippet( stringBuilder.toString() ); - rule.addConsequence( consequence ); + consequence.setSnippet("DroolsScorecard $sc = new DroolsScorecard();"); + rule.addConsequence(consequence); + + consequence = new Consequence(); + if ("pointsAbove".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE);"); + } else if ("pointsBelow".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW);"); + } + rule.addConsequence(consequence); + + consequence = new Consequence(); + if (scorecard.getReasonCodeAlgorithm() != null) { + if ("pointsAbove".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.sortAndSetReasonCodes(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE, $partialScoresList);"); + } else if ("pointsBelow".equalsIgnoreCase(scorecard.getReasonCodeAlgorithm())) { + consequence.setSnippet("$sc.sortAndSetReasonCodes(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW, $partialScoresList);"); + } + rule.addConsequence(consequence); + } +// consequence.setSnippet("$sc.sortAndSetReasonCodes($partialScoresList);"); +// rule.addConsequence(consequence); + + consequence = new Consequence(); + StringBuilder stringBuilder = new StringBuilder("$"); + stringBuilder.append(fieldName).append("Var").append(".set").append(Character.toUpperCase(reasonCodesField.charAt(0))).append(reasonCodesField.substring(1)); + stringBuilder.append("($sc.getReasonCodes());"); + consequence.setSnippet(stringBuilder.toString()); + rule.addConsequence(consequence); } } @@ -134,10 +158,19 @@ protected void addAdditionalReasonCodeCondition( Rule rule, } if ( !( reasonCodesField == null || reasonCodesField.isEmpty() ) && !( externalClassName == null || externalClassName.isEmpty() ) && !( fieldName == null || fieldName.isEmpty() ) ) { Condition condition = new Condition(); - StringBuilder stringBuilder = new StringBuilder( "$" ); - stringBuilder.append( fieldName ).append( "Var : " ).append( externalClassName ).append( "()" ); - condition.setSnippet( stringBuilder.toString() ); - rule.addCondition( condition ); + +// String objectClass = scorecard.getModelName().replaceAll(" ", ""); +// StringBuilder stringBuilder = new StringBuilder("$sc:"); +// +// stringBuilder.append(objectClass).append("()"); +// condition.setSnippet(stringBuilder.toString()); +// rule.addCondition(condition); +// +// condition = new Condition(); + StringBuilder stringBuilder = new StringBuilder("$"); + stringBuilder.append(fieldName).append("Var : ").append(externalClassName).append("()"); + condition.setSnippet(stringBuilder.toString()); + rule.addCondition(condition); } } diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/ExcelScorecardValidator.java b/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/ExcelScorecardValidator.java index a9c97feafe3..d7d987a2fff 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/ExcelScorecardValidator.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/ExcelScorecardValidator.java @@ -22,6 +22,7 @@ import org.dmg.pmml.pmml_4_1.descr.Scorecard; import org.drools.core.util.StringUtils; import org.drools.scorecards.ScorecardError; +import org.drools.scorecards.ScoringStrategy; import org.drools.scorecards.StringUtil; import org.drools.scorecards.pmml.PMMLExtensionNames; import org.drools.scorecards.pmml.ScorecardPMMLUtils; @@ -47,6 +48,36 @@ public static void runAdditionalValidations(Scorecard scorecard, List dataExpectations = resolveExpectations(currentRowCtr, currentColCtr); CellReference cellRef = new CellReference(currentRowCtr, currentColCtr); + Method method = null; for (DataExpectation dataExpectation : dataExpectations) { try { if (dataExpectation != null && dataExpectation.object != null) { @@ -62,7 +63,7 @@ private void fulfillExpectation(int currentRowCtr, int currentColCtr, Object cel } } String setter = "set" + Character.toUpperCase(dataExpectation.property.charAt(0)) + dataExpectation.property.substring(1); - Method method = getSuitableMethod(cellValue, expectedClass, dataExpectation, setter); + method = getSuitableMethod(cellValue, expectedClass, dataExpectation, setter); if ( method == null ) { if (cellValue != null && !StringUtils.isEmpty(cellValue.toString())) { parseErrors.add(new ScorecardError(cellRef.formatAsString(), "Unexpected Value! Wrong Datatype?")); @@ -70,11 +71,14 @@ private void fulfillExpectation(int currentRowCtr, int currentColCtr, Object cel return; } if (method.getParameterTypes()[0] == Double.class) { - cellValue = new Double(Double.parseDouble(cellValue.toString())); + cellValue = Double.parseDouble(cellValue.toString()); } if (method.getParameterTypes()[0] == Boolean.class) { cellValue = Boolean.valueOf(cellValue.toString()); } + if (method.getParameterTypes()[0] == String.class && !(cellValue instanceof String) && cellValue != null) { + cellValue = cellValue.toString(); + } method.invoke(dataExpectation.object, cellValue); if (dataExpectation.object instanceof Extension && ("cellRef".equals(((Extension) dataExpectation.object).getName()))) { ((Extension) dataExpectation.object).setValue(cellRef.formatAsString()); @@ -125,6 +129,12 @@ private void setAdditionalExpectation(int currentRowCtr, int currentColCtr, Stri if (XLSKeywords.SCORECARD_NAME.equalsIgnoreCase(stringCellValue)) { addExpectation(currentRowCtr, currentColCtr + 1, "modelName", scorecard, "Model Name is missing!"); + } else if (XLSKeywords.SCORECARD_SCORING_STRATEGY.equalsIgnoreCase(stringCellValue)) { + Extension extension = new Extension(); + extension.setName(PMMLExtensionNames.SCORECARD_SCORING_STRATEGY); + scorecard.getExtensionsAndCharacteristicsAndMiningSchemas().add(extension); + addExpectation(currentRowCtr, currentColCtr + 1, "value", extension, null); + } else if (XLSKeywords.SCORECARD_REASONCODE_ALGORITHM.equalsIgnoreCase(stringCellValue)) { addExpectation(currentRowCtr, currentColCtr + 1, "reasonCodeAlgorithm", scorecard, null); } else if (XLSKeywords.SCORECARD_USE_REASONCODES.equalsIgnoreCase(stringCellValue)) { @@ -201,7 +211,7 @@ private void setAdditionalExpectation(int currentRowCtr, int currentColCtr, Stri addExpectation(currentRowCtr, currentColCtr+1, "baselineScore", scorecard, null); } } else if (XLSKeywords.SCORECARD_REASONCODE.equalsIgnoreCase(stringCellValue)) { - String value = xlsScorecardParser.peekValueAt(currentRowCtr, currentColCtr-4); + String value = xlsScorecardParser.peekValueAt(currentRowCtr, currentColCtr - 4); if ("Name".equalsIgnoreCase(value)){ //only for characteristics... addExpectation(currentRowCtr + 1, currentColCtr, "reasonCode", _characteristic, null); @@ -209,33 +219,60 @@ private void setAdditionalExpectation(int currentRowCtr, int currentColCtr, Stri } else if (XLSKeywords.SCORECARD_CHARACTERISTIC_BIN_ATTRIBUTE.equalsIgnoreCase(stringCellValue)) { MergedCellRange cellRange = getMergedRegionForCell(currentRowCtr + 1, currentColCtr); + if (cellRange != null) { + int indexOfPartialScore = indexOfColumn(cellRange, XLSKeywords.SCORECARD_CHARACTERISTIC_BIN_INITIALSCORE); + int indexOfDescription = indexOfColumn(cellRange, XLSKeywords.SCORECARD_CHARACTERISTIC_BIN_DESC); + int indexOfReasonCodes = indexOfColumn(cellRange, XLSKeywords.SCORECARD_REASONCODE); + int indexOfValue = indexOfColumn(cellRange, XLSKeywords.SCORECARD_CHARACTERISTIC_BIN_LABEL); + int indexOfWeight = indexOfColumn(cellRange, XLSKeywords.SCORECARD_WEIGHT); + for (int r = cellRange.getFirstRow(); r <= cellRange.getLastRow(); r++) { + Attribute attribute = new Attribute(); _characteristic.getAttributes().add(attribute); - addExpectation(r, currentColCtr + 2, "partialScore", attribute, "Characteristic (Property) Partial Score is missing."); - - Extension extension = new Extension(); - extension.setName("description"); - attribute.getExtensions().add(extension); - addExpectation(r, currentColCtr + 3, "value", extension, null); - - extension = new Extension(); - extension.setName(PMMLExtensionNames.CHARACTERTISTIC_FIELD); - attribute.getExtensions().add(extension); - addExpectation(currentRowCtr + 1, currentColCtr, "value", extension, "Characteristic (Property) Name is missing."); - - extension = new Extension(); - extension.setName("predicateResolver"); - attribute.getExtensions().add(extension); - addExpectation(r, currentColCtr + 1, "value", extension, "Characteristic (Property) Value is missing."); - - extension = new Extension(); - extension.setName("cellRef"); - addExpectation(r, currentColCtr + 1, "value", extension, null); - attribute.getExtensions().add(extension); - addExpectation(r, currentColCtr+4, "reasonCode", attribute,null); + + if ( indexOfPartialScore != -1 ) { + addExpectation(r, currentColCtr + indexOfPartialScore, "partialScore", attribute, "Characteristic (Property) Partial Score is missing."); + } + + if ( indexOfDescription != -1) { + Extension extension = new Extension(); + extension.setName("description"); + attribute.getExtensions().add(extension); + addExpectation(r, currentColCtr + indexOfDescription, "value", extension, null); + } + + if ( indexOfValue != -1){ + Extension extension = new Extension(); + extension.setName(PMMLExtensionNames.CHARACTERTISTIC_FIELD); + attribute.getExtensions().add(extension); + addExpectation(currentRowCtr + indexOfValue, currentColCtr, "value", extension, "Characteristic (Property) Name is missing."); + + extension = new Extension(); + extension.setName("predicateResolver"); + attribute.getExtensions().add(extension); + addExpectation(r, currentColCtr + indexOfValue, "value", extension, "Characteristic (Property) Value is missing."); + + extension = new Extension(); + extension.setName("cellRef"); + attribute.getExtensions().add(extension); + addExpectation(r, currentColCtr + indexOfValue, "value", extension, null); + } + + if ( indexOfReasonCodes != -1) { + addExpectation(r, currentColCtr+indexOfReasonCodes, "reasonCode", attribute,null); + } + + if ( indexOfWeight != -1) { + Extension extension = new Extension(); + extension.setName(PMMLExtensionNames.CHARACTERTISTIC_WEIGHT); + attribute.getExtensions().add(extension); + extension.setValue("1"); + addExpectation(r, currentColCtr + indexOfWeight, "value", extension, "Characteristic (Weight) Value is missing."); + } } + MiningField miningField = new MiningField(); miningField.setInvalidValueTreatment(INVALIDVALUETREATMENTMETHOD.AS_MISSING); miningField.setUsageType(FIELDUSAGETYPE.ACTIVE); @@ -249,6 +286,21 @@ private void setAdditionalExpectation(int currentRowCtr, int currentColCtr, Stri } } + private int indexOfColumn(MergedCellRange mergedCellRange, String columnHeading) { + int row = mergedCellRange.getFirstRow()-1; + for ( int i=0; i<10;i++) { + try { + String peekValue = xlsScorecardParser.peekValueAt(row, i); + if ( columnHeading.equalsIgnoreCase(peekValue)) { + return i-mergedCellRange.getFirstCol(); + } + } catch (NullPointerException npe) { + //stay silent. This means the specific cell was not found. Continue looking. + } + } + return -1; + } + private void addExpectation(int row, int column, String property, Object ref, String errorMessage) { expectations.add(new DataExpectation(row, column, ref, property, errorMessage)); } diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/XLSKeywords.java b/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/XLSKeywords.java index 1bbe16251ee..c543a518ee2 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/XLSKeywords.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/parser/xls/XLSKeywords.java @@ -31,6 +31,9 @@ public interface XLSKeywords { public static final String SCORECARD_REASONCODE = "Reason Code"; public static final String SCORECARD_REASONCODE_ALGORITHM = "Reason Code Algorithm"; + public static final String SCORECARD_SCORING_STRATEGY = "Scoring Strategy"; + public static final String SCORECARD_WEIGHT = "Weight"; + public static final String SCORECARD_CHARACTERISTIC_NAME = "Name"; public static final String SCORECARD_CHARACTERISTIC_DATATYPE = "Data Type"; public static final String SCORECARD_CHARACTERISTIC_BASELINE_SCORE = "Baseline Score"; diff --git a/drools-scorecards/src/main/java/org/drools/scorecards/pmml/PMMLExtensionNames.java b/drools-scorecards/src/main/java/org/drools/scorecards/pmml/PMMLExtensionNames.java index c4422049ce9..dd0522fff72 100644 --- a/drools-scorecards/src/main/java/org/drools/scorecards/pmml/PMMLExtensionNames.java +++ b/drools-scorecards/src/main/java/org/drools/scorecards/pmml/PMMLExtensionNames.java @@ -20,6 +20,8 @@ public class PMMLExtensionNames { public static final String SCORECARD_PACKAGE = "scorecardPackage"; + public static final String SCORECARD_SCORING_STRATEGY = "scoringStrategy"; + public static final String SCORECARD_CELL_REF = "cellRef"; public static final String SCORECARD_RESULTANT_SCORE_FIELD = "final"; public static final String SCORECARD_RESULTANT_REASONCODES_FIELD = "reasonCodeField"; @@ -30,7 +32,10 @@ public class PMMLExtensionNames { public static final String CHARACTERTISTIC_FACTTYPE = "factType"; public static final String CHARACTERTISTIC_FIELD = "field"; public static final String CHARACTERTISTIC_DATATYPE = "dataType"; + public static final String CHARACTERTISTIC_WEIGHT = "weight"; + public static final String CHARACTERTISTIC_DESCRIPTION = "description"; + public static final String PREDICATE_SOLVER = "predicateResolver"; public static final String DEFAULT_PREDICTED_FIELD = "scorecard__calculatedScore"; } diff --git a/drools-scorecards/src/test/java/org/drools/scorecards/ScorecardReasonCodeTest.java b/drools-scorecards/src/test/java/org/drools/scorecards/ScorecardReasonCodeTest.java index 92d4bb2fc60..ac1fe494c24 100644 --- a/drools-scorecards/src/test/java/org/drools/scorecards/ScorecardReasonCodeTest.java +++ b/drools-scorecards/src/test/java/org/drools/scorecards/ScorecardReasonCodeTest.java @@ -1,6 +1,6 @@ package org.drools.scorecards; -import junit.framework.Assert; +import org.junit.Assert; import org.dmg.pmml.pmml_4_1.descr.Attribute; import org.dmg.pmml.pmml_4_1.descr.Characteristic; import org.dmg.pmml.pmml_4_1.descr.Characteristics; @@ -18,11 +18,11 @@ import org.kie.internal.runtime.StatefulKnowledgeSession; import org.kie.api.io.ResourceType; -import static junit.framework.Assert.assertEquals; -import static junit.framework.Assert.assertFalse; -import static junit.framework.Assert.assertNotNull; -import static junit.framework.Assert.assertTrue; -import static junit.framework.Assert.fail; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.drools.scorecards.ScorecardCompiler.DrlType.INTERNAL_DECLARED_TYPES; public class ScorecardReasonCodeTest { @@ -72,7 +72,7 @@ public void testUseReasonCodes() throws Exception { for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ if (serializable instanceof Scorecard){ assertTrue(((Scorecard)serializable).isUseReasonCodes()); - assertEquals(100.0, ((Scorecard)serializable).getInitialScore()); + assertEquals(100.0, ((Scorecard)serializable).getInitialScore(), 0.0); assertEquals("pointsBelow",((Scorecard)serializable).getReasonCodeAlgorithm()); } } @@ -107,11 +107,11 @@ public void testBaselineScores() throws Exception { if (obj instanceof Characteristics){ Characteristics characteristics = (Characteristics)obj; assertEquals(4, characteristics.getCharacteristics().size()); - assertEquals(10.0, characteristics.getCharacteristics().get(0).getBaselineScore()); - assertEquals(99.0, characteristics.getCharacteristics().get(1).getBaselineScore()); - assertEquals(12.0, characteristics.getCharacteristics().get(2).getBaselineScore()); - assertEquals(0.0, characteristics.getCharacteristics().get(3).getBaselineScore()); - assertEquals(25.0, ((Scorecard)serializable).getBaselineScore()); + assertEquals(10.0, characteristics.getCharacteristics().get(0).getBaselineScore(), 0.0); + assertEquals(99.0, characteristics.getCharacteristics().get(1).getBaselineScore(), 0.0); + assertEquals(12.0, characteristics.getCharacteristics().get(2).getBaselineScore(), 0.0); + assertEquals(15.0, characteristics.getCharacteristics().get(3).getBaselineScore(), 0.0); + assertEquals(25.0, ((Scorecard)serializable).getBaselineScore(), 0.0); return; } } @@ -170,7 +170,8 @@ public void testReasonCodesCombinations() throws Exception { assertTrue(29 == scorecard.getCalculatedScore()); //age-reasoncode=AGE02, license-reasoncode=VL002 assertEquals(2, scorecard.getReasonCodes().size()); - assertTrue(scorecard.getReasonCodes().contains("AGE02")); + //AGE02 - should be knocked out as we are using the pointsBelow Algorithm. + assertFalse(scorecard.getReasonCodes().contains("AGE02")); assertTrue(scorecard.getReasonCodes().contains("VL099")); session = kbase.newStatefulKnowledgeSession(); @@ -198,11 +199,11 @@ public void testReasonCodesCombinations() throws Exception { session.fireAllRules(); session.dispose(); //occupation = +10, age = +40, state = -10, validLicense = 1 - assertEquals(41.0,scorecard.getCalculatedScore()); + assertEquals(41.0,scorecard.getCalculatedScore(), 0.0); //[OCC02, AGE03, VL001, RS001] assertEquals(4, scorecard.getReasonCodes().size()); assertTrue(scorecard.getReasonCodes().contains("OCC99")); - assertTrue(scorecard.getReasonCodes().contains("AGE03")); + assertFalse(scorecard.getReasonCodes().contains("AGE03")); assertTrue(scorecard.getReasonCodes().contains("VL001")); assertTrue(scorecard.getReasonCodes().contains("RS001")); } @@ -216,6 +217,7 @@ public void testDRLExecution() throws Exception { System.out.println(error.getMessage()); } assertFalse(kbuilder.hasErrors()); + //System.out.println(drl); //BUILD RULEBASE KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(); @@ -234,8 +236,10 @@ public void testDRLExecution() throws Exception { assertTrue(129 == scorecard.getCalculatedScore()); //age-reasoncode=AGE02, license-reasoncode=VL002 assertEquals(2, scorecard.getReasonCodes().size()); - assertTrue(scorecard.getReasonCodes().contains("AGE02")); - assertTrue(scorecard.getReasonCodes().contains("VL002")); + //AGE02 - should be knocked out as we are using the pointsBelow Algorithm. + assertEquals(-1, scorecard.getReasonCodes().indexOf("AGE02")); + assertEquals(0, scorecard.getReasonCodes().indexOf("VL002")); + assertEquals(1, scorecard.getReasonCodes().lastIndexOf("VL002")); session = kbase.newStatefulKnowledgeSession(); scorecard = (DroolsScorecard) scorecardType.newInstance(); @@ -245,7 +249,7 @@ public void testDRLExecution() throws Exception { session.fireAllRules(); session.dispose(); //occupation = -10, age = +10, validLicense = -1, initialScore = 100; - assertEquals(99.0, scorecard.getCalculatedScore()); + assertEquals(99.0, scorecard.getCalculatedScore(), 0.0); assertEquals(3, scorecard.getReasonCodes().size()); //[AGE01, VL002, OCC01] @@ -263,13 +267,156 @@ public void testDRLExecution() throws Exception { session.fireAllRules(); session.dispose(); //occupation = +10, age = +40, state = -10, validLicense = 1, initialScore = 100; - assertEquals(141.0,scorecard.getCalculatedScore()); + assertEquals(141.0,scorecard.getCalculatedScore(), 0.0); //[OCC02, AGE03, VL001, RS001] assertEquals(4, scorecard.getReasonCodes().size()); assertTrue(scorecard.getReasonCodes().contains("OCC02")); - assertTrue(scorecard.getReasonCodes().contains("AGE03")); + assertFalse(scorecard.getReasonCodes().contains("AGE03")); assertTrue(scorecard.getReasonCodes().contains("VL001")); assertTrue(scorecard.getReasonCodes().contains("RS001")); } + @Test + public void testPointsAbove() throws Exception { + ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES); + scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_pointsAbove"); + assertEquals(0, scorecardCompiler.getScorecardParseErrors().size()); + String drl = scorecardCompiler.getDRL(); + assertNotNull(drl); + KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); + + kbuilder.add( ResourceFactory.newByteArrayResource(drl.getBytes()), ResourceType.DRL); + for (KnowledgeBuilderError error : kbuilder.getErrors()){ + System.out.println(error.getMessage()); + } + assertFalse( kbuilder.hasErrors() ); + + //BUILD RULEBASE + KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(); + kbase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); + + //NEW WORKING MEMORY + StatefulKnowledgeSession session = kbase.newStatefulKnowledgeSession(); + FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" ); + + DroolsScorecard scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 10); + session.insert(scorecard); + session.fireAllRules(); + session.dispose(); + //age = 30, validLicence -1 + assertEquals(29.0, scorecard.getCalculatedScore(), 0.0); + //age-reasoncode=AGE02, license-reasoncode=VL002 + assertEquals(2, scorecard.getReasonCodes().size()); + assertEquals(0, scorecard.getReasonCodes().indexOf("VL002")); + //AGE02 - should be knocked out as we are using the pointsAbove Algorithm. + assertEquals(-1, scorecard.getReasonCodes().indexOf("AGE02")); + assertEquals(1, scorecard.getReasonCodes().lastIndexOf("VL002")); + + session = kbase.newStatefulKnowledgeSession(); + scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 0); + scorecardType.set(scorecard, "occupation", "SKYDIVER"); + session.insert(scorecard); + session.fireAllRules(); + session.dispose(); + //occupation = -10, age = +10, validLicense = -1; + assertTrue(-1 == scorecard.getCalculatedScore()); + assertEquals(3, scorecard.getReasonCodes().size()); + //[AGE01, VL002, OCC01] + assertEquals(0, scorecard.getReasonCodes().indexOf("OCC01")); + assertTrue("VL002".equalsIgnoreCase(scorecard.getReasonCodes().get(1))); + assertTrue("VL002".equalsIgnoreCase(scorecard.getReasonCodes().get(2))); + //AGE01 - should be knocked out as we are using the pointsAbove Algorithm. + assertEquals(-1, scorecard.getReasonCodes().indexOf("AGE01")); + + session = kbase.newStatefulKnowledgeSession(); + scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 20); + scorecardType.set(scorecard, "occupation", "TEACHER"); + scorecardType.set(scorecard, "residenceState", "AP"); + scorecardType.set(scorecard, "validLicense", true); + session.insert( scorecard ); + session.fireAllRules(); + session.dispose(); + //occupation = +10, age = +40, state = -10, validLicense = 1 + assertEquals(41.0,scorecard.getCalculatedScore(), 0.0); + //[OCC02, AGE03, VL001, RS001] + assertEquals(4, scorecard.getReasonCodes().size()); + assertEquals(-1, scorecard.getReasonCodes().indexOf("OCC02")); + assertEquals(-1, scorecard.getReasonCodes().indexOf("AGE03")); + assertEquals(-1, scorecard.getReasonCodes().indexOf("VL001")); + assertEquals("RS001", scorecard.getReasonCodes().get(0)); + assertEquals("RS001", scorecard.getReasonCodes().get(1)); + assertEquals("RS001", scorecard.getReasonCodes().get(2)); + assertEquals("RS001", scorecard.getReasonCodes().get(3)); + } + + @Test + public void testPointsBelow() throws Exception { + ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES); + scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"), "scorecards_pointsBelow"); + assertEquals(0, scorecardCompiler.getScorecardParseErrors().size()); + String drl = scorecardCompiler.getDRL(); + assertNotNull(drl); + KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); + + kbuilder.add( ResourceFactory.newByteArrayResource(drl.getBytes()), ResourceType.DRL); + for (KnowledgeBuilderError error : kbuilder.getErrors()){ + System.out.println(error.getMessage()); + } + assertFalse( kbuilder.hasErrors() ); + + //BUILD RULEBASE + KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(); + kbase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); + + //NEW WORKING MEMORY + StatefulKnowledgeSession session = kbase.newStatefulKnowledgeSession(); + FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" ); + + DroolsScorecard scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 10); + session.insert(scorecard); + session.fireAllRules(); + session.dispose(); + //age = 30, validLicence -1 + assertEquals(29.0, scorecard.getCalculatedScore(), 0.0); + //age-reasoncode=AGE02, license-reasoncode=VL002 + assertEquals(2, scorecard.getReasonCodes().size()); + //VL002 - should be knocked out as we are using the pointsBelow Algorithm. + assertEquals(0, scorecard.getReasonCodes().indexOf("VL002")); + + session = kbase.newStatefulKnowledgeSession(); + scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 0); + scorecardType.set(scorecard, "occupation", "SKYDIVER"); + session.insert(scorecard); + session.fireAllRules(); + session.dispose(); + //occupation = -10, age = +10, validLicense = -1; + assertTrue(-1 == scorecard.getCalculatedScore()); + assertEquals(3, scorecard.getReasonCodes().size()); + //[AGE01, VL002, OCC01] + assertEquals(2, scorecard.getReasonCodes().indexOf("OCC01")); + assertEquals(1, scorecard.getReasonCodes().indexOf("VL002")); + assertEquals(0, scorecard.getReasonCodes().indexOf("AGE01")); + + session = kbase.newStatefulKnowledgeSession(); + scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 20); + scorecardType.set(scorecard, "occupation", "TEACHER"); + scorecardType.set(scorecard, "residenceState", "AP"); + scorecardType.set(scorecard, "validLicense", true); + session.insert( scorecard ); + session.fireAllRules(); + session.dispose(); + //occupation = +10, age = +40, state = -10, validLicense = 1 + assertEquals(41.0,scorecard.getCalculatedScore(), 0.0); + //[OCC02, AGE03, VL001, RS001] + assertEquals(4, scorecard.getReasonCodes().size()); + assertEquals(2, scorecard.getReasonCodes().indexOf("OCC02")); + assertEquals(1, scorecard.getReasonCodes().indexOf("RS001")); + assertEquals(0, scorecard.getReasonCodes().indexOf("VL001")); + } } diff --git a/drools-scorecards/src/test/java/org/drools/scorecards/ScoringStrategiesTest.java b/drools-scorecards/src/test/java/org/drools/scorecards/ScoringStrategiesTest.java new file mode 100644 index 00000000000..4887ef79f2b --- /dev/null +++ b/drools-scorecards/src/test/java/org/drools/scorecards/ScoringStrategiesTest.java @@ -0,0 +1,229 @@ +package org.drools.scorecards; + +import org.dmg.pmml.pmml_4_1.descr.Extension; +import org.dmg.pmml.pmml_4_1.descr.PMML; +import org.dmg.pmml.pmml_4_1.descr.Scorecard; +import org.drools.scorecards.pmml.PMMLExtensionNames; +import org.drools.scorecards.pmml.ScorecardPMMLUtils; +import org.junit.Before; +import org.junit.Test; +import org.kie.api.definition.type.FactType; +import org.kie.api.io.ResourceType; +import org.kie.api.runtime.StatelessKieSession; +import org.kie.internal.KnowledgeBase; +import org.kie.internal.KnowledgeBaseFactory; +import org.kie.internal.builder.KnowledgeBuilder; +import org.kie.internal.builder.KnowledgeBuilderError; +import org.kie.internal.builder.KnowledgeBuilderFactory; +import org.kie.internal.builder.ScoreCardConfiguration; +import org.kie.internal.io.ResourceFactory; +import org.kie.internal.runtime.StatefulKnowledgeSession; + +import java.io.InputStream; + +import static junit.framework.Assert.*; +import static org.drools.scorecards.ScorecardCompiler.DrlType.INTERNAL_DECLARED_TYPES; + +public class ScoringStrategiesTest { + + + @Before + public void setUp() throws Exception { + } + + @Test + public void testScoringExtension() throws Exception { + PMML pmmlDocument; + ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES); + if (scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_scoring_strategies.xls")) ) { + pmmlDocument = scorecardCompiler.getPMMLDocument(); + assertNotNull(pmmlDocument); + String drl = scorecardCompiler.getDRL(); + assertNotNull(drl); + for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){ + if (serializable instanceof Scorecard){ + Scorecard scorecard = (Scorecard)serializable; + assertEquals("Sample Score",scorecard.getModelName()); + Extension extension = ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.SCORECARD_SCORING_STRATEGY); + assertNotNull(extension); + assertEquals(extension.getValue(), ScoringStrategy.AGGREGATE_SCORE.toString()); + return; + } + } + } + fail(); + } + + @Test + public void testAggregate() throws Exception { + + double finalScore = executeAndFetchScore("scorecards"); + //age==10 (30), validLicense==FALSE (-1) + assertEquals(29.0, finalScore); + } + + @Test + public void testAverage() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_avg"); + //age==10 (30), validLicense==FALSE (-1) + //count = 2 + assertEquals(14.5, finalScore); + } + + @Test + public void testMinimum() throws Exception { + double finalScore = executeAndFetchScore("scorecards_min"); + //age==10 (30), validLicense==FALSE (-1) + assertEquals(-1.0, finalScore); + } + + @Test + public void testMaximum() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_max"); + //age==10 (30), validLicense==FALSE (-1) + assertEquals(30.0, finalScore); + } + + @Test + public void testWeightedAggregate() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_aggregate"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + assertEquals(599.0, finalScore); + } + + @Test + public void testWeightedAverage() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_avg"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + assertEquals(299.5, finalScore); + } + + @Test + public void testWeightedMaximum() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_max"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + assertEquals(600.0, finalScore); + } + + @Test + public void testWeightedMinimum() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_min"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + assertEquals(-1.0, finalScore); + } + + /* Tests with Initial Score */ + @Test + public void testAggregateInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_initial_score"); + //age==10 (30), validLicense==FALSE (-1) + //initialScore = 100 + assertEquals(129.0, finalScore); + } + + @Test + public void testAverageInitialScore() throws Exception { + double finalScore = executeAndFetchScore("scorecards_avg_initial_score"); + //age==10 (30), validLicense==FALSE (-1) + //count = 2 + //initialScore = 100 + assertEquals(114.5, finalScore); + } + + @Test + public void testMinimumInitialScore() throws Exception { + double finalScore = executeAndFetchScore("scorecards_min_initial_score"); + //age==10 (30), validLicense==FALSE (-1) + //initialScore = 100 + assertEquals(99.0, finalScore); + } + + @Test + public void testMaximumInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_max_initial_score"); + //age==10 (30), validLicense==FALSE (-1) + //initialScore = 100 + assertEquals(130.0, finalScore); + } + + @Test + public void testWeightedAggregateInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_aggregate_initial"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + //initialScore = 100 + assertEquals(699.0, finalScore); + } + + @Test + public void testWeightedAverageInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_avg_initial"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + //initialScore = 100 + assertEquals(399.5, finalScore); + } + + @Test + public void testWeightedMaximumInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_max_initial"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + //initialScore = 100 + assertEquals(700.0, finalScore); + } + + @Test + public void testWeightedMinimumInitialScore() throws Exception { + + double finalScore = executeAndFetchScore("scorecards_w_min_initial"); + //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1) + //initialScore = 100 + assertEquals(99.0, finalScore); + } + + /* Internal functions */ + private double executeAndFetchScore(String sheetName) throws Exception { + + ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES); + InputStream inputStream = PMMLDocumentTest.class.getResourceAsStream("/scoremodel_scoring_strategies.xls"); + boolean compileResult = scorecardCompiler.compileFromExcel(inputStream, sheetName); + if (!compileResult) { + for(ScorecardError error : scorecardCompiler.getScorecardParseErrors()){ + System.err.println("Scorecard Compiler Error :"+error.getErrorLocation()+"->"+error.getErrorMessage()); + } + return -999999; + } + String drl = scorecardCompiler.getDRL(); + //System.out.println(drl); + KnowledgeBuilder kbuilder = KnowledgeBuilderFactory.newKnowledgeBuilder(); + kbuilder.add( ResourceFactory.newByteArrayResource(drl.getBytes()), ResourceType.DRL); + for (KnowledgeBuilderError error : kbuilder.getErrors()){ + System.out.println(error.getMessage()); + } + assertFalse( kbuilder.hasErrors() ); + + KnowledgeBase kbase = KnowledgeBaseFactory.newKnowledgeBase(); + assertNotNull(kbase); + kbase.addKnowledgePackages( kbuilder.getKnowledgePackages() ); + + //NEW WORKING MEMORY + StatelessKieSession session = kbase.newStatelessKnowledgeSession(); + FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" ); + DroolsScorecard scorecard = (DroolsScorecard) scorecardType.newInstance(); + scorecardType.set(scorecard, "age", 10); + session.execute(scorecard); + + return scorecard.getCalculatedScore(); + + } + +} diff --git a/drools-scorecards/src/test/resources/scoremodel_c.xls b/drools-scorecards/src/test/resources/scoremodel_c.xls index 23a70367ab9..f1d77dadd0c 100644 Binary files a/drools-scorecards/src/test/resources/scoremodel_c.xls and b/drools-scorecards/src/test/resources/scoremodel_c.xls differ diff --git a/drools-scorecards/src/test/resources/scoremodel_reasoncodes.xls b/drools-scorecards/src/test/resources/scoremodel_reasoncodes.xls index c0b1d70b6a5..296810b52a3 100644 Binary files a/drools-scorecards/src/test/resources/scoremodel_reasoncodes.xls and b/drools-scorecards/src/test/resources/scoremodel_reasoncodes.xls differ diff --git a/drools-scorecards/src/test/resources/scoremodel_scoring_strategies.xls b/drools-scorecards/src/test/resources/scoremodel_scoring_strategies.xls new file mode 100644 index 00000000000..0c707bde905 Binary files /dev/null and b/drools-scorecards/src/test/resources/scoremodel_scoring_strategies.xls differ