Skip to content

Commit

Permalink
[PMML 6.x] Add Naive Bayes Model
Browse files Browse the repository at this point in the history
  • Loading branch information
sotty committed Feb 7, 2014
1 parent a2a9095 commit 88f4f7f
Show file tree
Hide file tree
Showing 23 changed files with 750 additions and 49 deletions.
Expand Up @@ -17,6 +17,7 @@
package org.drools.pmml.pmml_4_1;

import org.dmg.pmml.pmml_4_1.descr.ClusteringModel;
import org.dmg.pmml.pmml_4_1.descr.NaiveBayesModel;
import org.dmg.pmml.pmml_4_1.descr.NeuralNetwork;
import org.dmg.pmml.pmml_4_1.descr.PMML;
import org.dmg.pmml.pmml_4_1.descr.RegressionModel;
Expand Down Expand Up @@ -172,6 +173,14 @@ public class PMML4Compiler implements PMMLCompiler {
"models/svm/svmOutputVote1v1.drlt",
};

protected static boolean naiveBayesLoaded = false;
protected static final String[] NAIVE_BAYES_TEMPLATES = new String[] {
"models/bayes/naiveBayesDeclare.drlt",
"models/bayes/naiveBayesEval.drlt",
"models/bayes/naiveBayesBuildCounts.drlt",
"models/bayes/naiveBayesBuildOuts.drlt",
};

