# pipeline test

This is a text cell. Start editing!

In [1]:
import com.johnsnowlabs.nlp.SparkNLP
import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.ml.tensorflow.TensorflowBert
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.types._
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.functions.{udf,to_timestamp}
import org.apache.spark.ml.feature.QuantileDiscretizer
import org.apache.spark.storage._
import org.apache.spark.ml.feature._
import org.apache.spark.ml.classification._

val dataDir = "../../data/recsys2020"

In [2]:
// val df = spark.read.parquet(dataDir + "/training1m.parquet").limit(1000).persist(StorageLevel.MEMORY_ONLY) // Remove limit after experiments
// df

In [3]:
// Convert targets
// val udf_has_engagement = udf[Int, IntegerType](x => if (x != null) 1 else 0)
// var trans_df = df.withColumn("has_retweet", udf_has_engagement('retweet_timestamp))
// trans_df

In [4]:
val trainDf = Seq(0 to 100:_*).toDF().withColumn("text", when(rand() > 0.2, "Hello World.").otherwise("Other"))
    .withColumn("nonTextFeature", when(rand() > 0.5, 1).otherwise(0))
    .withColumn("target1", when(rand() > 0.3, 1).otherwise(0))
    .withColumn("target2", when(rand() > 0.6, 1).otherwise(0))
trainDf.show()

+-----+------------+--------------+-------+-------+
|value|        text|nonTextFeature|target1|target2|
+-----+------------+--------------+-------+-------+
|    0|Hello World.|             1|      0|      1|
|    1|Hello World.|             1|      1|      0|
|    2|Hello World.|             1|      1|      1|
|    3|Hello World.|             0|      0|      0|
|    4|Hello World.|             0|      0|      0|
|    5|Hello World.|             0|      0|      0|
|    6|Hello World.|             1|      0|      0|
|    7|       Other|             1|      0|      0|
|    8|Hello World.|             1|      0|      1|
|    9|Hello World.|             1|      0|      0|
|   10|Hello World.|             1|      1|      0|
|   11|Hello World.|             1|      0|      1|
|   12|Hello World.|             0|      1|      1|
|   13|Hello World.|             1|      0|      0|
|   14|Hello World.|             0|      1|      0|
|   15|       Other|             1|      1|      0|
|   16|Hello

In [5]:
val doc = new DocumentAssembler()
    .setInputCol("text")
    .setOutputCol("document")
    .setCleanupMode("shrink")
// val tok = new Tokenizer()
//     .setInputCols("document")
//     .setOutputCol("token")
//     .setContextChars(Array("(", ")", "?", "!"))
//     .setSplitChars(Array("-"))
//     .addException("New York")
//     .addException("e-mail")
// val bert = BertEmbeddings.pretrained(name="bert_multi_cased", lang="xx")
//       .setInputCols("document", "token")
//       .setOutputCol("embeddings")
//       .setPoolingLayer(0) // 0, -1, or -2
val use = UniversalSentenceEncoder
      .pretrained()
      .setInputCols(Array("document"))
      .setOutputCol("tweet_embeddings")
// val emb = new SentenceEmbeddings()
//       .setInputCols(Array("document", "embeddings"))
//       .setOutputCol("tweet_embeddings")
//       .setPoolingStrategy("AVERAGE")
val fin = new EmbeddingsFinisher()
      .setInputCols("tweet_embeddings")
      .setOutputCols("finished_tweet_embeddings")
      .setOutputAsVector(true)
      .setCleanAnnotations(false)

// val text_trans_pipeline = new Pipeline().setStages(Array(doc, tok, bert, emb, fin))
val text_trans_pipeline = new Pipeline().setStages(Array(doc, use, fin))

In [6]:
// can later be combined into one large pipeline
val intermediateDf = text_trans_pipeline.fit(trainDf).transform(trainDf)
intermediateDf.show()

+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+
|value|        text|nonTextFeature|target1|target2|            document|    tweet_embeddings|finished_tweet_embeddings|
+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+
|    0|Hello World.|             1|      0|      1|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|
|    1|Hello World.|             1|      1|      0|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|
|    2|Hello World.|             1|      1|      1|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|
|    3|Hello World.|             0|      0|      0|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|
|    4|Hello World.|             0|      0|      0|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|
|    5|Hello World.|             0|     

In [9]:
// assumption: only one embedding is generated, thus explode just flattens the list
val transDf = intermediateDf.withColumn("embedding_features", explode('finished_tweet_embeddings))
transDf.show()

+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+--------------------+
|value|        text|nonTextFeature|target1|target2|            document|    tweet_embeddings|finished_tweet_embeddings|  embedding_features|
+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+--------------------+
|    0|Hello World.|             1|      0|      1|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|[-0.0375856719911...|
|    1|Hello World.|             1|      1|      0|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|[-0.0375856719911...|
|    2|Hello World.|             1|      1|      1|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|[-0.0375856719911...|
|    3|Hello World.|             0|      0|      0|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|[-0.0375856719911...|
|    4|Hello 

In [7]:
val ass = new VectorAssembler()
  .setInputCols(Array("nonTextFeature", "embedding_features"))
  .setOutputCol("features")

// You need to setProbabilityCol,  setPredictionCol and setRawPredictionCol
// otherwise you get a name conflict in the pipeline

val gbtT1 = new GBTClassifier()
  .setLabelCol("target1")
  .setFeaturesCol("features")
  .setProbabilityCol("prob1")
  .setPredictionCol("pred1")
  .setRawPredictionCol("rpred1")
  .setMaxIter(10)
  .setFeatureSubsetStrategy("auto")

val gbtT2 = new GBTClassifier()
  .setLabelCol("target2")
  .setFeaturesCol("features")
  .setProbabilityCol("prob2")
  .setPredictionCol("pred2")
  .setRawPredictionCol("rpred2")
  .setMaxIter(10)
  .setFeatureSubsetStrategy("auto")

val pred_pipeline = new Pipeline().setStages(Array(ass, gbtT1, gbtT2))

In [8]:
val finalDF = pred_pipeline.fit(transDf).transform(transDf)
finalDF.show()

+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+--------------------+--------------------+--------------------+--------------------+-----+--------------------+--------------------+-----+
|value|        text|nonTextFeature|target1|target2|            document|    tweet_embeddings|finished_tweet_embeddings|  embedding_features|            features|              rpred1|               prob1|pred1|              rpred2|               prob2|pred2|
+-----+------------+--------------+-------+-------+--------------------+--------------------+-------------------------+--------------------+--------------------+--------------------+--------------------+-----+--------------------+--------------------+-----+
|    0|Hello World.|             1|      0|      1|[[document, 0, 11...|[[sentence_embedd...|     [[-0.037585671991...|[-0.0375856719911...|[1.0,-0.037585671...|[-0.3779539256575...|[0.31953537457085...|  1.0|[0.42963234723922

In [10]:
finalDF.select("target1", "prob1", "pred1", "features").show(false)

+-------+----------------------------------------+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [11]:
finalDF.select("target2", "prob2", "pred2", "features").show(false)

+-------+----------------------------------------+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------