Skip to content

Commit

Permalink
Merge pull request #159 from SANSA-Stack/feature/SVA_default_timestamp
Browse files Browse the repository at this point in the history
Feature/sva default timestamp
  • Loading branch information
carstendraschner committed May 31, 2021
2 parents 41ce829 + b6b8a68 commit 30e1b92
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 1 deletion.
Expand Up @@ -179,5 +179,9 @@ object DistRDF2ML_Regression {

val metagraph: RDD[Triple] = ml2Graph.transform(predictions)
metagraph.take(10).foreach(println(_))

metagraph
.coalesce(1)
.saveAsNTriplesFile(args(0) + "someFolder")
}
}
Expand Up @@ -7,6 +7,9 @@ import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.{Decimal, DoubleType, StringType, StructType}
import org.apache.spark.sql.functions.{udf, _}
import java.sql.Timestamp

import org.apache.spark.sql.types._

import scala.collection.mutable

Expand Down Expand Up @@ -34,6 +37,7 @@ class SmartVectorAssembler extends Transformer{
// null replacement
protected var _nullDigitReplacement: Int = -1
protected var _nullStringReplacement: String = ""
protected var _nullTimestampReplacement: Timestamp = Timestamp.valueOf("1900-01-01 00:00:00")

protected var _word2VecSize = 2
protected var _word2VecMinCount = 1
Expand Down Expand Up @@ -96,6 +100,9 @@ class SmartVectorAssembler extends Transformer{
else if (datatype.toLowerCase == "digit") _nullDigitReplacement = {
value.asInstanceOf[Int]
}
else if (datatype.toLowerCase == "timestamp") _nullTimestampReplacement = {
value.asInstanceOf[Timestamp]
}
else {
println("only digit and string are supported")
}
Expand Down Expand Up @@ -399,6 +406,9 @@ class SmartVectorAssembler extends Transformer{
}
else if (featureType.contains("Timestamp") & featureType.contains("Single")) {
dfCollapsedTwoColumns
.withColumn(featureColumn, col(featureColumn).cast("string"))
.na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn))
.withColumn(featureColumn, col(featureColumn).cast("timestamp"))
.withColumn(featureName + "UnixTimestamp(Single_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int"))
.withColumn(featureName + "DayOfWeek(Single_NonCategorical_Int)", dayofweek(col(featureColumn)))
.withColumn(featureName + "DayOfMonth(Single_NonCategorical_Int)", dayofmonth(col(featureColumn)))
Expand All @@ -415,6 +425,9 @@ class SmartVectorAssembler extends Transformer{
val df1 = df0
.select(col(_entityColumn), explode_outer(col(featureColumn)))
.withColumnRenamed("col", featureColumn)
.withColumn(featureColumn, col(featureColumn).cast("string"))
.na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn))
.withColumn(featureColumn, col(featureColumn).cast("timestamp"))

val df2 = df1
.withColumn(featureName + "UnixTimestamp(ListOf_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int"))
Expand Down
18 changes: 18 additions & 0 deletions sansa-ml/sansa-ml-spark/src/test/resources/utils/svaTest.ttl
@@ -0,0 +1,18 @@
@prefix : <http://dig.isi.edu/> .
@prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> .
@prefix owl: <http://www.w3.org/2002/07/owl#> .
@prefix xsd: <http://www.w3.org/2001/XMLSchema#> .

:John a :Person ;
:name "John" ;
:hasSpouse :Mary ;
:age "28"^^<http://www.w3.org/2001/XMLSchema#integer>.
:Mary a :Person ;
:name "Mary" ;
:hasSpouse :John ;
:age "25"^^<http://www.w3.org/2001/XMLSchema#integer>.
:John_jr a :Person ;
:name "John Jr." ;
:hasParent :John, :Mary ;
:age "2"^^<http://www.w3.org/2001/XMLSchema#integer>;
:birthday "2000-01-01T00:00:00Z"^^<http://www.w3.org/2001/XMLSchema#dateTime> .
Expand Up @@ -24,7 +24,7 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{
.config("spark.sql.crossJoin.enabled", true)
.getOrCreate()

private val dataPath = this.getClass.getClassLoader.getResource("utils/test.ttl").getPath
private val dataPath = this.getClass.getClassLoader.getResource("utils/svaTest.ttl").getPath
private def getData() = {
import net.sansa_stack.rdf.spark.io._
import net.sansa_stack.rdf.spark.model._
Expand Down Expand Up @@ -90,6 +90,66 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{
.setLabelColumn("seed__down_age(Single_NonCategorical_Decimal)")
.setNullReplacement("string", "Hallo")
.setNullReplacement("digit", -1000)
.setNullReplacement("timestamp", java.sql.Timestamp.valueOf("1900-01-01 00:00:00"))
.setWord2VecSize(3)
.setWord2VecMinCount(1)



val mlReadyDf = smartVectorAssembler
.transform(collapsedDf)
.cache()

assert(inputDfSize == mlReadyDf.count())

assert(mlReadyDf.columns.toSet == Set("entityID", "label", "features"))

mlReadyDf.show(false)
mlReadyDf.schema.foreach(println(_))
}

test("Test2 SmartVectorAssembler") {
val dataset = getData()

val queryString = """
|SELECT
|?seed
|?seed__down_age
|?seed__down_name
|?seed__down_birthday
|
|WHERE {
| ?seed a <http://dig.isi.edu/Person> .
|
| OPTIONAL {
| ?seed <http://dig.isi.edu/age> ?seed__down_age .
| }
| OPTIONAL {
| ?seed <http://dig.isi.edu/name> ?seed__down_name .
| }
| OPTIONAL {
| ?seed <http://dig.isi.edu/birthday> ?seed__down_birthday .
| }
|
|}""".stripMargin
val sparqlFrame = new SparqlFrame()
.setSparqlQuery(queryString)
.setCollapsByKey(true)
.setCollapsColumnName("seed")
val collapsedDf = sparqlFrame
.transform(dataset)
.cache()

collapsedDf.show(false)

val inputDfSize = collapsedDf.count()

val smartVectorAssembler = new SmartVectorAssembler()
.setEntityColumn("seed")
.setLabelColumn("seed__down_age(Single_NonCategorical_Decimal)")
.setNullReplacement("string", "Hallo")
.setNullReplacement("digit", -1000)
.setNullReplacement("timestamp", java.sql.Timestamp.valueOf("1900-01-01 00:00:00"))
.setWord2VecSize(3)
.setWord2VecMinCount(1)

Expand Down

0 comments on commit 30e1b92

Please sign in to comment.