![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/2.2.FewShot_Assertion_Classifier.ipynb)

# Few-Shot Assertion Classifier

**Few-Shot Assertion Classifier Model for Higher Accuracy with Less Data**

The Few-Shot Assertion Classifier Model is an advanced annotator designed to get higher accuracy with fewer data samples inspired by SetFit framework. Few-Shot Assertion models consist of a sentence embedding component paired with a classifier (or head). While current support is focused on MPNet-based Few-Shot Assertion models, future updates will extend compatibility to include other popular models like Bert, DistillBert, and Roberta.

This classifier model supports various classifier types, including sklearn's LogisticRegression and custom PyTorch models, providing flexibility for different model setups. Users are required to specify the classifier type during model export to SparkNLP.

**Colab Setup**

In [None]:
import json, os
from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.4.1 spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

In [None]:
import json
import os

import sparknlp
import sparknlp_jsl

from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *

from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession

import warnings
warnings.filterwarnings('ignore')

params = {"spark.driver.memory":"16G",
          "spark.kryoserializer.buffer.max":"2000M",
          "spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 5.4.0
Spark NLP_JSL Version : 5.4.0


    # if you want to start the session with custom params as in start function above
    from pyspark.sql import SparkSession

    def start(SECRET):
        builder = SparkSession.builder \
            .appName("Spark NLP Licensed") \
            .master("local[*]") \
            .config("spark.driver.memory", "16G") \
            .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
            .config("spark.kryoserializer.buffer.max", "2000M") \
            .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:"+PUBLIC_VERSION) \
            .config("spark.jars", "https://pypi.johnsnowlabs.com/"+SECRET+"/spark-nlp-jsl-"+JSL_VERSION+".jar")

        return builder.getOrCreate()

    # spark = start(SECRET)

# Pretrained Models

|FewShot Assertion Model Name| Predicted Classed | Trained Embeddings |
|----------------------------|-------------------|--------------------|
|[fewhot_assertion_i2b2_e5_base_v2_i2b2]() |   | [e5_base_v2_embeddings_medical_assertion_i2b2]()  |

## Oncology

In [None]:
document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentence_detector = SentenceDetectorDLModel.pretrained("sentence_detector_dl_healthcare","en","clinical/models")\
    .setInputCols(["document"])\
    .setOutputCol("sentence")

tokenizer = Tokenizer()\
    .setInputCols(["sentence"])\
    .setOutputCol("token")\
    .setSplitChars(["-", "\/"])

word_embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical","en","clinical/models")\
    .setInputCols(["sentence","token"])\
    .setOutputCol("embeddings")

# ner_oncology
ner_oncology = MedicalNerModel.pretrained("ner_oncology","en","clinical/models")\
    .setInputCols(["sentence","token","embeddings"])\
    .setOutputCol("ner_oncology")

ner_oncology_converter = NerConverterInternal()\
    .setInputCols(["sentence","token","ner_oncology"])\
    .setOutputCol("ner_chunk")

few_shot_assertion_converter = FewShotAssertionSentenceConverter()\
    .setInputCols(["sentence", "token", "ner_chunk"])\
    .setOutputCol("assertion_sentence")

e5_embeddings = E5Embeddings.pretrained("e5_base_v2_embeddings_medical_assertion_oncology", "en", "clinical/models")\
    .setInputCols(["assertion_sentence"])\
    .setOutputCol("assertion_embedding")

few_shot_assertion_classifier = FewShotAssertionClassifierModel()\
    .pretrained("fewhot_assertion_oncology_e5_base_v2_oncology", "en", "clinical/models")\
    .setInputCols(["assertion_embedding"])\
    .setOutputCol("assertion_fewshot")

assertion_pipeline = Pipeline(
    stages=[
        document_assembler,
        sentence_detector,
        tokenizer,
        word_embeddings,
        ner_oncology,
        ner_oncology_converter,
        few_shot_assertion_converter,
        e5_embeddings,
        few_shot_assertion_classifier
])

sentence_detector_dl_healthcare download started this may take some time.
Approximate size to download 367.3 KB
[OK!]
embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
ner_oncology download started this may take some time.
[OK!]
e5_base_v2_embeddings_medical_assertion_oncology download started this may take some time.
Approximate size to download 375.4 MB
[OK!]
fewhot_assertion_oncology_e5_base_v2_oncology download started this may take some time.
[OK!]


In [None]:
sample_text= [
"""The patient is suspected to have colorectal cancer. Her family history is positive for other cancers. The result of the biopsy was positive. A CT scan was ordered to rule out metastases."""
]

data = spark.createDataFrame([sample_text]).toDF("text")

result = assertion_pipeline.fit(data).transform(data)

In [None]:
result.select("assertion_fewshot").show(1, truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
result.select(F.explode(F.arrays_zip(result.assertion_fewshot.metadata,
                                     result.assertion_fewshot.begin,
                                     result.assertion_fewshot.end,
                                     result.assertion_fewshot.result,)).alias("cols")) \
      .select(F.expr("cols['0']['ner_chunk']").alias("ner_chunk"),
              F.expr("cols['1']").alias("begin"),
              F.expr("cols['2']").alias("end"),
              F.expr("cols['0']['ner_label']").alias("ner_label"),
              F.expr("cols['3']").alias("assertion"),
              F.expr("cols['0']['confidence']").alias("confidence") ).show(truncate=False)

+-----------------+-----+---+----------------+---------+----------+
|ner_chunk        |begin|end|ner_label       |assertion|confidence|
+-----------------+-----+---+----------------+---------+----------+
|colorectal cancer|33   |49 |Cancer_Dx       |Possible |0.5812815 |
|Her              |52   |54 |Gender          |Present  |0.9562998 |
|cancers          |93   |99 |Cancer_Dx       |Family   |0.23465642|
|biopsy           |120  |125|Pathology_Test  |Past     |0.95732147|
|positive         |131  |138|Pathology_Result|Present  |0.9564386 |
|CT scan          |143  |149|Imaging_Test    |Past     |0.9571699 |
|metastases       |175  |184|Metastasis      |Possible |0.54986554|
+-----------------+-----+---+----------------+---------+----------+



 **Display the result of the FewShotAssertionClassifierModel using sparknlp_display.**

In [None]:
from google.colab import widgets
from sparknlp_display import AssertionVisualizer

assertion_visualiser = AssertionVisualizer()
results = result.collect()

In [None]:
for  i  in range(len(results)):
  assertion_visualiser.display(results[i], label_col ='ner_chunk', assertion_col='assertion_fewshot')

# Train a custom Few-Shot Assertion Model

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/i2b2_assertion_sample_short.csv

In [None]:
import pandas as pd

assertion_df = spark.read.option("header", True).option("inferSchema", "True").csv("i2b2_assertion_sample_short.csv")

assertion_df.show(3, truncate=100)

+-------------------------------------------------+-------------------+-------+-----+---+
|                                             text|             target|  label|start|end|
+-------------------------------------------------+-------------------+-------+-----+---+
|She has no history of liver disease , hepatitis .|      liver disease| absent|    5|  6|
|                         1. Undesired fertility .|undesired fertility|present|    1|  2|
|                            3) STATUS POST FALL .|               fall|present|    3|  3|
+-------------------------------------------------+-------------------+-------+-----+---+
only showing top 3 rows



In [None]:
(training_data, test_data) = assertion_df.randomSplit([0.8, 0.2], seed = 100)
print("Training Dataset Count: " + str(training_data.count()))
print("Test Dataset Count: " + str(test_data.count()))

Training Dataset Count: 721
Test Dataset Count: 170


In [None]:
training_data.groupBy('label').count().orderBy('count', ascending=False).show(truncate=False)

test_data.groupBy('label').count().orderBy('count', ascending=False).show(truncate=False)

+-------+-----+
|label  |count|
+-------+-----+
|present|546  |
|absent |175  |
+-------+-----+

+-------+-----+
|label  |count|
+-------+-----+
|present|117  |
|absent |53   |
+-------+-----+



In [None]:
document = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

token = Tokenizer()\
    .setInputCols(["document"])\
    .setOutputCol("token")

chunk2doc = Doc2ChunkInternal()\
    .setInputCols(["document","token"])\
    .setOutputCol("ner_chunk")\
    .setChunkCol("target")\
    .setStartCol("start")\
    .setStartColByTokenIndex(True)\
    .setFailOnMissing(False)\
    .setLowerCase(True)

few_shot_assertion_sentence_converter = FewShotAssertionSentenceConverter()\
    .setInputCols(["document", "token","ner_chunk"])\
    .setOutputCol("assertion_sentence")

e5_embeddings = E5Embeddings.pretrained("e5_base_v2")\
    .setInputCols(["assertion_sentence"])\
    .setOutputCol("assertion_embedding")

embeddings_pipeline = Pipeline(
    stages = [
        document,
        token,
        chunk2doc,
        few_shot_assertion_sentence_converter,
        e5_embeddings
])

e5_base_v2 download started this may take some time.
Approximate size to download 246.7 MB
[OK!]


In [None]:
assertion_test_data = embeddings_pipeline.fit(test_data).transform(test_data)
#assertion_test_data.write.mode('overwrite').parquet('i2b2_assertion_sample_test_data.parquet')

assertion_train_data = embeddings_pipeline.fit(training_data).transform(training_data)
#assertion_train_data.write.mode('overwrite').parquet('i2b2_assertion_sample_train_data.parquet')

## Graph setup

In [None]:
!pip install -q tensorflow==2.12.0
!pip install -q tensorflow-addons

We will use TFGraphBuilder annotator which can be used to create graphs in the model training pipeline.


In [None]:
from sparknlp_jsl.annotator import TFGraphBuilder

graph_folder = "./tf_graphs"
graph_name = "assertion_graph.pb"

assertion_graph_builder = TFGraphBuilder()\
    .setModelName("fewshot_assertion")\
    .setInputCols(["assertion_embedding"]) \
    .setLabelColumn("label")\
    .setGraphFolder(graph_folder)\
    .setGraphFile(graph_name)\
    .setHiddenUnitsNumber(100)

fewshot_assertion_approach = FewShotAssertionClassifierApproach()\
    .setInputCols("assertion_embedding")\
    .setOutputCol("assertion")\
    .setLabelCol("label")\
    .setBatchSize(32)\
    .setDropout(0.1)\
    .setLearningRate(0.001)\
    .setEpochsNumber(40)\
    .setValidationSplit(0.2)\
    .setModelFile(f"{graph_folder}/{graph_name}")

clinical_assertion_pipeline = Pipeline(
    stages = [
        assertion_graph_builder,
        fewshot_assertion_approach
])

In [None]:
%%time

assertion_model = clinical_assertion_pipeline.fit(assertion_train_data)

TF Graph Builder configuration:
Model name: fewshot_assertion
Graph folder: ./tf_graphs
Graph file name: assertion_graph.pb
Build params: {'input_dim': 768, 'output_dim': 2}
fewshot_assertion graph exported to ./tf_graphs/assertion_graph.pb
CPU times: user 1.25 s, sys: 77.2 ms, total: 1.33 s
Wall time: 2min 27s


## Checking the results

Checking the results saved in the log file

In [None]:
preds = assertion_model.transform(assertion_test_data)\
                       .selectExpr('label','assertion.result[0] as result')

preds_df = preds.toPandas()
preds_df

Unnamed: 0,label,result
0,present,present
1,absent,present
2,present,present
3,present,present
4,present,present
...,...,...
165,present,present
166,absent,absent
167,absent,absent
168,absent,absent


In [None]:
# We are going to use sklearn to evalute the results on test dataset
from sklearn.metrics import classification_report

print (classification_report( preds_df['label'], preds_df['result']))

              precision    recall  f1-score   support

      absent       0.93      0.75      0.83        53
     present       0.90      0.97      0.93       117

    accuracy                           0.91       170
   macro avg       0.91      0.86      0.88       170
weighted avg       0.91      0.91      0.90       170



In [None]:
# save model
assertion_model.stages[-1].write().overwrite().save('custom_fewshot_assertion_model')

## Load saved model

**Build Pipeline**

In [None]:
documentAssembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

# Sentence Detector annotator, processes various sentences per line
sentenceDetector = SentenceDetector()\
    .setInputCols(["document"])\
    .setOutputCol("sentence")

# Tokenizer splits words in a relevant format for NLP
tokenizer = Tokenizer()\
    .setInputCols(["sentence"])\
    .setOutputCol("token")

# Clinical word embeddings trained on PubMED dataset
word_embeddings = WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols(["sentence", "token"])\
    .setOutputCol("embeddings")

clinical_ner = MedicalNerModel.pretrained("ner_clinical", "en", "clinical/models") \
    .setInputCols(["sentence", "token", "embeddings"]) \
    .setOutputCol("ner")

ner_converter = NerConverterInternal() \
    .setInputCols(["sentence", "token", "ner"]) \
    .setOutputCol("ner_chunk")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
ner_clinical download started this may take some time.
[OK!]


In [None]:
few_shot_assertion_sentence_converter = FewShotAssertionSentenceConverter()\
    .setInputCols(["sentence", "ner_chunk"])\
    .setOutputCol("assertion_sentence")

e5_embeddings = E5Embeddings.pretrained("e5_base_v2")\
    .setInputCols(["assertion_sentence"])\
    .setOutputCol("assertion_embedding")

few_shot_assertion_classifier = FewShotAssertionClassifierModel.load("custom_fewshot_assertion_model")\
    .setInputCols(["assertion_embedding"])\
    .setOutputCol("assertion")


nlpPipeline = Pipeline(
    stages=[
        documentAssembler,
        sentenceDetector,
        tokenizer,
        word_embeddings,
        clinical_ner,
        ner_converter,
        few_shot_assertion_sentence_converter,
        e5_embeddings,
        few_shot_assertion_classifier
])

empty_data = spark.createDataFrame([[""]]).toDF("text")

model = nlpPipeline.fit(empty_data)

e5_base_v2 download started this may take some time.
Approximate size to download 246.7 MB
[OK!]


In [None]:
text = 'Patient has a headache for the last 2 weeks, needs to get a head CT, and appears anxious when she walks fast. No alopecia and pain noted'

light_model = LightPipeline(model)

light_result = light_model.fullAnnotate(text)[0]

print(text)

chunks=[]
entities=[]
status=[]
confidence=[]

for n,m in zip(light_result['ner_chunk'],light_result['assertion']):

    chunks.append(n.result)
    entities.append(n.metadata['entity'])
    status.append(m.result)
    confidence.append(m.metadata['confidence'])

df = pd.DataFrame({'chunks':chunks, 'entities':entities, 'assertion':status, 'confidence':confidence})

df

Patient has a headache for the last 2 weeks, needs to get a head CT, and appears anxious when she walks fast. No alopecia and pain noted


Unnamed: 0,chunks,entities,assertion,confidence
0,a headache,PROBLEM,present,0.9482695
1,a head CT,TEST,present,0.9325676
2,anxious,PROBLEM,present,0.9496433
3,alopecia,PROBLEM,absent,0.93149626
4,pain,PROBLEM,absent,0.93775314


In [None]:
from sparknlp_display import AssertionVisualizer

assertion_visualiser = AssertionVisualizer()
assertion_visualiser.display(light_result, label_col ='ner_chunk', assertion_col='assertion')