Skip to content

Scala Library for extracting useful information from trained Spark Model (DecisionTreeClassificationModel)

License

Notifications You must be signed in to change notification settings

RaistlinTAO/SparkModelHelper

Repository files navigation

Spark Model Helper

A helper that extracting useful information from trained Spark Model

For my beloved 
We were made for each other.
You are the giver
I am the wonder
And I am so blessed to be your disciple.

License

Sonatype Maven Central

Code Size Repo Size

Have you tired staring at the model.toDebugString() for hours and getting no clue at all? Something like this:

DecisionTreeClassificationModel (uid=dtc_e933455b) of depth 5 with 341 nodes
  If (feature 518 <= 1.5)
   If (feature 6 <= 2.5)
    If (feature 30 <= 20)
     If (feature 45 <= 7)
      If (feature 24 <= 2.5)
        ...
      Else (feature 160 > 0.99)
      If (feature 64 <= 3.5)
       Predict: 0.0
      Else (feature 64 > 3.5)
       Predict: 1.0

Well now you have this helper designed for HUMAN, which matters.

Contents

This helper was built for Scala 2.12.15 and Spark 2.4.8. Also tested with the latest Scala and Spark.

Usage

DecisionTreeClassificationModel Analysis


1. Get Root Feature Index from trained model

In automated ML, sometimes you need to retrain model due to the scalar metric as evaluation result (precision and recall for instance) are not within a desired range. By using rootFeatureIndex we can change the dataframe accordingly.

  val helper = new DecisionModelHelper(model)
  println("Root Feature Index: " + helper.getRootFeature)

2. Get the JSON string from DecisionTreeClassificationModel

  val helper = new DecisionModelHelper(model)
  println("toJson: " + helper.toJson + "\n")

Return beatified JSON string:

    {
      "featureIndex": 367,
      "gain": 0.10617781879627633,
      "impurity": 0.3144732024264778,
      "threshold": 1.5,
      "nodeType": "internal",
      "splitType": "continuous",
      "prediction": 0.0,
      "leftChild": {
        ....
      "path": "F(367)|0.3144732024264778|0.0|1.5|"
    }

3. Return an Object of Model Node-Tree

    val nodeObj = helper.getDecisionNode
    println("getDecisionNode: " + nodeObj)

Object version of JSON String

4. Return Root to Leaf Path of Rules

    val rules = helper.getRulesList(1, 0.2)
    rules.foreach(rule => {
      println("Rule: " + rule.mkString(", "))
    })

The above code prints:

    Rule: F(396)|0.3144732024264778|0.0|1.5|L, F(12)|0.49791192623975383|1.0|2.5|R, F(223)|0.2998340735773348|1.0|2500000.0|R, F(20)|0.19586076183802947|1.0|3.523971665E10|L, None|0.1902980108641974|1.0|None|E

The function returns List[List[String]], the structure of each String is

    Feature_Index | impurity | prediction | threshold | node_type

For example, Feature Index 45 has impurity 3.5, prediction 1, threshold 1.5 and the path goes right after this node, the string will be:

    F(45)|3.5|1|1.5|R

The Leaf nodes will have "E" as node_type

5. Customise the Feature_Index

Feature Index is not designed for human reading, especially with large amount of columns. The helper also supports customisation of Features

  val helper = new DecisionModelHelper(model)
    helper.setFeatureName(
      Map(0 -> "UserID", 1 -> "UserCity", 2 -> "Salary" ... )
    )

The helper automatically change the F(1) into "UserCity" upon called setFeatureName(Map[Int, String])

    F(1)|3.5|1|1.5|R

will output as

    UserCity|3.5|1|1.5|R

DecisionTreeClassificationModel Rule Helper

1. Convert Rules into Python or Scala

    rules.foreach(rule => {
      println("Rule (Scala): " + '\n' + DecisionRuleHelper.getStatementFromPathList(rule, language = Language.Scala))
      println("Rule (Python): " + '\n' + DecisionRuleHelper.getStatementFromPathList(rule, language = Language.Python))
    })

You can get code directly from this function and just Copy & Paste

Rule (Scala): 
if (F(58) <= 1.5 && F(56) > 2.5 && F(20) > 2500000.0 && F(20) <= 3.523971665E10) 1.0
Rule (Python): 
if F(58) <= 1.5 and F(56) > 2.5 and F(20) > 2500000.0 and F(20) <= 3.523971665E10: 1.0

2. Evaluate new data against Rules

For testing new data against extracted rules, Remember ONLY support non-customised Feature Name

    //Load the data into Dataframe from CSV
    var data = spark.read.format("csv").option("header", value = false).option("inferSchema", value = true).load("sample/testing_data.csv").na.drop("all")
    println("Data has been loaded into Dataframe from CSV file, or you can use existing Dataframe directly")
    //Convert Dataframe Schema Datatype into a ListBuffer
    val typeList = ListBuffer[String]()
    data.schema.foreach(schemaNode => {
      schemaNode.dataType match {
        case IntegerType => {
          typeList += "INT"
        }
        case DoubleType => {
          typeList += "DOUBLE"
        }
        case LongType => {
          typeList += "LONG"
        }
        case _ => {
          typeList += "DEC"
        }
      }
    })
    //Loop Though Every Rules (Start from 1 for Easy Column Name)
    for (ruleIndex <- 1 to rules.length) {
      val columns = data.columns

      //Define the User Defined Function: Using DecisionRuleHelper.evaluateRule
      def combineUdf = udf((row: Row) => DecisionRuleHelper.evaluateRule(row, rules(ruleIndex - 1), typeList.toList))

      //Create a new column that save the evaluation result of each rows into Column
      //If data in one row fit the rule, the result is 1, otherwise 0
      data = data.withColumn("Rule" + ruleIndex.toString, combineUdf(struct(columns.map(col): _*)))
      println("Rule " + ruleIndex + " proceed")
    }
    data.write.format("csv").save("result/" + System.currentTimeMillis() / 1000)
    println("Result CSV Saved, or you can use data (Dataframe) directly")

After the loop end, the data itself changed into a new dataframe with several additional columns depends on your rules extracted from model(1 rule for 1 column). Each additional column represents result of evaluation of current row.

For example, if row 1 fit Rule 1:

(F(58) <= 1.5 && F(56) > 2.5 && F(20) > 2500000.0 && F(20) <= 3.523971665E10) 1.0

You will get 1 otherwise 0