# sentence encoder


Based on [https://towardsdatascience.com/text-classification-in-spark-nlp-with-bert-and-universal-sentence-encoders-e644d618ca32](https://towardsdatascience.com/text-classification-in-spark-nlp-with-bert-and-universal-sentence-encoders-e644d618ca32)




In [1]:
import com.johnsnowlabs.nlp.SparkNLP
import com.johnsnowlabs.nlp.annotator._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.ml.tensorflow.TensorflowBert
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.types._
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.functions.{udf,to_timestamp}
import org.apache.spark.ml.feature.QuantileDiscretizer
import org.apache.spark.storage.StorageLevel
import scala.collection.mutable.WrappedArray

val dataDir = sys.env("HOME") + "/recsys2020"

In [2]:
var df = spark.read.parquet(dataDir + "/training1k.parquet")
df.show(false)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+--------------------------------+----------------------------+----------------------------------+----------------------------------+----------+-------------------

In [3]:
// Convert targets
val udf_has_engagement = udf[Boolean, IntegerType](_ != null)
val target_df = df.withColumn("has_retweet", udf_has_engagement('retweet_timestamp))
target_df

[text_tokens: array<string>, hashtags: array<string> ... 23 more fields]

In [4]:
val text_df = df.withColumn("sample_text", lit("Hello World."))

In [5]:
// If there is media like photo or video, the last link in the text is always a link to the tweet itself
val ignored_tokens = Set("[CLS]","[UNK]","[SEP]","UNKN")
val udf_tweet = udf((text_tokens: WrappedArray[String], present_media: WrappedArray[String]) => {
    text_tokens.filterNot(token => ignored_tokens.contains(token)).foldLeft(List[String]()){(soFar, next) => {
            var m: String = null;
            if (!soFar.isEmpty && soFar.last.startsWith("https") && !(next == "https")) {
                soFar.last.length match {
                    case 5 => if (next == ":") m = next;
                    case 6 | 7 | 12 => if (next == "/") m = next;
                    case 8 => if (next == "t") m = next;
                    case 9 => if (next == ".") m = next;
                    case 10 => if (next == "co") m = next;
                    case x  if (x == 13) => 
                        if (next.forall(_.isLetterOrDigit)) {
                            m = next;
                        }
                    case x if (x > 13) =>
                        if (next.startsWith("##")) {
                            val nwop = next.stripPrefix("##");
                            if (nwop.forall(_.isLetterOrDigit))
                            {
                                m = nwop;
                            }
                        }
                }
            }
            else if (!soFar.isEmpty && (soFar.last.startsWith("#") || soFar.last.startsWith("@"))) {
                val isFirst = Set('_', '#', '@').contains(soFar.last.last);
                if (isFirst || next.startsWith("##") || next == "_")
                {
                    val nwop = next.stripPrefix("##");
                    if (nwop.forall(c => c.isLetterOrDigit || c == '_'))
                    {
                        m = nwop;
                    }
                }
            }
            else if (next.startsWith("##"))
            {
                m = next.stripPrefix("##");
            }
            if (m != null) soFar.init :+ (soFar.last + m) else soFar :+ next
        }
    }.mkString(" ")
})

val converted_df = df.withColumn("sentence", udf_tweet('text_tokens, 'present_media))
converted_df.select('present_domains, 'sentence, 'text_tokens).show(100, false)

+--------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [6]:
"// Define stages
val use = UniversalSentenceEncoder.pretrained().setInputCols(Array("text_tokens")).setOutputCol("tweet_embeddings")
val fin = new EmbeddingsFinisher()
      .setInputCols("tweet_embeddings"
      .setOutputCols("finished_tweet_embeddings")
      .setOutputAsVector(true)
      .setCleanAnnotations(false)

In [7]:
val sent_emb = new SentenceEmbeddings().setInputCols(Array("text_tokens"))

java.lang.IllegalArgumentException: requirement failed: setInputCols in SENTENCE_EMBEDDINGS_0a833852716e expecting 2 columns. Provided column amount: 1. Which should be columns from the following annotators: document, word_embeddings

In [8]:
val use_clf_pipeline = new Pipeline().setStages(Array(use, fin))

In [9]:
val use_transformed = use_clf_pipeline.fit(df).transform(df)
use_transformed

java.lang.IllegalArgumentException: requirement failed: Wrong or missing inputCols annotators in UNIVERSAL_SENTENCE_ENCODER_4de71669b7ec.

Current inputCols: text_tokens. Dataset's columns:
(column_name=text_tokens,is_nlp_annotator=false)
(column_name=hashtags,is_nlp_annotator=false)
(column_name=tweet_id,is_nlp_annotator=false)
(column_name=present_media,is_nlp_annotator=false)
(column_name=present_links,is_nlp_annotator=false)
(column_name=present_domains,is_nlp_annotator=false)
(column_name=tweet_type,is_nlp_annotator=false)
(column_name=language,is_nlp_annotator=false)
(column_name=tweet_timestamp,is_nlp_annotator=false)
(column_name=engaged_with_user_id,is_nlp_annotator=false)
(column_name=engaged_with_user_follower_count,is_nlp_annotator=false)
(column_name=engaged_with_user_following_count,is_nlp_annotator=false)
(column_name=engaged_with_user_is_verified,is_nlp_annotator=false)
(column_name=engaged_with_user_account_creation,is_nlp_annotator=false)
(column_name=engaging_user_id,is_nlp_annotator=false)
(column_name=engaging_user_follower_count,is_nlp_annotator=false)
(column_name=engaging_user_following_count,is_nlp_annotator=false)
(column_name=engaging_user_is_verified,is_nlp_annotator=false)
(column_name=engaging_user_account_creation,is_nlp_annotator=false)
(column_name=engagee_follows_engager,is_nlp_annotator=false)
(column_name=reply_timestamp,is_nlp_annotator=false)
(column_name=retweet_timestamp,is_nlp_annotator=false)
(column_name=retweet_with_comment_timestamp,is_nlp_annotator=false)
(column_name=like_timestamp,is_nlp_annotator=false).
Make sure such annotators exist in your pipeline, with the right output names and that they have following annotator types: document