Skip to content
This repository has been archived by the owner on Oct 8, 2020. It is now read-only.

Commit

Permalink
Make kge package Scala-style code compliant, as requested on #5
Browse files Browse the repository at this point in the history
  • Loading branch information
GezimSejdiu committed Aug 21, 2018
1 parent 4cb0a84 commit 0d63d61
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 121 deletions.
@@ -1,5 +1,8 @@
package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples
import org.apache.spark.sql._

/**
* Bootstrapping
* -------------
Expand All @@ -8,18 +11,12 @@ package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation
*
* Created by lpfgarcia
*/

import org.apache.spark.sql._

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples

class Bootstrapping(data: Dataset[IntegerTriples])
extends CrossValidation[Dataset[IntegerTriples]] {
extends CrossValidation[Dataset[IntegerTriples]] {

def crossValidation() = {
def crossValidation(): (Dataset[IntegerTriples], Dataset[IntegerTriples]) = {
val train = data.sample(true, 1)
val test = data.except(train)
(train, test)
}

}
}
Expand Up @@ -13,4 +13,4 @@ trait CrossValidation[T] {

def crossValidation: (T, T)

}
}
@@ -1,5 +1,8 @@
package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples
import org.apache.spark.sql._

/**
* Hould Out
* ---------
Expand All @@ -8,22 +11,17 @@ package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation
*
* Created by lpfgarcia
*/

import org.apache.spark.sql._

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples

case class rateException(info: String) extends Exception

class Holdout(data: Dataset[IntegerTriples], rate: Float) extends CrossValidation[Dataset[IntegerTriples]] {

if (rate < 0 || rate >= 1)
if (rate < 0 || rate >= 1) {
throw new rateException("Rate value should be higher than 0 and lower than 1")
}

def crossValidation() = {
def crossValidation(): (Dataset[IntegerTriples], Dataset[IntegerTriples]) = {
val train = data.sample(false, rate)
val test = data.except(train)
(train, test)
}

}
}
@@ -1,5 +1,8 @@
package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples
import org.apache.spark.sql._

/**
* k-fold Cross Validation
* -----------------------
Expand All @@ -9,26 +12,23 @@ package net.sansa_stack.ml.spark.kge.linkprediction.crossvalidation
* Created by lpfgarcia
*/

import org.apache.spark.sql._

import net.sansa_stack.rdf.spark.kge.triples.IntegerTriples

case class kException(info: String) extends Exception

case class withIndex(Subject: Int, Predicate: Int, Object: Int, k: Int)

