Skip to content

Commit

Permalink
add variant for catboost inference that does not create a spark dataf…
Browse files Browse the repository at this point in the history
  • Loading branch information
jdries committed May 5, 2022
1 parent 3c80014 commit 2e00af6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 33 deletions.
6 changes: 6 additions & 0 deletions openeo-geotrellis/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@
<version>1.0.4</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>ai.catboost</groupId>
<artifactId>catboost-prediction</artifactId>
<version>1.0.4</version>
<scope>compile</scope>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.openeo.geotrellis

import ai.catboost.CatBoostModel
import ai.catboost.spark.CatBoostClassificationModel
import geotrellis.raster.mapalgebra.local._
import geotrellis.raster.{ArrayTile, ByteUserDefinedNoDataCellType, CellType, Dimensions, DoubleConstantTile, FloatConstantNoDataCellType, FloatConstantTile, IntConstantNoDataCellType, IntConstantTile, MultibandTile, MutableArrayTile, NODATA, ShortConstantTile, Tile, UByteConstantTile, UByteUserDefinedNoDataCellType, UShortUserDefinedNoDataCellType, isData, isNoData}
Expand Down Expand Up @@ -898,50 +899,67 @@ class OpenEOProcessScriptBuilder {
val operator = (rs: Seq[Tile], context: Map[String, Any]) => {
// 1. Checks.
val modelCheck = context.getOrElse("context", null)
if (!modelCheck.isInstanceOf[CatBoostClassificationModel])
if (!modelCheck.isInstanceOf[CatBoostClassificationModel] && !modelCheck.isInstanceOf[CatBoostModel])
throw new IllegalArgumentException(
s"The 'model' argument should contain a valid Catboost model, but got: $modelCheck.")
rs.assertEqualDimensions()
val layerCount = rs.length
if (layerCount == 0) sys.error(s"No features provided for predict_catboost.")

// 2. Convert Seq[Tile] to Dataframe.
val newCellType = rs.map(_.cellType).reduce(_.union(_))
val Dimensions(cols, rows) = rs.head.dimensions
val evalData = mutable.Buffer[Row]()
cfor(0)(_ < rows, _ + 1) { row =>
cfor(0)(_ < cols, _ + 1) { col =>
val features = new Array[Double](layerCount)
cfor(0)(_ < layerCount, _ + 1) { i =>
features(i) = rs(i).getDouble(col, row)
if(context("context").isInstanceOf[CatBoostModel]){

val model = context("context").asInstanceOf[CatBoostModel]

def sigmoid(x: Double) = 1. / (1 + Math.pow(Math.E, -x))
multibandReduce(MultibandTile(rs),ts => {
val numericalFeatures = ts.map(_.floatValue()).toArray
val rawPrediction = model.predict(numericalFeatures,null.asInstanceOf[Array[Int]])
val sigmoids: Array[Double] = rawPrediction.copyRowMajorPredictions().map(sigmoid)
val theClass = sigmoids.indices.maxBy(sigmoids)
theClass

},true)
}else{

// 2. Convert Seq[Tile] to Dataframe.
val newCellType = rs.map(_.cellType).reduce(_.union(_))
val Dimensions(cols, rows) = rs.head.dimensions
val evalData = mutable.Buffer[Row]()
cfor(0)(_ < rows, _ + 1) { row =>
cfor(0)(_ < cols, _ + 1) { col =>
val features = new Array[Double](layerCount)
cfor(0)(_ < layerCount, _ + 1) { i =>
features(i) = rs(i).getDouble(col, row)
}
evalData.append(Row(Vectors.dense(features)))
}
evalData.append(Row(Vectors.dense(features)))
}
}
val spark = SparkSession.builder.config(SparkContext.getOrCreate().getConf).getOrCreate()
val srcDataSchema = Seq(StructField("features", SQLDataTypes.VectorType))
val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))

// 3. Generate predictions.
val model = context("context").asInstanceOf[CatBoostClassificationModel]
val predictions: DataFrame = model.transform(evalDf)

// 4. Convert DataFrame back to Seq[Tile].
val tile = ArrayTile.alloc(newCellType, cols, rows)
val predIter: util.Iterator[Row] = predictions.toLocalIterator()
cfor(0)(_ < rows, _ + 1) { row =>
cfor(0)(_ < cols, _ + 1) { col =>
if (!predIter.hasNext) {
throw new IllegalStateException("Not enough predictions to fill all pixels.")
val spark = SparkSession.builder.config(SparkContext.getOrCreate().getConf).getOrCreate()
val srcDataSchema = Seq(StructField("features", SQLDataTypes.VectorType))
val evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))

// 3. Generate predictions.
val model = context("context").asInstanceOf[CatBoostClassificationModel]
val predictions: DataFrame = model.transform(evalDf)

// 4. Convert DataFrame back to Seq[Tile].
val tile = ArrayTile.alloc(newCellType, cols, rows)
val predIter: util.Iterator[Row] = predictions.toLocalIterator()
cfor(0)(_ < rows, _ + 1) { row =>
cfor(0)(_ < cols, _ + 1) { col =>
if (!predIter.hasNext) {
throw new IllegalStateException("Not enough predictions to fill all pixels.")
}
// Note: predIter.next() = [features, rawPrediction, probability, prediction]
tile.setDouble(col, row, predIter.next().get(3) match {
case d: Double => d
case o => throw new IllegalStateException(s"Predictions are not in Double format: $o")
})
}
// Note: predIter.next() = [features, rawPrediction, probability, prediction]
tile.setDouble(col, row, predIter.next().get(3) match {
case d: Double => d
case o => throw new IllegalStateException(s"Predictions are not in Double format: $o")
})
}
Seq(tile)
}
Seq(tile)

}
// Return our operator in a composed function.
val storedArgs = contextStack.head
Expand Down

0 comments on commit 2e00af6

Please sign in to comment.