![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

# **AverageEmbeddings**

This notebook will cover the different parameters and usages of `AverageEmbeddings`.

**📖 Learning Objectives:**

1. Understand how to use `AverageEmbeddings`.

2. Become comfortable using the different parameters of the annotator.



**🔗 Helpful Links:**


- Python Docs : [AverageEmbeddings](https://nlp.johnsnowlabs.com/licensed/api/python/reference/autosummary/sparknlp_jsl/annotator/embeddings/average_embeddings/index.html#sparknlp_jsl.annotator.embeddings.average_embeddings.AverageEmbeddings)

- Scala Docs : [AverageEmbeddings](https://nlp.johnsnowlabs.com/licensed/api/com/johnsnowlabs/nlp/annotators/embeddings/AverageEmbeddings.html)


## **📜 Background**


`AverageEmbeddings` Averages same two size sentence embdeddings' vectors into one

## **🎬 Colab Setup**

In [None]:
!pip install -q johnsnowlabs

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

In [None]:
from johnsnowlabs import nlp, medical

# After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM
nlp.install()

In [None]:
from johnsnowlabs import nlp, medical
import pandas as pd

# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()

In [5]:
spark

In [6]:
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T

## **🖨️ Input/Output Annotation Types**

- Input: `SENTENCE_EMBEDDINGS`, `SENTENCE_EMBEDDINGS`, `CHUNK`

- Output: `EMBEDDINGS`

## **🔎 Parameters**


- `inputCols`: The name of the columns containing the input annotations. It can read either a String column or an Array.
- `outputCol`: The name of the column in Document type that is generated. We can specify only one column here.

All the parameters can be set using the corresponding set method in camel case. For example, `.setInputcols()`.

### `inputCols` and `outputCol`

Let's define the annotators to process raw texts into `DOCUMENT`, `SENTENCE` and `CHUNK` annotations to produce sentence embeddings.

In [27]:
document_assembler =  nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")\

sentence_detector = nlp.SentenceDetector()\
    .setInputCols(["document"])\
    .setOutputCol("sentence")

doc2Chunk = nlp.Doc2Chunk() \
    .setInputCols("sentence") \
    .setOutputCol("chunk") \
    .setIsArray(True)

sbiobert_base_cased_mli = nlp.BertSentenceEmbeddings\
    .pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols("sentence")\
    .setOutputCol("sbiobert_base_cased_mli")

sent_biobert_clinical_base_cased = nlp.BertSentenceEmbeddings.pretrained("sent_biobert_clinical_base_cased", "en") \
    .setInputCols("sentence") \
    .setOutputCol("sent_biobert_clinical_base_cased")


sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]
sent_biobert_clinical_base_cased download started this may take some time.
Approximate size to download 386.6 MB
[OK!]


Generating `AverageEmbeddings` to calculate the average embeddings of the sentence embeddings:

In [28]:
avg_embeddings = medical.AverageEmbeddings()\
    .setInputCols(["sent_biobert_clinical_base_cased","sbiobert_base_cased_mli","chunk"])\
    .setOutputCol("embeddings")

pipeline = nlp.Pipeline(
    stages=[
        document_assembler,
        sentence_detector,
        doc2Chunk,
        sbiobert_base_cased_mli,
        sent_biobert_clinical_base_cased,
        avg_embeddings
    ])

data = spark.createDataFrame([[" The patient was prescribed 1 capsule of Advil for 5 days "]]).toDF("text")

result = pipeline.fit(data).transform(data)

In [29]:
result.select("sent_biobert_clinical_base_cased.embeddings").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [30]:
result.select("sbiobert_base_cased_mli.embeddings").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [31]:
import numpy as np

sbiobert_base_cased_mli = np.array(result.select("sbiobert_base_cased_mli.embeddings").collect()[0][0][0])
sent_biobert_clinical_base_cased = np.array(result.select("sent_biobert_clinical_base_cased.embeddings").collect()[0][0][0])

average = (sbiobert_base_cased_mli+sent_biobert_clinical_base_cased)/2.0
average[:10]

array([ 0.32466834,  0.12497781, -0.20237188,  0.3716198 ,  0.27896178,
        0.01155752, -0.18116346,  0.31467234, -0.15455821,  0.21972417])

In [32]:
result.select("embeddings.embeddings").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [33]:
result_df = result.select(F.explode(F.arrays_zip(result.chunk.result,
                                                 result.chunk.metadata,
                                                 result.sentence.result,
                                                 result.embeddings.embeddings,
                                                 result.sent_biobert_clinical_base_cased.embeddings,
                                                 result.sbiobert_base_cased_mli.embeddings,)).alias("cols"))\
                  .select(F.expr("cols['0']").alias("sentence"),
                          F.expr("cols['1']").alias("sentence_metadata"),
                          F.expr("cols['2']").alias("chunk"),
                          F.expr("cols['3']").alias("embeddings"),
                          F.expr("cols['4']").alias("sent_biobert_clinical_base_cased"),
                          F.expr("cols['5']").alias("sbiobert_base_cased_mli"))

result_df.show(50, truncate=1000)

+--------------------------------------------------------+---------------------------+--------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [34]:
result_df.columns

['sentence',
 'sentence_metadata',
 'chunk',
 'embeddings',
 'sent_biobert_clinical_base_cased',
 'sbiobert_base_cased_mli']