class kFold(data: Dataset[IntegerTriples], k: Int, sk: SparkSession)
extends CrossValidation[Seq[Dataset[IntegerTriples]]] {
extends CrossValidation[Seq[Dataset[IntegerTriples]]] {

import sk.implicits._

if (k > 1 && k <= 10)
if (k > 1 && k <= 10) {
throw new kException("The k value should be higher than 1 and lower or equal to 10")
}

val id = (1 to data.count().toInt / k).flatMap(List.fill(k)(_))
val fold = sk.sparkContext.parallelize(id, data.rdd.getNumPartitions)

def crossValidation() = {
def crossValidation(): (IndexedSeq[Dataset[IntegerTriples]], IndexedSeq[Dataset[IntegerTriples]]) = {

val df = sk.createDataFrame(data.rdd.zip(fold).map { r =>
withIndex(r._1.Subject, r._1.Predicate, r._1.Object, r._2)
Expand All @@ -45,4 +45,4 @@ class kFold(data: Dataset[IntegerTriples], k: Int, sk: SparkSession)
(train, test)
}

}
}
Expand Up @@ -9,9 +9,8 @@ package net.sansa_stack.ml.spark.kge.linkprediction.evaluate

object Evaluate {

def meanRank(left: Array[Float], right: Array[Float]) {
def meanRank(left: Array[Float], right: Array[Float]): (Float, Float) = {
(left.sum / left.length,
right.sum / right.length)
}

}
}
@@ -1,5 +1,11 @@
package net.sansa_stack.ml.spark.kge.linkprediction.models

import com.intel.analytics.bigdl.optim.Adam
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import net.sansa_stack.rdf.spark.kge.triples.{ IntegerTriples, StringTriples }
import org.apache.spark.sql._

/**
* DistMult: diagonal bilinear model
* ---------------------------------
Expand All @@ -9,32 +15,23 @@ package net.sansa_stack.ml.spark.kge.linkprediction.models
*
* Created by lpfgarcia on 20/11/2017.
*/

import org.apache.spark.sql._

import com.intel.analytics.bigdl.optim.Adam
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat

import net.sansa_stack.rdf.spark.kge.triples.{StringTriples,IntegerTriples}

class DistMult(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession)
extends Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {
extends Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {

val epochs = 100
val rate = 0.01f

var opt = new Adam(learningRate = rate)

def dist(data: Dataset[IntegerTriples]) = {
def dist(data: Dataset[IntegerTriples]): Float = {
val aux = data.collect().map { i =>
e(i.Subject) * r(i.Predicate) * e(i.Object)
}.reduce((a, b) => a + b)

L2(aux)
}

def run() = {
def run(): Unit = {

for (i <- 1 to epochs) {

Expand All @@ -53,5 +50,4 @@ class DistMult(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k:

}
}

}
}
@@ -1,23 +1,21 @@
package net.sansa_stack.ml.spark.kge.linkprediction.models

/**
* Model Abstract Class
* --------------------
*
* Created by lpfgarcia on 14/11/2017.
*/

import scala.math._
import scala.util._

import org.apache.spark.sql._

import com.intel.analytics.bigdl.nn.Power
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import net.sansa_stack.rdf.spark.kge.triples.{ IntegerTriples, StringTriples }
import org.apache.spark.sql._

import net.sansa_stack.rdf.spark.kge.triples.{StringTriples,IntegerTriples}

/**
* Model Abstract Class
* --------------------
*
* Created by lpfgarcia on 14/11/2017.
*/
abstract class Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {

val Ne = ne
Expand All @@ -26,40 +24,40 @@ abstract class Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {
var e = initialize(ne)
var r = normalize(initialize(nr))

def initialize(size: Int) = {
def initialize(size: Int): Tensor[Float] = {
Tensor(size, k).rand(-6 / sqrt(k), 6 / sqrt(k))
}

def normalize(data: Tensor[Float]) = {
def normalize(data: Tensor[Float]): Tensor[Float] = {
data / data.abs().sum()
}

import sk.implicits._

val seed = new Random(System.currentTimeMillis())

def tuple(aux: IntegerTriples) = {
def tuple(aux: IntegerTriples): IntegerTriples = {
if (seed.nextBoolean()) {
IntegerTriples(seed.nextInt(Ne) + 1, aux.Predicate, aux.Object)
} else {
IntegerTriples(aux.Subject, aux.Predicate, seed.nextInt(Ne) + 1)
}
}

def negative(data: Dataset[IntegerTriples]) = {
def negative(data: Dataset[IntegerTriples]): Dataset[IntegerTriples] = {
data.map(i => tuple(i))
}

def subset(data: Dataset[IntegerTriples]) = {
def subset(data: Dataset[IntegerTriples]): Dataset[IntegerTriples] = {
data.sample(false, 2 * (batch.toDouble / data.count().toDouble)).limit(batch)
}

def L1(vec: Tensor[Float]) = {
def L1(vec: Tensor[Float]): Float = {
vec.abs().sum()
}

def L2(vec: Tensor[Float]) = {
def L2(vec: Tensor[Float]): Float = {
vec.pow(2).sqrt().sum()
}

}
}
@@ -1,5 +1,14 @@
package net.sansa_stack.ml.spark.kge.linkprediction.models

import scala.math._

import com.intel.analytics.bigdl.optim.Adam
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import net.sansa_stack.rdf.spark.kge.triples.{ IntegerTriples, StringTriples }
import org.apache.spark.sql._


/**
* TransE embedding model
* ----------------------
Expand All @@ -9,19 +18,8 @@ package net.sansa_stack.ml.spark.kge.linkprediction.models
*
* Created by lpfgarcia on 14/11/2017.
*/

import scala.math._

import org.apache.spark.sql._

import com.intel.analytics.bigdl.optim.Adam
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat

import net.sansa_stack.rdf.spark.kge.triples.{StringTriples,IntegerTriples}

class TransE(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k: Int, margin: Float, L: String, sk: SparkSession)
extends Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {
extends Models(ne: Int, nr: Int, batch: Int, k: Int, sk: SparkSession) {

val epochs = 1000
val rate = 0.01f
Expand All @@ -30,12 +28,12 @@ class TransE(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k: In

val myL = L match {
case "L2" => L2 _
case _ => L1 _
case _ => L1 _
}

import sk.implicits._

def dist(data: Dataset[IntegerTriples]) = {
def dist(data: Dataset[IntegerTriples]): Float = {

val aux = data.collect().map { i =>
e(i.Subject) + r(i.Predicate) - e(i.Object)
Expand All @@ -44,11 +42,11 @@ class TransE(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k: In
myL(aux)
}

def dist(row: IntegerTriples) = {
def dist(row: IntegerTriples): Tensor[Float] = {
e(row.Subject) + r(row.Predicate) - e(row.Object)
}

def run() = {
def run(): Unit = {

for (i <- 1 to epochs) {

Expand All @@ -70,5 +68,4 @@ class TransE(train: Dataset[IntegerTriples], ne: Int, nr: Int, batch: Int, k: In

}
}

}
}
@@ -1,29 +1,27 @@
package net.sansa_stack.ml.spark.kge.linkprediction.prediction

import net.sansa_stack.rdf.spark.kge.triples.{ IntegerTriples, StringTriples }
import org.apache.spark.sql._

/**
* Predict Abstract Class
* ----------------------
*
* Created by lpfgarcia on 14/11/2017.
*/

import org.apache.spark.sql._

import net.sansa_stack.rdf.spark.kge.triples.{StringTriples,IntegerTriples}

abstract class Evaluate(test: Dataset[IntegerTriples]) {

def left(row: IntegerTriples, i: Int) = {
def left(row: IntegerTriples, i: Int): IntegerTriples = {
IntegerTriples(i, row.Predicate, row.Object)
}

def right(row: IntegerTriples, i: Int) = {
def right(row: IntegerTriples, i: Int): IntegerTriples = {
IntegerTriples(row.Subject, row.Predicate, i)
}

def rank(row: IntegerTriples, spo: String): Integer

def ranking() = {
def ranking(): (Seq[Integer], Seq[Integer]) = {

var l, r = Seq[Integer]()

Expand All @@ -35,7 +33,7 @@ abstract class Evaluate(test: Dataset[IntegerTriples]) {
(l, r)
}

def rawHits10() = {
def rawHits10(): (Seq[Boolean], Seq[Boolean]) = {

var l, r = Seq[Boolean]()

Expand All @@ -46,5 +44,4 @@ abstract class Evaluate(test: Dataset[IntegerTriples]) {

(l, r)
}

}
}

0 comments on commit 0d63d61

Please sign in to comment.