Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ settings {
overrideConfPath = "./application.conf"
}
performance {

serialization = "object"
useBroadcast = true
}
6 changes: 3 additions & 3 deletions src/main/scala/com/johnsnowlabs/nlp/HasFeatures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ trait HasFeatures {

protected def set[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}

protected def setDefault[T](feature: ArrayFeature[T], value: Array[T]): this.type = {feature.setValue(Some(value)); this}
protected def setDefault[T](feature: ArrayFeature[T], value: () => Array[T]): this.type = {feature.setFallback(Some(value)); this}

protected def setDefault[K, V](feature: MapFeature[K, V], value: Map[K, V]): this.type = {feature.setValue(Some(value)); this}
protected def setDefault[K, V](feature: MapFeature[K, V], value: () => Map[K, V]): this.type = {feature.setFallback(Some(value)); this}

protected def setDefault[T](feature: StructFeature[T], value: T): this.type = {feature.setValue(Some(value)); this}
protected def setDefault[T](feature: StructFeature[T], value: () => T): this.type = {feature.setFallback(Some(value)); this}

protected def get[T](feature: ArrayFeature[T]): Option[Array[T]] = feature.get

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Lemmatizer(override val uid: String) extends AnnotatorModel[Lemmatizer] {
setDefault(lemmaValSep, config.getString("nlp.lemmaDict.vSeparator"))

if (config.getString("nlp.lemmaDict.file").nonEmpty)
setDefault(lemmaDict, Lemmatizer.retrieveLemmaDict(
setDefault(lemmaDict, () => Lemmatizer.retrieveLemmaDict(
config.getString("nlp.lemmaDict.file"),
config.getString("nlp.lemmaDict.format"),
config.getString("nlp.lemmaDict.kvSeparator"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class NerCrfModel(override val uid: String) extends ModelWithWordEmbeddings[NerC
def setModel(crf: LinearChainCrfModel): NerCrfModel = set(model, crf)

def setDictionaryFeatures(dictFeatures: DictionaryFeatures): this.type = set(dictionaryFeatures, dictFeatures.dict)
setDefault(dictionaryFeatures, Map.empty[String, String])
setDefault(dictionaryFeatures, () => Map.empty[String, String])

def setEntities(toExtract: Array[String]): NerCrfModel = set(entities, toExtract)

Expand Down
129 changes: 109 additions & 20 deletions src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala
Original file line number Diff line number Diff line change
@@ -1,41 +1,83 @@
package com.johnsnowlabs.nlp.serialization

import java.io.File
import java.nio.file.{Files, Paths}

import com.johnsnowlabs.nlp.HasFeatures
import com.johnsnowlabs.nlp.embeddings.{ModelWithWordEmbeddings, WordEmbeddings, WordEmbeddingsClusterHelper}
import com.johnsnowlabs.nlp.util.ConfigHelper
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.{SparkContext, SparkFiles}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.ml.param.{ParamMap, Params}
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}

import scala.reflect.ClassTag

abstract class Feature[Serializable1, Serializable2, TComplete: ClassTag](model: HasFeatures, val name: String)(implicit val sparkSession: SparkSession = SparkSession.builder().getOrCreate()) extends Serializable {
model.features.append(this)

final protected var value: Option[Broadcast[TComplete]] = None
private val config = ConfigHelper.retrieve

val serializationMode: String = config.getString("performance.serialization")
val useBroadcast: Boolean = config.getBoolean("performance.useBroadcast")

def serialize(spark: SparkSession, path: String, field: String, value: TComplete): Unit
final protected var broadcastValue: Option[Broadcast[TComplete]] = None
final protected var rawValue: Option[TComplete] = None
final protected var fallback: Option[() => TComplete] = None

final def serialize(spark: SparkSession, path: String, field: String, value: TComplete): Unit = {
serializationMode match {
case "dataset" => serializeDataset(spark, path, field, value)
case "object" => serializeObject(spark, path, field, value)
case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'")
}
}

final def serializeInfer(spark: SparkSession, path: String, field: String, value: Any): Unit =
serialize(spark, path, field, value.asInstanceOf[TComplete])

def deserialize(spark: SparkSession, path: String, field: String): Option[_]
final def deserialize(spark: SparkSession, path: String, field: String): Option[_] = {
if (broadcastValue.isDefined || rawValue.isDefined)
throw new Exception(s"Trying de deserialize an already set value for ${this.name}. This should not happen.")
serializationMode match {
case "dataset" => deserializeDataset(spark, path, field)
case "object" => deserializeObject(spark, path, field)
case _ => throw new IllegalArgumentException("Illegal performance.serialization setting. Can be 'dataset' or 'object'")
}
}

protected def serializeDataset(spark: SparkSession, path: String, field: String, value: TComplete): Unit

protected def deserializeDataset(spark: SparkSession, path: String, field: String): Option[_]

protected def serializeObject(spark: SparkSession, path: String, field: String, value: TComplete): Unit

protected def deserializeObject(spark: SparkSession, path: String, field: String): Option[_]

final protected def getFieldPath(path: String, field: String): Path =
Path.mergePaths(new Path(path), new Path("/fields/" + field))

final def get: Option[TComplete] = value.map(_.value)
final def getValue: TComplete = value.map(_.value).getOrElse(throw new Exception(s"feature $name is not set"))
final def get: Option[TComplete] = {
broadcastValue.map(_.value).orElse(rawValue)
}

final def getValue: TComplete = {
broadcastValue.map(_.value).orElse(rawValue).orElse(fallback.map(_())).getOrElse(throw new Exception(s"feature $name is not set"))
}

final def setValue(v: Option[Any]): HasFeatures = {
if (isSet) value.get.destroy()
value = Some(sparkSession.sparkContext.broadcast[TComplete](v.get.asInstanceOf[TComplete]))
if (useBroadcast) {
if (isSet) broadcastValue.get.destroy()
broadcastValue = Some(sparkSession.sparkContext.broadcast[TComplete](v.get.asInstanceOf[TComplete]))
} else {
rawValue = Some(v.get.asInstanceOf[TComplete])
}
model
}
final def isSet: Boolean = value.isDefined

def setFallback(v: Option[() => TComplete]): HasFeatures = {
fallback = v
model
}

final def isSet: Boolean = {
broadcastValue.isDefined || rawValue.isDefined
}

}

Expand All @@ -44,12 +86,27 @@ class StructFeature[TValue: ClassTag](model: HasFeatures, override val name: Str

implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

override def serialize(spark: SparkSession, path: String, field: String, value: TValue): Unit = {
override def serializeObject(spark: SparkSession, path: String, field: String, value: TValue): Unit = {
val dataPath = getFieldPath(path, field)
spark.sparkContext.parallelize(Seq(value)).saveAsObjectFile(dataPath.toString)
}

override def deserializeObject(spark: SparkSession, path: String, field: String): Option[TValue] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).first)
} else {
None
}
}

override def serializeDataset(spark: SparkSession, path: String, field: String, value: TValue): Unit = {
val dataPath = getFieldPath(path, field)
spark.createDataset(Seq(value)).write.mode("overwrite").parquet(dataPath.toString)
}

override def deserialize(spark: SparkSession, path: String, field: String): Option[TValue] = {
override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[TValue] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Expand All @@ -66,15 +123,32 @@ class MapFeature[TKey: ClassTag, TValue: ClassTag](model: HasFeatures, override

implicit val encoder: Encoder[(TKey, TValue)] = Encoders.kryo[(TKey, TValue)]

override def serialize(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = {
override def serializeObject(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = {
val dataPath = getFieldPath(path, field)
spark.sparkContext.parallelize(value.toSeq).saveAsObjectFile(dataPath.toString)
}



override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Some(spark.sparkContext.objectFile[(TKey, TValue)](dataPath.toString).collect.toMap)
} else {
None
}
}

override def serializeDataset(spark: SparkSession, path: String, field: String, value: Map[TKey, TValue]): Unit = {
import spark.implicits._
val dataPath = getFieldPath(path, field)
value.toSeq.toDS.write.mode("overwrite").parquet(dataPath.toString)
}



override def deserialize(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = {
override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Map[TKey, TValue]] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Expand All @@ -91,12 +165,27 @@ class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: Stri

implicit val encoder: Encoder[TValue] = Encoders.kryo[TValue]

override def serialize(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = {
override def serializeObject(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = {
val dataPath = getFieldPath(path, field)
spark.sparkContext.parallelize(value).saveAsObjectFile(dataPath.toString)
}

override def deserializeObject(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
} else {
None
}
}

override def serializeDataset(spark: SparkSession, path: String, field: String, value: Array[TValue]): Unit = {
val dataPath = getFieldPath(path, field)
spark.createDataset(value).write.mode("overwrite").parquet(dataPath.toString)
}

override def deserialize(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = {
override def deserializeDataset(spark: SparkSession, path: String, field: String): Option[Array[TValue]] = {
val fs: FileSystem = FileSystem.get(spark.sparkContext.hadoopConfiguration)
val dataPath = getFieldPath(path, field)
if (fs.exists(dataPath)) {
Expand Down
6 changes: 3 additions & 3 deletions src/test/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ nlp {
revert = -1.0
}
}

performance {
useBroadcast = false
}
serialization = "object"
useBroadcast = true
}