diff --git a/src/main/resources/application.conf b/src/main/resources/application.conf index 4e497d0bb46fbb..abf43b92886f32 100644 --- a/src/main/resources/application.conf +++ b/src/main/resources/application.conf @@ -52,5 +52,6 @@ settings { overrideConfPath = "./application.conf" } performance { - + serialization = "object" + useBroadcast = true } \ No newline at end of file diff --git a/src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala b/src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala index b8738116d7a385..d713366ae16eac 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala @@ -14,11 +14,11 @@ trait HasFeatures { protected def set[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this} - protected def setDefault[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this} + protected def setDefault[T](feature: ArrayFeature[T], value: () => Array[T]): this.type = {feature.setFallback(Some(value)); this} - protected def setDefault[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this} + protected def setDefault[K, V](feature: MapFeature[K, V], value: () => Map[K, V]): this.type = {feature.setFallback(Some(value)); this} - protected def setDefault[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this} + protected def setDefault[T](feature: StructFeature[T], value: () => T): this.type = {feature.setFallback(Some(value)); this} protected def get[T](feature: ArrayFeature[T]): Option[Array[T]] = feature.get diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala index efdb6d6adc0961..0c1863b0d33fa6 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/Lemmatizer.scala @@ -40,7 +40,7 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] { setDefault(lemmaValSep, config.getString("nlp.lemmaDict.vSeparator")) if (config.getString("nlp.lemmaDict.file").nonEmpty) - setDefault(lemmaDict, Lemmatizer.retrieveLemmaDict( + setDefault(lemmaDict, () => Lemmatizer.retrieveLemmaDict( config.getString("nlp.lemmaDict.file"), config.getString("nlp.lemmaDict.format"), config.getString("nlp.lemmaDict.kvSeparator"), diff --git a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/crf/NerCrfModel.scala b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/crf/NerCrfModel.scala index d940f6b96fedb1..a444b6de44b06d 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/crf/NerCrfModel.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/annotators/ner/crf/NerCrfModel.scala @@ -25,7 +25,7 @@ class NerCrfModel(override val uid: String) extends ModelWithWordEmbeddings[NerC def setModel(crf: LinearChainCrfModel): NerCrfModel = set(model, crf) def setDictionaryFeatures(dictFeatures: DictionaryFeatures): this.type = set(dictionaryFeatures, dictFeatures.dict) - setDefault(dictionaryFeatures, Map.empty[String, String]) + setDefault(dictionaryFeatures, () => Map.empty[String, String]) def setEntities(toExtract: Array[String]): NerCrfModel = set(entities, toExtract) diff --git a/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala b/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala index da916d90f31bf5..f3a251f533423b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala @@ -1,14 +1,9 @@ package com.johnsnowlabs.nlp.serialization -import java.io.File -import java.nio.file.{Files, Paths} - import com.johnsnowlabs.nlp.HasFeatures -import com.johnsnowlabs.nlp.embeddings.{ModelWithWordEmbeddings, WordEmbeddings, WordEmbeddingsClusterHelper} +import com.johnsnowlabs.nlp.util.ConfigHelper import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkContext, SparkFiles} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.ml.param.{ParamMap, Params} import org.apache.spark.sql.{Encoder, Encoders, SparkSession} import scala.reflect.ClassTag @@ -16,26 +11,73 @@ import scala.reflect.ClassTag abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag](model: HasFeatures, val name: String)(implicit val sparkSession: SparkSession = SparkSession.builder().getOrCreate()) extends Serializable { model.features.append(this) - final protected var value: Option[Broadcast[TComplete]] = None + private val config = ConfigHelper.retrieve + + val serializationMode: String = config.getString("performance.serialization") + val useBroadcast: Boolean = config.getBoolean("performance.useBroadcast") - def serialize(spark: SparkSession, path: String, field: String, value: TComplete): Unit + final protected var broadcastValue: Option[Broadcast[TComplete]] = None + final protected var rawValue: Option[TComplete] = None + final protected var fallback: Option[() => TComplete] = None + + final def serialize(spark: SparkSession, path: String, field: String, value: TComplete): Unit = { + serializationMode match { + case "dataset" => serializeDataset(spark, path, field, value) + case "object" => serializeObject(spark, path, field, value) + case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'") + } + } final def serializeInfer(spark: SparkSession, path: String, field: String, value: Any): Unit = serialize(spark, path, field, value.asInstanceOf[TComplete]) - def deserialize(spark: SparkSession, path: String, field: String): Option[_] + final def deserialize(spark: SparkSession, path: String, field: String): Option[_] = { + if (broadcastValue.isDefined || rawValue.isDefined) + throw new Exception(s"Trying de deserialize an already set value for ${this.name}. This should not happen.") + serializationMode match { + case "dataset" => deserializeDataset(spark, path, field) + case "object" => deserializeObject(spark, path, field) + case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'") + } + } + + protected def serializeDataset(spark: SparkSession, path: String, field: String, value: TComplete): Unit + + protected def deserializeDataset(spark: SparkSession, path: String, field: String): Option[_] + + protected def serializeObject(spark: SparkSession, path: String, field: String, value: TComplete): Unit + + protected def deserializeObject(spark: SparkSession, path: String, field: String): Option[_] final protected def getFieldPath(path: String, field: String): Path = Path.mergePaths(new Path(path), new Path("/fields/" + field)) - final def get: Option[TComplete] = value.map(_.value) - final def getValue: TComplete = value.map(_.value).getOrElse(throw new Exception(s"feature $name is not set")) + final def get: Option[TComplete] = { + broadcastValue.map(_.value).orElse(rawValue) + } + + final def getValue: TComplete = { + broadcastValue.map(_.value).orElse(rawValue).orElse(fallback.map(_())).getOrElse(throw new Exception(s"feature $name is not set")) + } + final def setValue(v: Option[Any]): HasFeatures = { - if (isSet) value.get.destroy() - value = Some(sparkSession.sparkContext.broadcast[TComplete](v.get.asInstanceOf[TComplete])) + if (useBroadcast) { + if (isSet) broadcastValue.get.destroy() + broadcastValue = Some(sparkSession.sparkContext.broadcast[TComplete](v.get.asInstanceOf[TComplete])) + } else { + rawValue = Some(v.get.asInstanceOf[TComplete]) + } model } - final def isSet: Boolean = value.isDefined + + def setFallback(v: Option[() => TComplete]): HasFeatures = { + fallback = v + model + } + + final def isSet: Boolean = { + broadcastValue.isDefined || rawValue.isDefined + } } @@ -44,12 +86,27 @@ class StructFeature[TValue: ClassTag](model: HasFeatures, override val name: Str implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue] - override def serialize(spark: SparkSession, path: String, field: String, value: TValue): Unit = { + override def serializeObject(spark: SparkSession, path: String, field: String, value: TValue): Unit = { + val dataPath = getFieldPath(path, field) + spark.sparkContext.parallelize(Seq(value)).saveAsObjectFile(dataPath.toString) + } + + override def deserializeObject(spark: SparkSession, path: String, field: String): Option[TValue] = { + val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val dataPath = getFieldPath(path, field) + if (fs.exists(dataPath)) { + Some(spark.sparkContext.objectFile[TValue](dataPath.toString).first) + } else { + None + } + } + + override def serializeDataset(spark: SparkSession, path: String, field: String, value: TValue): Unit = { val dataPath = getFieldPath(path, field) spark.createDataset(Seq(value)).write.mode("overwrite").parquet(dataPath.toString) } - override def deserialize(spark: SparkSession, path: String, field: String): Option[TValue] = { + override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[TValue] = { val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) val dataPath = getFieldPath(path, field) if (fs.exists(dataPath)) { @@ -66,7 +123,24 @@ class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override implicit val encoder: Encoder[(TKey, TValue)] = Encoders.kryo[(TKey, TValue)] - override def serialize(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = { + override def serializeObject(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = { + val dataPath = getFieldPath(path, field) + spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString) + } + + + + override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = { + val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val dataPath = getFieldPath(path, field) + if (fs.exists(dataPath)) { + Some(spark.sparkContext.objectFile[(TKey, TValue)](dataPath.toString).collect.toMap) + } else { + None + } + } + + override def serializeDataset(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = { import spark.implicits._ val dataPath = getFieldPath(path, field) value.toSeq.toDS.write.mode("overwrite").parquet(dataPath.toString) @@ -74,7 +148,7 @@ class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override - override def deserialize(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = { + override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = { val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) val dataPath = getFieldPath(path, field) if (fs.exists(dataPath)) { @@ -91,12 +165,27 @@ class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: Stri implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue] - override def serialize(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = { + override def serializeObject(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = { + val dataPath = getFieldPath(path, field) + spark.sparkContext.parallelize(value).saveAsObjectFile(dataPath.toString) + } + + override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = { + val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) + val dataPath = getFieldPath(path, field) + if (fs.exists(dataPath)) { + Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect()) + } else { + None + } + } + + override def serializeDataset(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = { val dataPath = getFieldPath(path, field) spark.createDataset(value).write.mode("overwrite").parquet(dataPath.toString) } - override def deserialize(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = { + override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = { val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration) val dataPath = getFieldPath(path, field) if (fs.exists(dataPath)) { diff --git a/src/test/resources/application.conf b/src/test/resources/application.conf index c65ea165f0ffb8..f5872cfcf6efb3 100644 --- a/src/test/resources/application.conf +++ b/src/test/resources/application.conf @@ -36,7 +36,7 @@ nlp { revert = -1.0 } } - performance { - useBroadcast = false -} + serialization = "object" + useBroadcast = true +} \ No newline at end of file