Skip to content
This repository has been archived by the owner on Oct 8, 2020. It is now read-only.

Commit

Permalink
Added method for fixpoint iteration in Spark Dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzBuehmann committed Nov 17, 2017
1 parent a0fb542 commit f0f15e3
Showing 1 changed file with 33 additions and 5 deletions.
@@ -1,10 +1,9 @@
package net.sansa_stack.inference.spark.forwardchaining

import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD

import net.sansa_stack.inference.utils.Logging
import org.apache.spark.sql.Dataset

/**
* Creates a new RDD by performing bulk iterations using the given step function. The first
Expand All @@ -22,10 +21,10 @@ trait FixpointIteration[T] extends Logging {
object FixpointIteration extends Logging {

/**
* Creates a new RDD by performing bulk iterations using the given step function `f`. The first
* RDD the step function returns is the input for the next iteration, the second RDD is
* Creates a new [[RDD]] by performing bulk iterations using the given step function `f`. The first
* RDD the step function returns is the input for the next iteration, the second [[RDD]] is
* the termination criterion. The iterations terminate when either the termination criterion
* RDD contains no elements or when `maxIterations` iterations have been performed.
* [[RDD]] contains no elements or when `maxIterations` iterations have been performed.
*
**/
def apply[T: ClassTag](maxIterations: Int = 10)(rdd: RDD[T], f: RDD[T] => RDD[T]): RDD[T] = {
Expand All @@ -48,4 +47,33 @@ object FixpointIteration extends Logging {
}
newRDD
}

/**
*
* Creates a new [[Dataset]] by performing bulk iterations using the given step function `f`. The first
* [[Dataset]] the step function returns is the input for the next iteration, the second RDD is
* the termination criterion. The iterations terminate when either the termination criterion
* RDD contains no elements or when `maxIterations` iterations have been performed.
*
**/
def apply2[T: ClassTag](maxIterations: Int = 10)(dataset: Dataset[T], f: Dataset[T] => Dataset[T]): Dataset[T] = {
var newDS = dataset
newDS.cache()
var i = 1
var oldCount = 0L
var nextCount = if (newDS.count() == 0) 0L else 1L
while (nextCount != oldCount) {
log.info(s"iteration $i...")
oldCount = nextCount
info(s"i:$nextCount")
newDS = newDS
.union(f(newDS))
.distinct()
.cache()
nextCount = newDS.count()
info(s"i+1:$nextCount")
i += 1
}
newDS
}
}

0 comments on commit f0f15e3

Please sign in to comment.