![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/2.2.FewShot_Assertion_Classifier.ipynb)

# Few-Shot Assertion Classifier

**Few-Shot Assertion Classifier Model for Higher Accuracy with Less Data**

The Few-Shot Assertion Classifier Model is an advanced annotator designed to get higher accuracy with fewer data samples inspired by SetFit framework. Few-Shot Assertion models consist of a sentence embedding component paired with a classifier (or head). While current support is focused on MPNet-based Few-Shot Assertion models, future updates will extend compatibility to include other popular models like Bert, DistillBert, and Roberta.

This classifier model supports various classifier types, including sklearn's LogisticRegression and custom PyTorch models, providing flexibility for different model setups. Users are required to specify the classifier type during model export to SparkNLP.

## Healthcare NLP for Data Scientists Course

If you are not familiar with the components in this notebook, you can check [Healthcare NLP for Data Scientists Udemy Course](https://www.udemy.com/course/healthcare-nlp-for-data-scientists/) and the [MOOC Notebooks](https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/Spark_NLP_Udemy_MOOC/Healthcare_NLP) for each components.

**Colab Setup**

In [None]:
import json, os
from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.4.1 spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

In [None]:
import json
import os

import sparknlp
import sparknlp_jsl

from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *

from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession

import warnings
warnings.filterwarnings('ignore')

params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 5.3.2
Spark NLP_JSL Version : 5.3.2


In [None]:
# if you want to start the session with custom params as in start function above
from pyspark.sql import SparkSession

def start(SECRET):
    builder = SparkSession.builder \
        .appName("Spark NLP Licensed") \
        .master("local[*]") \
        .config("spark.driver.memory", "16G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.kryoserializer.buffer.max", "2000M") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:"+PUBLIC_VERSION) \
        .config("spark.jars", "https://pypi.johnsnowlabs.com/"+SECRET+"/spark-nlp-jsl-"+JSL_VERSION+".jar")

    return builder.getOrCreate()

# spark = start(SECRET)

# Pipeline

## assertion_fewshotclassifier

In [None]:
document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentence_detector = SentenceDetector()\
   .setInputCols("document")\
   .setOutputCol("sentence")

tokenizer = Tokenizer()\
   .setInputCols(["sentence"])\
   .setOutputCol("token")

embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
   .setInputCols(["sentence", "token"])\
   .setOutputCol("embeddings") \
   .setCaseSensitive(False)

ner = MedicalNerModel.pretrained("ner_jsl", "en", "clinical/models") \
   .setInputCols(["sentence", "token", "embeddings"]) \
   .setOutputCol("ner")

ner_converter = NerConverterInternal()\
   .setInputCols(["sentence", "token", "ner"])\
   .setWhiteList(["Disease_Syndrome_Disorder", "Hypertension","Symptom","VS_Finding"])\
   .setOutputCol("ner_chunk")

few_shot_assertion_classifier = FewShotAssertionClassifierModel().pretrained("assertion_fewshotclassifier","en", "clinical/models")\
  .setInputCols(["sentence", "ner_chunk"])\
  .setOutputCol("assertion_fewshot")

pipeline = Pipeline()\
  .setStages([
  document_assembler,
  sentence_detector,
  tokenizer,
  embeddings,
  ner,
  ner_converter,
  few_shot_assertion_classifier
])

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
ner_jsl download started this may take some time.
[OK!]
assertion_fewshotclassifier download started this may take some time.
[OK!]


In [None]:
texts = [
    ["Includes hypertension and chronic obstructive pulmonary disease."],
    ["Her former vascular no arteriovenous malformations are identified; there is no evidence of recurrence of her former vascular malformation."],
    ["He is an elderly gentleman in no acute distress. He is sitting up in bed eating his breakfast."],
    ["Trachea is midline. No jugular venous pressure distention is noted. No adenopathy in the cervical, supraclavicular, or axillary areas."],
    ["Soft and not tender. There may be some fullness in the left upper quadrant, although I do not appreciate a true spleen with inspiration."]
]
spark_df = spark.createDataFrame(texts).toDF("text")
result = pipeline.fit(spark_df).transform(spark_df)

In [None]:
result.select("assertion_fewshot").show(1, truncate=False)

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|assertion_fewshot                                                                                                                                                                                                                                                                                                                                                                                                                        |
+-----------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
result.select(F.explode(F.arrays_zip(result.assertion_fewshot.metadata,
                                     result.assertion_fewshot.begin,
                                     result.assertion_fewshot.end,
                                     result.assertion_fewshot.result,)).alias("cols")) \
      .select(F.expr("cols['0']['chunk']").alias("chunk"),
              F.expr("cols['1']").alias("begin"),
              F.expr("cols['2']").alias("end"),
              F.expr("cols['0']['entity']").alias("entity"),
              F.expr("cols['3']").alias("assertion"),
              F.expr("cols['0']['confidence']").alias("confidence") ).show(truncate=False)

+-------------------------------------+-----+---+-------------------------+---------+----------+
|chunk                                |begin|end|entity                   |assertion|confidence|
+-------------------------------------+-----+---+-------------------------+---------+----------+
|hypertension                         |9    |20 |Hypertension             |present  |1.0       |
|chronic obstructive pulmonary disease|26   |62 |Disease_Syndrome_Disorder|present  |1.0       |
|arteriovenous malformations          |23   |49 |Disease_Syndrome_Disorder|absent   |1.0       |
|vascular malformation                |116  |136|Disease_Syndrome_Disorder|absent   |0.9999956 |
|distress                             |39   |46 |Symptom                  |absent   |1.0       |
|jugular venous pressure distention   |23   |56 |Symptom                  |absent   |1.0       |
|adenopathy                           |71   |80 |Symptom                  |absent   |1.0       |
|tender                       

 **Display the result of the FewShotAssertionClassifierModel using sparknlp_display.**

In [None]:
from google.colab import widgets
from sparknlp_display import AssertionVisualizer

assertion_visualiser = AssertionVisualizer()
results = result.collect()

In [None]:
for  i  in range(len(results)):
  assertion_visualiser.display(results[i], label_col ='ner_chunk', assertion_col='assertion_fewshot')