protected static boolean simpleRegLoaded = false;
protected static final String[] SIMPLEREG_TEMPLATES = new String[] {
"models/regression/regDeclare.drlt",
Expand Down Expand Up @@ -321,6 +330,16 @@ private static KieBase checkBuildingResources( PMML pmml ) throws IOException {

for ( Object o : pmml.getAssociationModelsAndBaselineModelsAndClusteringModels() ) {

if ( o instanceof NaiveBayesModel ) {
if ( ! naiveBayesLoaded ) {
for ( String ntempl : NAIVE_BAYES_TEMPLATES ) {
prepareTemplate( ntempl );
}
naiveBayesLoaded = true;
}
chosenKieBase = chosenKieBase == null ? "PMML-Bayes" : "PMML";
}

if ( o instanceof NeuralNetwork ) {
if ( ! neuralLoaded ) {
for ( String ntempl : NEURAL_TEMPLATES ) {
Expand Down
Expand Up @@ -1034,4 +1034,5 @@ public DATATYPE mapFeatureType( DATATYPE srcType, RESULTFEATURE feat ) {
public void reset() {
definedModelBeans = new HashSet<String>();
}

}
4 changes: 3 additions & 1 deletion drools-pmml/src/main/resources/META-INF/kmodule.xml
Expand Up @@ -2,6 +2,7 @@
<kmodule xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://jboss.org/kie/6.0.0/kmodule">

<kbase name="PMML-Bayes-Rules" packages="org.drools.pmml.pmml_4_1.compiler.bayes" eventProcessingMode="stream" />
<kbase name="PMML-Cluster-Rules" packages="org.drools.pmml.pmml_4_1.compiler.clustering" eventProcessingMode="stream" />
<kbase name="PMML-Neural-Rules" packages="org.drools.pmml.pmml_4_1.compiler.neural" eventProcessingMode="stream" />
<kbase name="PMML-Regression-Rules" packages="org.drools.pmml.pmml_4_1.compiler.regression" eventProcessingMode="stream" />
Expand All @@ -11,13 +12,14 @@

<kbase name="PMML-Base" packages="org.drools.pmml.pmml_4_1.compiler" eventProcessingMode="stream" />

<kbase name="PMML-Bayes" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Bayes-Rules" />
<kbase name="PMML-Cluster" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Cluster-Rules" />
<kbase name="PMML-Neural" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Neural-Rules"/>
<kbase name="PMML-Regression" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Regression-Rules"/>
<kbase name="PMML-Scorecard" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Scorecard-Rules"/>
<kbase name="PMML-SVM" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-SVM-Rules"/>
<kbase name="PMML-Tree" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Tree-Rules"/>

<kbase name="PMML" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Cluster-Rules,PMML-Neural-Rules,PMML-Regression-Rules,PMML-Scorecard-Rules,PMML-SVM-Rules,PMML-Tree-Rules" />
<kbase name="PMML" packages="-" eventProcessingMode="stream" includes="PMML-Base,PMML-Bayes-Rules,PMML-Cluster-Rules,PMML-Neural-Rules,PMML-Regression-Rules,PMML-Scorecard-Rules,PMML-SVM-Rules,PMML-Tree-Rules" />

</kmodule>
@@ -0,0 +1,98 @@
package org.drools.pmml.pmml_4_1.compiler;

import org.dmg.pmml.pmml_4_1.descr.*;
import java.util.*;



rule "SVMRoot"
when
$nbm : NaiveBayesModel()
then
utils.applyTemplate( "naiveBayesDeclare.drlt", utils, registry, null, theory );
end

rule "visitNaiveBayes_context"
salience -9
when
$nbm : NaiveBayesModel( $name : modelName )
then
utils.context = utils.compactUpperCase( $name );

HashMap map = new HashMap( 3 );
map.put( "type","NaiveBayesModel" );
map.put( "name", utils.context );

utils.applyTemplate( "modelMark.drlt", utils, registry, map, theory );
end

rule "visitNaiveBayes_inputs"
salience -9
when
$nbm : NaiveBayesModel( $name : modelName, $info : extensionsAndBayesOutputsAndBayesInputs, $thold : threshold )
BayesOutput( this memberOf $info, $tgtFld : fieldName )
TypeOfField( $tgtFld, $type ; )

BayesInputs( this memberOf $info, $inputs : bayesInputs )
$bin : BayesInput( this memberOf $inputs, $fld : fieldName, $pairs : pairCounts, $dfld : derivedField == null )
then
utils.context = utils.compactUpperCase( $name );

HashMap map = new HashMap( 7 );
map.put( "name", utils.context );
map.put( "threshold", $thold );
map.put( "field", utils.compactUpperCase( $fld ) );
map.put( "datatype", $type );
map.put( "pairs", $pairs );

utils.applyTemplate( "naiveBayesBuildCounts.drlt", utils, registry, map, theory );

utils.applyTemplate( "naiveBayesEval.drlt", utils, registry, map, theory );

end

rule "visitNaiveBayes_inputs_with_discretized_fields"
salience -9
when
$nbm : NaiveBayesModel( $name : modelName, $info : extensionsAndBayesOutputsAndBayesInputs, $thold : threshold )
BayesOutput( this memberOf $info, $tgtFld : fieldName )
TypeOfField( $tgtFld, $type ; )

BayesInputs( this memberOf $info, $inputs : bayesInputs )
$bin : BayesInput( this memberOf $inputs, $pairs : pairCounts, $dfld : derivedField != null )
$dfl : DerivedField( this == $dfld, $fld : name )
then
utils.context = utils.compactUpperCase( $name );

HashMap map = new HashMap( 5 );
map.put( "name", utils.context );
map.put( "field", utils.compactUpperCase( $fld ) );
map.put( "datatype", $type );
map.put( "pairs", $pairs );

utils.applyTemplate( "naiveBayesBuildCounts.drlt", utils, registry, map, theory );

utils.applyTemplate( "naiveBayesEval.drlt", utils, registry, map, theory );

end



rule "visitNaiveBayes_outputs"
salience -9
when
$nbm : NaiveBayesModel( $name : modelName, $info : extensionsAndBayesOutputsAndBayesInputs, $thold : threshold )
BayesOutput( this memberOf $info, $fld : fieldName, $tvc : targetValueCounts )
TypeOfField( $fld, $type ; )

then
utils.context = utils.compactUpperCase( $name );

HashMap map = new HashMap( 7 );
map.put( "name", utils.context );
map.put( "threshold", $thold );
map.put( "field", utils.compactUpperCase( $fld ) );
map.put( "datatype", $type );
map.put( "tvc", $tvc );
utils.applyTemplate( "naiveBayesBuildOuts.drlt", utils, registry, map, theory );
end
Expand Up @@ -631,6 +631,17 @@ end
//******************************************************************************************************************


rule "Inherit_Discretize_Datatype"
salience 100
when
$fld : DerivedField( $dx : discretize, $type : dataType )
$dsc : Discretize( this == $dx, $dataType : dataType == null )
then
modify ( $dsc ) {
setDataType( $type );
}
end


rule "processDerivedField_Discretize_mapMissing"
dialect "mvel"
Expand Down Expand Up @@ -682,12 +693,13 @@ when
Discretize( this == $dx, $f : field , $d : defaultValue != null, $bins : discretizeBins, $dataType : dataType )
then
HashMap map = new HashMap(7);
map.put("context",utils.context);
map.put("name",utils.compactUpperCase($fld.name));
map.put("origField",utils.compactUpperCase($f));
map.put("bins",$bins);
map.put("target",utils.format($dataType,$d));
utils.applyTemplate("outOfBinningDefault.drlt", utils, registry, map, theory);
map.put( "context", utils.context );
map.put( "name", utils.compactUpperCase( $fld.name ) );
map.put( "origField", utils.compactUpperCase( $f ) );
map.put( "bins", $bins );
map.put( "datatype", $dataType );
map.put( "target", $d );
utils.applyTemplate( "outOfBinningDefault.drlt", utils, registry, map, theory );
end


Expand All @@ -696,17 +708,18 @@ rule "processDerivedField_Discretize_bin"
dialect "mvel"
when
$fld : DerivedField( $dx : discretize )
Discretize( this == $dx, $f : field , $d : defaultValue != null, $bins : discretizeBins, $dataType : dataType )
Discretize( this == $dx, $f : field , $d : defaultValue, $bins : discretizeBins, $dataType : dataType )
$bin : DiscretizeBin( $interval : interval, $x : binValue ) from $bins
then
HashMap map = new HashMap(7);
map.put("context",utils.context);
map.put("name",utils.compactUpperCase($fld.name));
map.put("origField",utils.compactUpperCase($f));
map.put("intv",$interval);
map.put("index",utils.nextCount());
map.put("target",utils.format($dataType,$x));
utils.applyTemplate("intervalBinning.drlt", utils, registry, map, theory);
HashMap map = new HashMap( 11 );
map.put( "context", utils.context );
map.put( "name", utils.compactUpperCase( $fld.name ) );
map.put( "origField", utils.compactUpperCase( $f ) );
map.put( "intv", $interval );
map.put( "index", utils.nextCount() );
map.put( "datatype", $dataType );
map.put( "target", $x );
utils.applyTemplate( "intervalBinning.drlt", utils, registry, map, theory );
end


Expand Down Expand Up @@ -1807,10 +1820,10 @@ end
rule "Output_Type2"
no-loop
when
$of : OutputField( $name : name, dataType == null, $tgt : targetField )
$of : OutputField( $name : name, dataType == null, $feat : feature, $tgt : targetField )
TypeOfField( name == $tgt, $dataType : dataType )
then
insertLogical( new TypeOfField( $name, $dataType ) );
insertLogical( new TypeOfField( $name, utils.mapFeatureType( $dataType, $feat ) ) );
end

rule "Bind_FactField_Outputield"
Expand Down Expand Up @@ -1890,13 +1903,15 @@ when
feature == RESULTFEATURE.PREDICTED_DISPLAY_VALUE || == RESULTFEATURE.WARNING
|| == RESULTFEATURE.PROBABILITY )
$tf : TypeOfField( name == $name, $type : dataType )
$sf : TypeOfField( name == $tgt, $srcType : dataType )
then
HashMap map = new HashMap( 7 );
map.put( "context", utils.context );
map.put( "origField", utils.compactUpperCase( $tgt ) );
map.put( "name", utils.compactUpperCase( $name ) );
map.put( "value", $val );
map.put( "type", $type );
map.put( "srcType", $srcType );
map.put( "feature", $feat );
utils.applyTemplate( "addOutputFeature.drlt", utils, registry, map, theory );
end
Expand Down
Expand Up @@ -853,4 +853,37 @@ when
$a : Attribute( $set : simpleSetPredicate != null )
then
insertLogical( $set );
end
end


//------------------------------------------------------------------------------------------------------------//

rule "visitNaiveBayes"
salience -10
when
$nbm : NaiveBayesModel( $info : extensionsAndBayesOutputsAndBayesInputs, $name : modelName )
$o : Object() from $info
then
insertLogical( $o );
end

rule "visitBayesInputs"
salience -11
when
$bis : BayesInputs( $in : bayesInputs )
$bin : BayesInput() from $in
then
insertLogical( $bin );
end

rule "visitBayesDerivedField"
salience -11
when
$bin : BayesInput( $dfld : derivedField != null, $fld : fieldName )
then
if ( $dfld.getName() == null ) {
// only allowed derivation is discretization
$dfld.setName( $fld + "__Discrete" );
}
insertLogical( $dfld );
end
Expand Up @@ -29,14 +29,14 @@
rule "propagateMissing_@{name}"
when
$src : @{origField}( missing == true, $ctx : context
@if{ context != null } , context == @{context} @end{})
@if{ context != null } , context == "@{context}" @end{} )
then
@{name} x = new @{name}();
x.setValue(null);
x.setName(@{format("string",name)});
x.setMissing(true);
x.setValid(true);
x.setContext($ctx);
x.setValue( null );
x.setName( @{ format( "string", name ) } );
x.setMissing( true );
x.setValid( true );
x.setContext( "@{context}" );
//x.setContinous();
insertLogical(x);
end
Expand Down
@@ -0,0 +1,43 @@
@comment{

Copyright 2011 JBoss Inc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
}




@comment{

}

@declare{'naiveBayesBuildCounts'}

rule "BuildCounts_@{name}_@{field}"
dialect "mvel"
when
ModelMarker( "@{name}" ; enabled == true )
then
map = new java.util.HashMap();
bc = new BayesCounts( "@{name}", "@{field}", map, null );
@foreach{ pair : pairs }
map.put( "@{ pair.value }", [ @foreach{ tvc : pair.targetValueCounts.targetValueCounts } @{ tvc.count } @end{ ',' } ] );
@end{}
insertLogical( bc );
end

@end{}


@includeNamed{'naiveBayesBuildCounts'}

0 comments on commit 88f4f7f

Please sign in to comment.