From 2906a9e9c8fac03a178f1a89c9348fa64fb297e4 Mon Sep 17 00:00:00 2001 From: carstendraschner Date: Tue, 15 Jun 2021 11:07:06 +0200 Subject: [PATCH 1/3] integrate hashing as alternative to indexing of categorical strings. adjusted also unit tests and offer new setter --- .../SmartVectorAssembler.scala | 465 ++++++++++-------- ...SimilarityExperimentMetaGraphFactory.scala | 2 +- .../SmartVectorAssemblerTest.scala | 4 + 3 files changed, 265 insertions(+), 206 deletions(-) diff --git a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala index fabb84aa3..09567fe6d 100644 --- a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala +++ b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala @@ -1,7 +1,7 @@ package net.sansa_stack.ml.spark.featureExtraction import org.apache.spark.ml.Transformer -import org.apache.spark.ml.feature.{StopWordsRemover, StringIndexer, Tokenizer, VectorAssembler, Word2Vec} +import org.apache.spark.ml.feature.{HashingTF, StopWordsRemover, StringIndexer, Tokenizer, VectorAssembler, Word2Vec} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -34,6 +34,8 @@ class SmartVectorAssembler extends Transformer{ protected var _numericCollapsingStrategy: String = "median" protected var _stringCollapsingStrategy: String = "concat" + protected var _digitStringStrategy: String = "hash" + // null replacement protected var _nullDigitReplacement: Int = -1 protected var _nullStringReplacement: String = "" @@ -149,6 +151,17 @@ class SmartVectorAssembler extends Transformer{ this } + /** + * setter for of strategy to transform categorical strings to digit. option one is hash option two is index + * @param digitStringStrategy strategy, either hash or index + * @return transformer + */ + def setDigitStringStrategy(digitStringStrategy: String): this.type = { + assert(Seq("hash", "index").contains(digitStringStrategy)) + _digitStringStrategy = digitStringStrategy + this + } + /** * Validate set column to check if we need fallback to first column if not set * and if set if it is in available cols @@ -257,7 +270,8 @@ class SmartVectorAssembler extends Transformer{ .select(_entityColumn, featureColumn) var newFeatureColumnName: String = featureName - val digitizedDf: DataFrame = if (featureType == "Single_NonCategorical_String") { + val digitizedDf: DataFrame = + if (featureType == "Single_NonCategorical_String") { newFeatureColumnName += "(Word2Vec)" val dfCollapsedTwoColumnsNullsReplaced = dfCollapsedTwoColumns @@ -307,226 +321,263 @@ class SmartVectorAssembler extends Transformer{ .withColumnRenamed("output", newFeatureColumnName) .select(_entityColumn, newFeatureColumnName) } - else if (featureType == "ListOf_NonCategorical_String") { - newFeatureColumnName += "(Word2Vec)" + else if (featureType == "ListOf_NonCategorical_String") { + newFeatureColumnName += "(Word2Vec)" - val dfCollapsedTwoColumnsNullsReplaced = dfCollapsedTwoColumns - .na.fill(_nullStringReplacement) - .withColumn("sentences", concat_ws(". ", col(featureColumn))) - .select(_entityColumn, "sentences") + val dfCollapsedTwoColumnsNullsReplaced = dfCollapsedTwoColumns + .na.fill(_nullStringReplacement) + .withColumn("sentences", concat_ws(". ", col(featureColumn))) + .select(_entityColumn, "sentences") - val tokenizer = new Tokenizer() - .setInputCol("sentences") - .setOutputCol("words") + val tokenizer = new Tokenizer() + .setInputCol("sentences") + .setOutputCol("words") - val tokenizedDf = tokenizer - .transform(dfCollapsedTwoColumnsNullsReplaced) - .select(_entityColumn, "words") + val tokenizedDf = tokenizer + .transform(dfCollapsedTwoColumnsNullsReplaced) + .select(_entityColumn, "words") - val remover = new StopWordsRemover() - .setInputCol("words") - .setOutputCol("filtered") + val remover = new StopWordsRemover() + .setInputCol("words") + .setOutputCol("filtered") - val inputDf = remover - .transform(tokenizedDf) - .select(_entityColumn, "filtered") - .persist() - - val word2vec = new Word2Vec() - .setInputCol("filtered") - .setOutputCol("output") - .setMinCount(_word2VecMinCount) - .setVectorSize(_word2VecSize) - - val word2vecTrainingDf = if (_word2vecTrainingDfSizeRatio == 1) { - inputDf + val inputDf = remover + .transform(tokenizedDf) + .select(_entityColumn, "filtered") .persist() - } else { - inputDf - .sample(withReplacement = false, fraction = _word2vecTrainingDfSizeRatio).toDF() - .persist() - } - - val word2vecModel = word2vec - .fit(word2vecTrainingDf) - - word2vecTrainingDf.unpersist() - - word2vecModel - .transform(inputDf) - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType == "Single_Categorical_String") { - newFeatureColumnName += "(IndexedString)" - val inputDf = dfCollapsedTwoColumns - .na.fill(_nullStringReplacement) - .cache() - - val indexer = new StringIndexer() - .setInputCol(featureColumn) - .setOutputCol("output") - - indexer - .fit(inputDf) - .transform(inputDf) - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType == "ListOf_Categorical_String") { - newFeatureColumnName += "(ListOfIndexedString)" - - val inputDf = dfCollapsedTwoColumns - .select(col(_entityColumn), explode_outer(col(featureColumn))) - .na.fill(_nullStringReplacement) - .cache() - - val stringIndexerTrainingDf = if (_stringIndexerTrainingDfSizeRatio == 1) { - inputDf - .persist() - } else { - inputDf - .sample(withReplacement = false, fraction = _stringIndexerTrainingDfSizeRatio).toDF() - .persist() + val word2vec = new Word2Vec() + .setInputCol("filtered") + .setOutputCol("output") + .setMinCount(_word2VecMinCount) + .setVectorSize(_word2VecSize) + + val word2vecTrainingDf = if (_word2vecTrainingDfSizeRatio == 1) { + inputDf + .persist() + } else { + inputDf + .sample(withReplacement = false, fraction = _word2vecTrainingDfSizeRatio).toDF() + .persist() + } + + val word2vecModel = word2vec + .fit(word2vecTrainingDf) + + word2vecTrainingDf.unpersist() + + word2vecModel + .transform(inputDf) + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else if (featureType == "Single_Categorical_String") { + + val inputDf = dfCollapsedTwoColumns + .na.fill(_nullStringReplacement) + .cache() + + if (_digitStringStrategy == "index") { + newFeatureColumnName += "(IndexedString)" + + val indexer = new StringIndexer() + .setInputCol(featureColumn) + .setOutputCol("output") + + indexer + .fit(inputDf) + .transform(inputDf) + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else { + newFeatureColumnName += "(Single_Categorical_HashedString)" + + inputDf + .withColumn("output", hash(col(featureColumn)).cast(DoubleType)) + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + /* val hashingTF = new HashingTF() + .setInputCol(featureColumn) + .setOutputCol("output") + + hashingTF + .transform(inputDf) + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) */ + } } + else if (featureType == "ListOf_Categorical_String") { + + val inputDf = dfCollapsedTwoColumns + .select(col(_entityColumn), explode_outer(col(featureColumn))) + .na.fill(_nullStringReplacement) + .cache() + + + val stringIndexerTrainingDf = if (_stringIndexerTrainingDfSizeRatio == 1) { + inputDf + .persist() + } else { + inputDf + .sample(withReplacement = false, fraction = _stringIndexerTrainingDfSizeRatio).toDF() + .persist() + } + + if (_digitStringStrategy == "index") { + newFeatureColumnName += "(ListOfIndexedString)" + + val indexer = new StringIndexer() + .setInputCol("col") + .setOutputCol("outputTmp") + + indexer + .fit(stringIndexerTrainingDf) + .transform(inputDf) + .groupBy(_entityColumn) + .agg(collect_list("outputTmp") as "output") + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else { + newFeatureColumnName += "(ListOf_Categorical_HashedString)" + + inputDf + .withColumn("output", hash(col("col")).cast(DoubleType)) + .groupBy(_entityColumn) + .agg(collect_list("output") as "output") + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } - val indexer = new StringIndexer() - .setInputCol("col") - .setOutputCol("outputTmp") - indexer - .fit(stringIndexerTrainingDf) - .transform(inputDf) - .groupBy(_entityColumn) - .agg(collect_list("outputTmp") as "output") - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType.contains("Timestamp") & featureType.contains("Single")) { - dfCollapsedTwoColumns - .withColumn(featureColumn, col(featureColumn).cast("string")) - .na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn)) - .withColumn(featureColumn, col(featureColumn).cast("timestamp")) - .withColumn(featureName + "UnixTimestamp(Single_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int")) - .withColumn(featureName + "DayOfWeek(Single_NonCategorical_Int)", dayofweek(col(featureColumn))) - .withColumn(featureName + "DayOfMonth(Single_NonCategorical_Int)", dayofmonth(col(featureColumn))) - .withColumn(featureName + "DayOfYear(Single_NonCategorical_Int)", dayofyear(col(featureColumn))) - .withColumn(featureName + "Year(Single_NonCategorical_Int)", year(col(featureColumn))) - .withColumn(featureName + "Month(Single_NonCategorical_Int)", month(col(featureColumn))) - .withColumn(featureName + "Hour(Single_NonCategorical_Int)", hour(col(featureColumn))) - .withColumn(featureName + "Minute(Single_NonCategorical_Int)", minute(col(featureColumn))) - .withColumn(featureName + "Second(Single_NonCategorical_Int)", second(col(featureColumn))) - .drop(featureColumn) - } - else if (featureType.contains("Timestamp") & featureType.contains("ListOf")) { - val df0 = dfCollapsedTwoColumns - val df1 = df0 - .select(col(_entityColumn), explode_outer(col(featureColumn))) - .withColumnRenamed("col", featureColumn) - .withColumn(featureColumn, col(featureColumn).cast("string")) - .na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn)) - .withColumn(featureColumn, col(featureColumn).cast("timestamp")) - - val df2 = df1 - .withColumn(featureName + "UnixTimestamp(ListOf_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int")) - .withColumn(featureName + "DayOfWeek(ListOf_NonCategorical_Int)", dayofweek(col(featureColumn))) - .withColumn(featureName + "DayOfMonth(ListOf_NonCategorical_Int)", dayofmonth(col(featureColumn))) - .withColumn(featureName + "DayOfYear(ListOf_NonCategorical_Int)", dayofyear(col(featureColumn))) - .withColumn(featureName + "Year(ListOf_NonCategorical_Int)", year(col(featureColumn))) - .withColumn(featureName + "Month(ListOf_NonCategorical_Int)", month(col(featureColumn))) - .withColumn(featureName + "Hour(ListOf_NonCategorical_Int)", hour(col(featureColumn))) - .withColumn(featureName + "Minute(ListOf_NonCategorical_Int)", minute(col(featureColumn))) - .withColumn(featureName + "Second(ListOf_NonCategorical_Int)", second(col(featureColumn))) - .drop(featureColumn) - .persist() - val subFeatureColumns = df2.columns.filter(_ != _entityColumn) - var df3 = df0 - .select(_entityColumn) - .persist() - for (subFeatureColumn <- subFeatureColumns) { - val df4 = df2.select(_entityColumn, subFeatureColumn) - .groupBy(_entityColumn) - .agg(collect_list(subFeatureColumn) as subFeatureColumn) - df3 = df3.join(df4, _entityColumn) } + else if (featureType.contains("Timestamp") & featureType.contains("Single")) { + dfCollapsedTwoColumns + .withColumn(featureColumn, col(featureColumn).cast("string")) + .na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn)) + .withColumn(featureColumn, col(featureColumn).cast("timestamp")) + .withColumn(featureName + "UnixTimestamp(Single_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int")) + .withColumn(featureName + "DayOfWeek(Single_NonCategorical_Int)", dayofweek(col(featureColumn))) + .withColumn(featureName + "DayOfMonth(Single_NonCategorical_Int)", dayofmonth(col(featureColumn))) + .withColumn(featureName + "DayOfYear(Single_NonCategorical_Int)", dayofyear(col(featureColumn))) + .withColumn(featureName + "Year(Single_NonCategorical_Int)", year(col(featureColumn))) + .withColumn(featureName + "Month(Single_NonCategorical_Int)", month(col(featureColumn))) + .withColumn(featureName + "Hour(Single_NonCategorical_Int)", hour(col(featureColumn))) + .withColumn(featureName + "Minute(Single_NonCategorical_Int)", minute(col(featureColumn))) + .withColumn(featureName + "Second(Single_NonCategorical_Int)", second(col(featureColumn))) + .drop(featureColumn) + } + else if (featureType.contains("Timestamp") & featureType.contains("ListOf")) { + val df0 = dfCollapsedTwoColumns + val df1 = df0 + .select(col(_entityColumn), explode_outer(col(featureColumn))) + .withColumnRenamed("col", featureColumn) + .withColumn(featureColumn, col(featureColumn).cast("string")) + .na.fill(value = _nullTimestampReplacement.toString, cols = Array(featureColumn)) + .withColumn(featureColumn, col(featureColumn).cast("timestamp")) + + val df2 = df1 + .withColumn(featureName + "UnixTimestamp(ListOf_NonCategorical_Int)", unix_timestamp(col(featureColumn)).cast("int")) + .withColumn(featureName + "DayOfWeek(ListOf_NonCategorical_Int)", dayofweek(col(featureColumn))) + .withColumn(featureName + "DayOfMonth(ListOf_NonCategorical_Int)", dayofmonth(col(featureColumn))) + .withColumn(featureName + "DayOfYear(ListOf_NonCategorical_Int)", dayofyear(col(featureColumn))) + .withColumn(featureName + "Year(ListOf_NonCategorical_Int)", year(col(featureColumn))) + .withColumn(featureName + "Month(ListOf_NonCategorical_Int)", month(col(featureColumn))) + .withColumn(featureName + "Hour(ListOf_NonCategorical_Int)", hour(col(featureColumn))) + .withColumn(featureName + "Minute(ListOf_NonCategorical_Int)", minute(col(featureColumn))) + .withColumn(featureName + "Second(ListOf_NonCategorical_Int)", second(col(featureColumn))) + .drop(featureColumn) + .persist() - df2.unpersist() - df3 - } - - else if ( - featureType.startsWith("ListOf") && - (featureType.endsWith("Double") || featureType.endsWith("Decimal") || featureType.endsWith("Int") || featureType.endsWith("Integer")) - ) { - newFeatureColumnName += s"(${featureType})" - - dfCollapsedTwoColumns - .select(col(_entityColumn), explode_outer(col(featureColumn))) - .withColumnRenamed("col", "output") - .na.fill(_nullDigitReplacement) - .groupBy(_entityColumn) - .agg(collect_list("output") as "output") - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) + val subFeatureColumns = df2.columns.filter(_ != _entityColumn) + var df3 = df0 + .select(_entityColumn) + .persist() + for (subFeatureColumn <- subFeatureColumns) { + val df4 = df2.select(_entityColumn, subFeatureColumn) + .groupBy(_entityColumn) + .agg(collect_list(subFeatureColumn) as subFeatureColumn) + df3 = df3.join(df4, _entityColumn) + } + + df2.unpersist() + df3 + } + else if ( + featureType.startsWith("ListOf") && + (featureType.endsWith("Double") || featureType.endsWith("Decimal") || featureType.endsWith("Int") || featureType.endsWith("Integer")) + ) { + newFeatureColumnName += s"(${featureType})" - } - else if (featureType.endsWith("Double")) { - newFeatureColumnName += s"(${featureType})" + dfCollapsedTwoColumns + .select(col(_entityColumn), explode_outer(col(featureColumn))) + .withColumnRenamed("col", "output") + .na.fill(_nullDigitReplacement) + .groupBy(_entityColumn) + .agg(collect_list("output") as "output") + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) - dfCollapsedTwoColumns - .withColumnRenamed(featureColumn, "output") - .na.fill(_nullDigitReplacement) - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType.endsWith("Integer") || featureType.endsWith("Int")) { - newFeatureColumnName += s"(${featureType})" - - dfCollapsedTwoColumns - .withColumn("output", col(featureColumn).cast(DoubleType)) - // .withColumnRenamed(featureColumn, "output") - .na.fill(_nullDigitReplacement) - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType.endsWith("Boolean")) { - newFeatureColumnName += s"(${featureType})" - - dfCollapsedTwoColumns - .withColumn("output", col(featureColumn).cast(DoubleType)) - // .withColumnRenamed(featureColumn, "output") - .na.fill(_nullDigitReplacement) - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else if (featureType.endsWith("Decimal")) { - newFeatureColumnName += s"(${featureType})" - - dfCollapsedTwoColumns - // .withColumn("output", col(featureColumn).cast(DoubleType)) - .withColumnRenamed(featureColumn, "output") - .na.fill(_nullDigitReplacement) - .select(_entityColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } - else { - newFeatureColumnName += ("(notDigitizedYet)") - println("transformation not possible yet") - dfCollapsedTwoColumns - .withColumnRenamed(featureColumn, "output") - .withColumnRenamed("output", newFeatureColumnName) - .select(_entityColumn, newFeatureColumnName) - } + } + else if (featureType.endsWith("Double")) { + newFeatureColumnName += s"(${featureType})" + + dfCollapsedTwoColumns + .withColumnRenamed(featureColumn, "output") + .na.fill(_nullDigitReplacement) + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else if (featureType.endsWith("Integer") || featureType.endsWith("Int")) { + newFeatureColumnName += s"(${featureType})" + + dfCollapsedTwoColumns + .withColumn("output", col(featureColumn).cast(DoubleType)) + // .withColumnRenamed(featureColumn, "output") + .na.fill(_nullDigitReplacement) + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else if (featureType.endsWith("Boolean")) { + newFeatureColumnName += s"(${featureType})" + + dfCollapsedTwoColumns + .withColumn("output", col(featureColumn).cast(DoubleType)) + // .withColumnRenamed(featureColumn, "output") + .na.fill(_nullDigitReplacement) + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else if (featureType.endsWith("Decimal")) { + newFeatureColumnName += s"(${featureType})" + + dfCollapsedTwoColumns + // .withColumn("output", col(featureColumn).cast(DoubleType)) + .withColumnRenamed(featureColumn, "output") + .na.fill(_nullDigitReplacement) + .select(_entityColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } + else { + newFeatureColumnName += ("(notDigitizedYet)") + + println("transformation not possible yet") + dfCollapsedTwoColumns + .withColumnRenamed(featureColumn, "output") + .withColumnRenamed("output", newFeatureColumnName) + .select(_entityColumn, newFeatureColumnName) + } fullDigitizedDf = fullDigitizedDf.join( digitizedDf, @@ -543,6 +594,8 @@ class SmartVectorAssembler extends Transformer{ val onlyDigitizedDf = fullDigitizedDf .select(digitzedColumns.map(col(_)): _*) + // onlyDigitizedDf.show(false) + fullDigitizedDf.unpersist() // println("FIX FEATURE LENGTH") @@ -586,6 +639,8 @@ class SmartVectorAssembler extends Transformer{ } // println(s"columns to assemble:\n${columnsToAssemble.mkString(", ")}") + // fixedLengthFeatureDf.show(false) + val assembler = new VectorAssembler() .setInputCols(columnsToAssemble) .setOutputCol("features") diff --git a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/utils/SimilarityExperimentMetaGraphFactory.scala b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/utils/SimilarityExperimentMetaGraphFactory.scala index eab8d7299..4ddef0049 100644 --- a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/utils/SimilarityExperimentMetaGraphFactory.scala +++ b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/utils/SimilarityExperimentMetaGraphFactory.scala @@ -111,7 +111,7 @@ class SimilarityExperimentMetaGraphFactory { } - //noinspection ScalaStyle + // noinspection ScalaStyle /* def transform( df: DataFrame )( diff --git a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala index f5c936b7d..42f2790c7 100644 --- a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala +++ b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala @@ -79,6 +79,8 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ .setCollapsByKey(true) val collapsedDf = sparqlFrame .transform(dataset) + .withColumnRenamed("seed__down_name(Single_NonCategorical_String)", "seed__down_name(Single_Categorical_String)") + .withColumnRenamed("seed__down_hasParent__down_name(ListOf_NonCategorical_String)", "seed__down_hasParent__down_name(ListOf_Categorical_String)") .cache() collapsedDf.show(false) @@ -91,6 +93,7 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ .setNullReplacement("string", "Hallo") .setNullReplacement("digit", -1000) .setNullReplacement("timestamp", java.sql.Timestamp.valueOf("1900-01-01 00:00:00")) + .setDigitStringStrategy("hash") .setWord2VecSize(3) .setWord2VecMinCount(1) @@ -149,6 +152,7 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ .setLabelColumn("seed__down_age(Single_NonCategorical_Decimal)") .setNullReplacement("string", "Hallo") .setNullReplacement("digit", -1000) + .setDigitStringStrategy("index") .setNullReplacement("timestamp", java.sql.Timestamp.valueOf("1900-01-01 00:00:00")) .setWord2VecSize(3) .setWord2VecMinCount(1) From 4ad855ace978af3d8155f22771aea158705685c6 Mon Sep 17 00:00:00 2001 From: carstendraschner Date: Tue, 15 Jun 2021 14:32:20 +0200 Subject: [PATCH 2/3] integrated getter for feature vector descriptions --- .../SmartVectorAssembler.scala | 29 +++++++++++++++++++ .../SmartVectorAssemblerTest.scala | 14 +++++++++ 2 files changed, 43 insertions(+) diff --git a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala index 09567fe6d..d453eec6e 100644 --- a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala +++ b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala @@ -12,6 +12,7 @@ import java.sql.Timestamp import org.apache.spark.sql.types._ import scala.collection.mutable +import scala.collection.mutable.ListBuffer /** * This Transformer creates a needed Dataframe for common ML approaches in Spark MLlib. @@ -30,6 +31,9 @@ class SmartVectorAssembler extends Transformer{ // list of columns which should be used as features protected var _featureColumns: List[String] = null + // feature vector descrition, adjusted within process + var _featureVectorDescription: ListBuffer[String] = null + // working process onfiguration protected var _numericCollapsingStrategy: String = "median" protected var _stringCollapsingStrategy: String = "concat" @@ -162,6 +166,14 @@ class SmartVectorAssembler extends Transformer{ this } + /** + * get the description of explainable feature vector + * @return ListBuffer of Strings, describing for each index of the KG the content + */ + def getFeatureVectorDescription(): ListBuffer[String] = { + _featureVectorDescription + } + /** * Validate set column to check if we need fallback to first column if not set * and if set if it is in available cols @@ -640,6 +652,23 @@ class SmartVectorAssembler extends Transformer{ // println(s"columns to assemble:\n${columnsToAssemble.mkString(", ")}") // fixedLengthFeatureDf.show(false) + // fixedLengthFeatureDf.schema.foreach(println(_)) + // fixedLengthFeatureDf.first().toSeq.map(_.getClass).foreach(println(_)) + + _featureVectorDescription = new ListBuffer[String] + for (c <- columnsToAssemble) { + // println(sf.name) + if (c.contains("Word2Vec")) { // sf.dataType == org.apache.spark.ml.linalg.Vectors) { + // println fixedLengthFeatureDf.first().getAs[org.apache.spark.ml.linalg.DenseVector](sf.name).size + for (w2v_index <- (0 until fixedLengthFeatureDf.first().getAs[org.apache.spark.ml.linalg.DenseVector](c).size)) { + _featureVectorDescription.append(c + "_" + w2v_index) + } + } + else { + _featureVectorDescription.append(c) + } + } + // _featureVectorDescription.foreach(println(_)) val assembler = new VectorAssembler() .setInputCols(columnsToAssemble) diff --git a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala index 42f2790c7..159e47b11 100644 --- a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala +++ b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala @@ -103,6 +103,13 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ .transform(collapsedDf) .cache() + println("Feature vector description:") + smartVectorAssembler + .getFeatureVectorDescription() + .zipWithIndex + .map(_.swap) + .foreach(println(_)) + assert(inputDfSize == mlReadyDf.count()) assert(mlReadyDf.columns.toSet == Set("entityID", "label", "features")) @@ -163,6 +170,13 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ .transform(collapsedDf) .cache() + println("Feature vector description:") + smartVectorAssembler + .getFeatureVectorDescription() + .zipWithIndex + .map(_.swap) + .foreach(println(_)) + assert(inputDfSize == mlReadyDf.count()) assert(mlReadyDf.columns.toSet == Set("entityID", "label", "features")) From e778e74c639517fbcab4d888f08c7cb68e1ac943 Mon Sep 17 00:00:00 2001 From: carstendraschner Date: Tue, 15 Jun 2021 16:45:47 +0200 Subject: [PATCH 3/3] offer semantic represenation of transformer hyperparameters --- .../SmartVectorAssembler.scala | 207 +++++++++++++++++- .../spark/featureExtraction/SparqlFrame.scala | 96 +++++++- .../SmartVectorAssemblerTest.scala | 4 + .../featureExtraction/SparqlFrameTest.scala | 4 + 4 files changed, 309 insertions(+), 2 deletions(-) diff --git a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala index d453eec6e..90a05b28f 100644 --- a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala +++ b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssembler.scala @@ -8,7 +8,11 @@ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.types.{Decimal, DoubleType, StringType, StructType} import org.apache.spark.sql.functions.{udf, _} import java.sql.Timestamp +import java.util.Calendar +import org.apache.jena.datatypes.xsd.XSDDatatype +import org.apache.jena.graph.{Node, NodeFactory, Triple} +import org.apache.spark.rdd.RDD import org.apache.spark.sql.types._ import scala.collection.mutable @@ -55,7 +59,7 @@ class SmartVectorAssembler extends Transformer{ protected val spark = SparkSession.builder().getOrCreate() // needed default elements - override val uid: String = Identifiable.randomUID("sparqlFrame") + override val uid: String = Identifiable.randomUID("SmartVectorAssembler") override def copy(extra: ParamMap): Transformer = defaultCopy(extra) override def transformSchema(schema: StructType): StructType = throw new NotImplementedError() @@ -174,6 +178,207 @@ class SmartVectorAssembler extends Transformer{ _featureVectorDescription } + /** + * gain all inforamtion from this transformer as knowledge graph + * @return RDD[Trile] describing the meta information + */ + def getSemanticTransformerDescription(): RDD[org.apache.jena.graph.Triple] = { + /* + svahash type sva + svaahsh hyerparameter hyperparameterHash1 + hyperparameterHash1 label label + hyperparameterHash1 value value + hyperparameterHash1 type hyperparameter + ... + */ + val svaNode = NodeFactory.createBlankNode(uid) + val hyperparameterNodeP = NodeFactory.createURI("sansa-stack/sansaVocab/hyperparameter") + val hyperparameterNodeValue = NodeFactory.createURI("sansa-stack/sansaVocab/value") + val nodeLabel = NodeFactory.createURI("rdfs/label") + + + val triples = List( + Triple.create( + svaNode, + NodeFactory.createURI("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + NodeFactory.createURI("sansa-stack/sansaVocab/Transformer") + ), Triple.create( + svaNode, + NodeFactory.createURI("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + NodeFactory.createURI("sansa-stack/sansaVocab/SmartVectorAssembler") + ), + // _entityColumn + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_entityColumn").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_entityColumn").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue({if (_entityColumn != null) _entityColumn else "_entityColumn not set"}, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_entityColumn").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_entityColumn", XSDDatatype.XSDstring) + ), + // _labelColumn + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_labelColumn").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_labelColumn").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue({if (_labelColumn != null) _labelColumn else "_labelColumn not set"}, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_labelColumn").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_labelColumn", XSDDatatype.XSDstring) + ), + // _featureColumns + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_featureColumns").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_featureColumns").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue({if (_featureColumns != null) _featureColumns.mkString(", ") else "_featureColumns not set"}, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_featureColumns").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_featureColumns", XSDDatatype.XSDstring) + ), + // _entityColumn + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_featureVectorDescription").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_featureVectorDescription").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_featureVectorDescription.zipWithIndex.toSeq.map(_.swap).mkString(", "), XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_featureVectorDescription").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_featureVectorDescription", XSDDatatype.XSDstring) + ), + // _digitStringStrategy + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_digitStringStrategy").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_digitStringStrategy").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_digitStringStrategy, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_digitStringStrategy").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_digitStringStrategy", XSDDatatype.XSDstring) + ), + // _nullDigitReplacement + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_nullDigitReplacement").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullDigitReplacement").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_nullDigitReplacement, XSDDatatype.XSDint) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullDigitReplacement").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_nullDigitReplacement", XSDDatatype.XSDstring) + ), + // _nullStringReplacement + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_nullStringReplacement").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullStringReplacement").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_nullStringReplacement, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullStringReplacement").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_nullStringReplacement", XSDDatatype.XSDstring) + ), + // _nullTimestampReplacement + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_nullTimestampReplacement").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullTimestampReplacement").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_nullTimestampReplacement, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_nullTimestampReplacement").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_nullTimestampReplacement", XSDDatatype.XSDstring) + ), + // _word2VecSize + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_word2VecSize").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2VecSize").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_word2VecSize, XSDDatatype.XSDint) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2VecSize").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_word2VecSize", XSDDatatype.XSDstring) + ), + // _word2VecMinCount + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_word2VecMinCount").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2VecMinCount").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_word2VecMinCount, XSDDatatype.XSDint) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2VecMinCount").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_word2VecMinCount", XSDDatatype.XSDstring) + ), + // _word2vecTrainingDfSizeRatio + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_word2vecTrainingDfSizeRatio").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2vecTrainingDfSizeRatio").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_word2vecTrainingDfSizeRatio, XSDDatatype.XSDdouble) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_word2vecTrainingDfSizeRatio").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_word2vecTrainingDfSizeRatio", XSDDatatype.XSDstring) + ), + // _stringIndexerTrainingDfSizeRatio + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_stringIndexerTrainingDfSizeRatio").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_stringIndexerTrainingDfSizeRatio").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_stringIndexerTrainingDfSizeRatio, XSDDatatype.XSDdouble) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_stringIndexerTrainingDfSizeRatio").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_stringIndexerTrainingDfSizeRatio", XSDDatatype.XSDstring) + ) + ) + spark.sqlContext.sparkContext.parallelize(triples) + } + /** * Validate set column to check if we need fallback to first column if not set * and if set if it is in available cols diff --git a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrame.scala b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrame.scala index a9a5c1577..0ed8b04b1 100644 --- a/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrame.scala +++ b/sansa-ml/sansa-ml-spark/src/main/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrame.scala @@ -12,6 +12,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, FloatType, IntegerType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Dataset, Encoders, Row, SparkSession} import org.aksw.sparqlify.core.sql.common.serialization.SqlEscaperDoubleQuote +import org.apache.jena.datatypes.xsd.XSDDatatype import org.apache.jena.graph.{Node, NodeFactory, Triple} import org.apache.spark.sql.functions.{col, collect_list, max, min, size} @@ -90,7 +91,7 @@ class SparqlFrame extends Transformer{ this } - def getRDFdescription(): Array[org.apache.jena.graph.Triple] = { + /* def getRDFdescription(): Array[org.apache.jena.graph.Triple] = { val transformerURI: Node = NodeFactory.createURI("sansa-stack/ml/transfomer/Sparqlframe") val pipelineElementURI: Node = NodeFactory.createURI("sansa-stack/ml/sansaVocab/ml/pipelineElement") @@ -105,6 +106,99 @@ class SparqlFrame extends Transformer{ description } + */ + + /** + * gain all inforamtion from this transformer as knowledge graph + * @return RDD[Trile] describing the meta information + */ + def getSemanticTransformerDescription(): RDD[org.apache.jena.graph.Triple] = { + /* + svahash type sva + svaahsh hyerparameter hyperparameterHash1 + hyperparameterHash1 label label + hyperparameterHash1 value value + hyperparameterHash1 type hyperparameter + ... + */ + val svaNode = NodeFactory.createBlankNode(uid) + val hyperparameterNodeP = NodeFactory.createURI("sansa-stack/sansaVocab/hyperparameter") + val hyperparameterNodeValue = NodeFactory.createURI("sansa-stack/sansaVocab/value") + val nodeLabel = NodeFactory.createURI("rdfs/label") + + + val triples = List( + Triple.create( + svaNode, + NodeFactory.createURI("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + NodeFactory.createURI("sansa-stack/sansaVocab/Transformer") + ), Triple.create( + svaNode, + NodeFactory.createURI("http://www.w3.org/1999/02/22-rdf-syntax-ns#type"), + NodeFactory.createURI("sansa-stack/sansaVocab/SparqlFrame") + ), + // _query + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_query").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_query").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_query, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_query").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_query", XSDDatatype.XSDstring) + ), + // _queryExcecutionEngine + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_queryExcecutionEngine").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_queryExcecutionEngine").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_queryExcecutionEngine, XSDDatatype.XSDstring) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_queryExcecutionEngine").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_queryExcecutionEngine", XSDDatatype.XSDstring) + ), + // _collapsByKey + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_collapsByKey").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_collapsByKey").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteralByValue(_collapsByKey, XSDDatatype.XSDboolean) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_featureColumns").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_featureColumns", XSDDatatype.XSDstring) + ), + /* // _keyColumnNameString + Triple.create( + svaNode, + hyperparameterNodeP, + NodeFactory.createBlankNode((uid + "_keyColumnNameString").hashCode.toString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_keyColumnNameString").hashCode.toString), + hyperparameterNodeValue, + NodeFactory.createLiteral(_keyColumnNameString) + ), Triple.create( + NodeFactory.createBlankNode((uid + "_keyColumnNameString").hashCode.toString), + nodeLabel, + NodeFactory.createLiteralByValue("_keyColumnNameString", XSDDatatype.XSDstring) + ) + + */ + ) + spark.sqlContext.sparkContext.parallelize(triples) + } + override def transformSchema(schema: StructType): StructType = throw new NotImplementedError() diff --git a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala index 159e47b11..064403785 100644 --- a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala +++ b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SmartVectorAssemblerTest.scala @@ -183,5 +183,9 @@ class SmartVectorAssemblerTest extends FunSuite with SharedSparkContext{ mlReadyDf.show(false) mlReadyDf.schema.foreach(println(_)) + + smartVectorAssembler + .getSemanticTransformerDescription() + .foreach(println(_)) } } diff --git a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrameTest.scala b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrameTest.scala index 486378d9e..1eb801a65 100644 --- a/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrameTest.scala +++ b/sansa-ml/sansa-ml-spark/src/test/scala/net/sansa_stack/ml/spark/featureExtraction/SparqlFrameTest.scala @@ -132,5 +132,9 @@ class SparqlFrameTest extends FunSuite with SharedSparkContext{ assert(collapsedDf.columns.toSet == Set("seed", "seed__down_age(Single_NonCategorical_Decimal)", "seed__down_name(Single_NonCategorical_String)", "seed__down_hasParent__down_name(ListOf_NonCategorical_String)", "seed__down_hasParent__down_age(ListOf_NonCategorical_Decimal)")) assert(featureTypes("seed__down_hasParent__down_age")("isListOfEntries") == true) assert(featureTypes("seed__down_hasParent__down_name")("datatype") == StringType) + + collapsingSparqlFrame + .getSemanticTransformerDescription() + .foreach(println(_)) } }