![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/healthcare-nlp/08.6.Text_Classification_with_FewShotClassifier.ipynb)

## Colab Setup

In [None]:
# Install the johnsnowlabs library to access Spark-OCR and Spark-NLP for Healthcare, Finance, and Legal.
! pip install -q johnsnowlabs

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

In [None]:
from johnsnowlabs import nlp, medical

# After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM
nlp.settings.enforce_versions=True
nlp.install(refresh_install=True)

In [None]:
from johnsnowlabs import nlp, medical
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()


In [5]:
spark

In [None]:
!pip install -q git+https://github.com/tensorflow/addons.git

In [14]:
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark.sql as SQL
from pyspark import keyword_only
from sklearn.metrics import classification_report

## Healthcare NLP for Data Scientists Course

If you are not familiar with the components in this notebook, you can check [Healthcare NLP for Data Scientists Udemy Course](https://www.udemy.com/course/healthcare-nlp-for-data-scientists/) and the [MOOC Notebooks](https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/Spark_NLP_Udemy_MOOC/Healthcare_NLP) for each components.

# Few Shot Classification

The `FewShotClassifierApproach` and `FewShotClassifierModel` annotators are new additions to the set of annotators available in the NLP library. These annotators specifically target few-shot classification tasks, which involve training a model to make accurate predictions with limited labeled data.

These new annotators provide a valuable capability for handling scenarios where labeled data is scarce or expensive to obtain. By effectively utilizing limited labeled examples, the few-shot classification approach enables the creation of models that can generalize and classify new instances accurately, even with minimal training data.

The FewShotClassifier is designed to process sentence embeddings as input. It generates category annotations, providing labels along with confidence scores that range from 0 to 1. Input annotation types supported by this model include `SENTENCE_EMBEDDINGS`, while the output annotation type is `CATEGORY`.

In [15]:
data = [
    ["ADE_positive", 'Both PAN and methotrexate have been independently demonstrated to cause sensorineural hearing loss.'],
    ["ADE_positive", 'Increased lash length, thickness, and pigmentation are well-documented side effects of prostaglandin analog glaucoma drops.'],
    ["ADE_positive", 'We reviewed the records of 3 patients with lymphoproliferative disorders who experienced acute coronary syndromes associated with their initial infusion of rituximab.'],
    ["ADE_positive", 'A 58-year-old woman with rheumatoid arthritis (RA) developed fever, skin eruptions, leukocytopenia, and thrombocytopenia, 3 weeks after treatment with sulfasalazine.'],
    ["ADE_positive", 'Adrenal suppression in a fetus due to administration of methylprednisolone has hitherto been rarely published.'],
    ["ADE_negative", 'Serum concentration of cerivastatin at 6 h after taking the last dose (0.15 mg) was 8062.5 ng/L, which was almost 5.7 times higher than that of normal persons.'],
    ["ADE_negative", 'The usual treatment includes quick relief bronchodilator medications of the sympathomimetic class and controller medications that may include the long-acting inhaled bronchodilator salmeterol.'],
    ["ADE_negative", 'Pathogenic mechanisms for the development of pseudomembranous colitis and the epidemiology of this condition in patients with AIDS are discussed.'],
    ["ADE_negative", 'On the basis of the clinico-radiologic presentation, a pulmonary hemorrhage was likely to occur; so to clarify the origin of this process, a complete serologic examination was performed but all the antibodies were negative.'],
    ["ADE_negative", 'I report a patient who developed the syndrome during treatment for schizophrenia with the antipsychotic agent molindone hydrochloride.']
]

In [16]:
train_data = spark.createDataFrame(data).toDF("label","text")
train_data.show(truncate=100)

+------------+----------------------------------------------------------------------------------------------------+
|       label|                                                                                                text|
+------------+----------------------------------------------------------------------------------------------------+
|ADE_positive| Both PAN and methotrexate have been independently demonstrated to cause sensorineural hearing loss.|
|ADE_positive|Increased lash length, thickness, and pigmentation are well-documented side effects of prostaglan...|
|ADE_positive|We reviewed the records of 3 patients with lymphoproliferative disorders who experienced acute co...|
|ADE_positive|A 58-year-old woman with rheumatoid arthritis (RA) developed fever, skin eruptions, leukocytopeni...|
|ADE_positive|Adrenal suppression in a fetus due to administration of methylprednisolone has hitherto been rare...|
|ADE_negative|Serum concentration of cerivastatin at 6 h after taking th

In [17]:
document_asm = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("sentence")

sentence_embeddings = nlp.BertSentenceEmbeddings\
    .pretrained("sbiobert_base_cased_mli","en","clinical/models")\
    .setInputCols(["sentence"])\
    .setOutputCol("sentence_embeddings")

graph_builder = medical.TFGraphBuilder()\
    .setModelName("fewshot_classifier")\
    .setInputCols(["sentence_embeddings"]) \
    .setLabelColumn("label")\
    .setGraphFolder("/tmp")\
    .setGraphFile("log_reg_graph.pb")\

few_shot_approach = medical.FewShotClassifierApproach()\
    .setLabelColumn("label")\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("prediction")\
    .setModelFile(f"/tmp/log_reg_graph.pb")\
    .setEpochsNumber(10)\
    .setBatchSize(1)\
    .setLearningRate(0.001)

pipeline = nlp.Pipeline(
    stages=[
        document_asm,
        sentence_embeddings,
        graph_builder,
        few_shot_approach
])

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]


In [18]:
%%time
model = pipeline.fit(train_data)

TF Graph Builder configuration:
Model name: fewshot_classifier
Graph folder: /tmp
Graph file name: log_reg_graph.pb
Build params: {'input_dim': 768, 'output_dim': 2, 'hidden_layers': [], 'output_act': 'sigmoid'}
fewshot_classifier graph exported to /tmp/log_reg_graph.pb
CPU times: user 302 ms, sys: 28.8 ms, total: 331 ms
Wall time: 4.75 s


In [19]:
model.stages[-1].write().overwrite().save("/tmp/few_shot_model")

In [20]:
few_shot_model = medical.FewShotClassifierModel.load("/tmp/few_shot_model")

## with LightPipeline

In [21]:
lpipeline = nlp.LightPipeline(nlp.Pipeline(
                stages=[
                    document_asm,
                    sentence_embeddings,
                    few_shot_model.setMultiClass(False)
                ]).fit(train_data))

In [22]:
tests = [
    'Bleomycin pneumonitis potentiated by oxygen administration.',
    'Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use.',
    'These cases and others quoted indicate that dependence on pressurized aerosol bronchodilators can and does occur in young people.',
    'PVT during chemotherapy in children is a rare event and appears to be closely related to intensive chemotherapy containing busulfan and to be associated with HVOD.',
]

In [23]:
for i, r in enumerate(lpipeline.fullAnnotate(tests)):
    print(i, r)


0 {'sentence': [Annotation(document, 0, 58, Bleomycin pneumonitis potentiated by oxygen administration., {}, [])], 'sentence_embeddings': [Annotation(sentence_embeddings, 0, 58, Bleomycin pneumonitis potentiated by oxygen administration., {'sentence': '0', 'token': 'Bleomycin pneumonitis potentiated by oxygen administration.', 'pieceId': '-1', 'isWordStart': 'true'}, [])], 'prediction': [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.7748149'}, [])]}
1 {'sentence': [Annotation(document, 0, 112, Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use., {}, [])], 'sentence_embeddings': [Annotation(sentence_embeddings, 0, 112, Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use., {'sentence': '0', 'token': 'Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use.', 'pieceId': '-1', 'isWordStart': 'true'},

In [24]:
for i, r in enumerate(lpipeline.fullAnnotate(tests)):
    print(i ,r["prediction"][0].result)

0 ADE_positive
1 ADE_negative
2 ADE_positive
3 ADE_positive


In [25]:
for i, r in enumerate(lpipeline.fullAnnotate(tests)):
    print(i ,r["prediction"][0],r["sentence"][0].result )

0 Annotation(category, 0, 0, ADE_positive, {'confidence': '0.7748149'}, []) Bleomycin pneumonitis potentiated by oxygen administration.
1 Annotation(category, 0, 0, ADE_negative, {'confidence': '0.68406105'}, []) Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use.
2 Annotation(category, 0, 0, ADE_positive, {'confidence': '0.5916398'}, []) These cases and others quoted indicate that dependence on pressurized aerosol bronchodilators can and does occur in young people.
3 Annotation(category, 0, 0, ADE_positive, {'confidence': '0.62139165'}, []) PVT during chemotherapy in children is a rare event and appears to be closely related to intensive chemotherapy containing busulfan and to be associated with HVOD.


In [26]:
lpipeline.fullAnnotate('After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.')

[{'sentence': [Annotation(document, 0, 152, After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings., {}, [])],
  'sentence_embeddings': [Annotation(sentence_embeddings, 0, 152, After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings., {'sentence': '0', 'token': 'After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.', 'pieceId': '-1', 'isWordStart': 'true'}, [])],
  'prediction': [Annotation(category, 0, 0, ADE_negative, {'confidence': '0.50165457'}, [])]}]

In [27]:
lpipeline_multi = nlp.LightPipeline(nlp.Pipeline(
                      stages=[
                          document_asm,
                          sentence_embeddings,
                          few_shot_model.setMultiClass(True)
                      ]).fit(train_data))

In [28]:
for i, r in enumerate(lpipeline_multi.fullAnnotate(tests)):
    print(i,r)


0 {'sentence': [Annotation(document, 0, 58, Bleomycin pneumonitis potentiated by oxygen administration., {}, [])], 'sentence_embeddings': [Annotation(sentence_embeddings, 0, 58, Bleomycin pneumonitis potentiated by oxygen administration., {'sentence': '0', 'token': 'Bleomycin pneumonitis potentiated by oxygen administration.', 'pieceId': '-1', 'isWordStart': 'true'}, [])], 'prediction': [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.7748149'}, []), Annotation(category, 0, 0, ADE_negative, {'confidence': '0.26696292'}, [])]}
1 {'sentence': [Annotation(document, 0, 112, Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use., {}, [])], 'sentence_embeddings': [Annotation(sentence_embeddings, 0, 112, Enzymes derived from two different bacterial sources (Escherichia coli and Erwinia carotovora) are in common use., {'sentence': '0', 'token': 'Enzymes derived from two different bacterial sources (Escherichia coli and Erwin

In [29]:
for i, r in enumerate(lpipeline_multi.fullAnnotate(tests)):
    print(i, r["prediction"])


0 [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.7748149'}, []), Annotation(category, 0, 0, ADE_negative, {'confidence': '0.26696292'}, [])]
1 [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.2453861'}, []), Annotation(category, 0, 0, ADE_negative, {'confidence': '0.68406105'}, [])]
2 [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.5916398'}, []), Annotation(category, 0, 0, ADE_negative, {'confidence': '0.5114412'}, [])]
3 [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.62139165'}, []), Annotation(category, 0, 0, ADE_negative, {'confidence': '0.59960246'}, [])]


In [30]:
lpipeline_multi.fullAnnotate('After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.')

[{'sentence': [Annotation(document, 0, 152, After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings., {}, [])],
  'sentence_embeddings': [Annotation(sentence_embeddings, 0, 152, After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings., {'sentence': '0', 'token': 'After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.', 'pieceId': '-1', 'isWordStart': 'true'}, [])],
  'prediction': [Annotation(category, 0, 0, ADE_positive, {'confidence': '0.4123436'}, []),
   Annotation(category, 0, 0, ADE_negative, {'confidence': '0.50165457'}, [])]}]

## with transform

In [31]:
test_data = spark.createDataFrame([
    ['After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.'],
    ['Possible linkage of amprenavir with intracranial bleeding in an HIV-infected hemophiliac.'],
    ['Results Radiographic responses to sirolimus were observed in all patients.'],
    ['Obtaining appropriate cultures can be critical in making the diagnosis and directing treatment.']
]).toDF("text")

In [32]:
result = model.transform(test_data)

result.select("text","prediction").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------+
|text                                                                                                                                                     |prediction                                                      |
+---------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------+
|After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic findings.|[{category, 0, 0, ADE_negative, {confidence -> 0.50165457}, []}]|
|Possible linkage of amprenavir with intracranial bleeding in an HIV-infected hemophiliac.                          

In [33]:
model_multi = nlp.Pipeline(
                  stages=[
                      document_asm,
                      sentence_embeddings,
                      few_shot_model.setMultiClass(True)
                  ]).fit(train_data)

result_2 = model_multi.transform(test_data)

result_2.select("text","prediction").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|text                                                                                                                                                     |prediction                                                                                                                      |
+---------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------+
|After 1 week of nefazodone therapy the patient experienced headache, confusion, and "gray areas" in her vision, without abnormal ophthalmologic 

## with word embeddings

In [34]:
tokenizer = nlp.Tokenizer() \
    .setInputCols(["sentence"]) \
    .setOutputCol("token")

word_embeddings = nlp.WordEmbeddingsModel.pretrained("embeddings_clinical","en","clinical/models")\
    .setInputCols(["sentence","token"])\
    .setOutputCol("word_embeddings")

sentence_embeddings = nlp.SentenceEmbeddings() \
    .setInputCols(["sentence", "word_embeddings"]) \
    .setOutputCol("sentence_embeddings") \
    .setPoolingStrategy("AVERAGE")

pipeline = nlp.Pipeline(
    stages=[
        document_asm,
        tokenizer,
        word_embeddings,
        sentence_embeddings,
        few_shot_approach
    ])

model = pipeline.fit(spark.createDataFrame([["",""]]).toDF("label","text"))

lpipeline = nlp.LightPipeline(model)

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [35]:
lpipeline.fullAnnotate('Nephrotic syndrome associated with interferon-beta-1b therapy for multiple sclerosis.')

[{'word_embeddings': [Annotation(word_embeddings, 0, 8, Nephrotic, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'Nephrotic', 'sentence': '0'}, []),
   Annotation(word_embeddings, 10, 17, syndrome, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'syndrome', 'sentence': '0'}, []),
   Annotation(word_embeddings, 19, 28, associated, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'associated', 'sentence': '0'}, []),
   Annotation(word_embeddings, 30, 33, with, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'with', 'sentence': '0'}, []),
   Annotation(word_embeddings, 35, 52, interferon-beta-1b, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'interferon-beta-1b', 'sentence': '0'}, []),
   Annotation(word_embeddings, 54, 60, therapy, {'isOOV': 'false', 'pieceId': '-1', 'isWordStart': 'true', 'token': 'therapy', 'sentence': '0'}, []),
   Annotation(word_embeddings, 62, 64, for, {'isOOV': '

# MTSamples Dataset

In [36]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/mtsamples_classifier.csv

In [37]:
spark_df = spark.read.csv("mtsamples_classifier.csv", header = True)

spark_df.show(10,truncate=100)

+----------------+----------------------------------------------------------------------------------------------------+
|        category|                                                                                                text|
+----------------+----------------------------------------------------------------------------------------------------+
|Gastroenterology| PROCEDURES PERFORMED: Colonoscopy. INDICATIONS: Renewed symptoms likely consistent with active f...|
|Gastroenterology| OPERATION 1. Ivor-Lewis esophagogastrectomy. 2. Feeding jejunostomy. 3. Placement of two right-s...|
|Gastroenterology| PREOPERATIVE DIAGNOSES: 1. Gastroesophageal reflux disease. 2. Chronic dyspepsia. POSTOPERATIVE ...|
|Gastroenterology| PROCEDURE: Colonoscopy. PREOPERATIVE DIAGNOSES: Rectal bleeding and perirectal abscess. POSTOPER...|
|Gastroenterology| PREOPERATIVE DIAGNOSIS: Right colon tumor. POSTOPERATIVE DIAGNOSES: 1. Right colon cancer. 2. As...|
|Gastroenterology| PREOPERATIVE DIAGNOSI

In [38]:
spark_df.groupBy("category").count().show()

+----------------+-----+
|        category|count|
+----------------+-----+
|         Urology|  115|
|       Neurology|  143|
|      Orthopedic|  223|
|Gastroenterology|  157|
+----------------+-----+



In [39]:
(trainingData, testData) = spark_df.randomSplit([0.8, 0.2], seed = 42)
(trainingData_part1, trainingData_part2) = trainingData.randomSplit([0.5, 0.5], seed = 42)

print("trainingData_with_embeddings:       ",trainingData.count())
print("testData_with_embeddings:           ",testData.count())
print("trainingData_part1_with_embeddings: ",trainingData_part1.count())

trainingData_with_embeddings:        536
testData_with_embeddings:            102
trainingData_part1_with_embeddings:  280


We will extract [sbiobert_base_cased_mli](https://nlp.johnsnowlabs.com/2020/11/27/sbiobert_base_cased_mli_en.html) embeddings which has 768 dimension output and use this embeddings in the model training.

In [40]:
document_assembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

bert_sent = nlp.BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", 'en','clinical/models')\
    .setInputCols(["document"])\
    .setOutputCol("sentence_embeddings")

embeddings_pipeline = nlp.Pipeline(
    stages = [document_assembler,
              bert_sent])

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]


In [41]:
trainingData_with_embeddings = embeddings_pipeline.fit(trainingData).transform(trainingData)\
                                                  .select("text","category","sentence_embeddings")

trainingData_with_embeddings.show(2,truncate=60)

+------------------------------------------------------------+----------------+------------------------------------------------------------+
|                                                        text|        category|                                         sentence_embeddings|
+------------------------------------------------------------+----------------+------------------------------------------------------------+
| ADMISSION DIAGNOSIS: Symptomatic cholelithiasis. DISCHAR...|Gastroenterology|[{sentence_embeddings, 0, 2228,  ADMISSION DIAGNOSIS: Sym...|
| ADMITTING DIAGNOSES: Hiatal hernia, gastroesophageal ref...|Gastroenterology|[{sentence_embeddings, 0, 3237,  ADMITTING DIAGNOSES: Hia...|
+------------------------------------------------------------+----------------+------------------------------------------------------------+
only showing top 2 rows



In [42]:
testData_with_embeddings = embeddings_pipeline.fit(testData).transform(testData)\
                                                  .select("text","category","sentence_embeddings")

testData_with_embeddings.show(2,truncate=60)

+------------------------------------------------------------+----------------+------------------------------------------------------------+
|                                                        text|        category|                                         sentence_embeddings|
+------------------------------------------------------------+----------------+------------------------------------------------------------+
| ADMITTING DIAGNOSIS: Gastrointestinal bleed. HISTORY OF ...|Gastroenterology|[{sentence_embeddings, 0, 3978,  ADMITTING DIAGNOSIS: Gas...|
| CHIEF COMPLAINT: Dysphagia and hematemesis while vomitin...|Gastroenterology|[{sentence_embeddings, 0, 6515,  CHIEF COMPLAINT: Dysphag...|
+------------------------------------------------------------+----------------+------------------------------------------------------------+
only showing top 2 rows



In [43]:
trainingData_part1_with_embeddings = embeddings_pipeline.fit(trainingData_part1).transform(trainingData_part1)\
                                                  .select("text","category","sentence_embeddings")

trainingData_part2_with_embeddings = embeddings_pipeline.fit(trainingData_part2).transform(trainingData_part2)\
                                                  .select("text","category","sentence_embeddings")

### ClassifierDL

In [44]:
log_folder="classifier_dl_logs_bert"

In [45]:
classifier_dl = nlp.ClassifierDLApproach()\
        .setInputCols(["sentence_embeddings"])\
        .setOutputCol("prediction_class")\
        .setLabelColumn("category")\
        .setBatchSize(8)\
        .setMaxEpochs(10)\
        .setLr(0.002)\
        .setDropout(0.1)\
        .setEnableOutputLogs(True)\
        .setOutputLogsPath(log_folder)

classifier_dl_pipeline = nlp.Pipeline(stages=[classifier_dl])

In [46]:
%%time
clfDL_model_bert = classifier_dl_pipeline.fit(trainingData_with_embeddings)

CPU times: user 36.4 ms, sys: 17.3 ms, total: 53.7 ms
Wall time: 2min 54s


In [47]:
preds = clfDL_model_bert.transform(testData_with_embeddings)

preds_df = preds.select("category","text","prediction_class.result").toPandas()
preds_df["result"] = preds_df["result"].apply(lambda x : x[0])

print(classification_report(preds_df["category"], preds_df["result"]))

                  precision    recall  f1-score   support

Gastroenterology       1.00      0.88      0.94        25
       Neurology       0.83      0.86      0.84        22
      Orthopedic       0.91      0.83      0.87        35
         Urology       0.76      0.95      0.84        20

        accuracy                           0.87       102
       macro avg       0.87      0.88      0.87       102
    weighted avg       0.88      0.87      0.87       102



In [48]:
results_df = pd.DataFrame(columns=["macro-f1-score","weighted-f1-score","accuracy"])

res = classification_report(preds_df["category"], preds_df["result"], output_dict=True)
results_df.loc["ClassifierDL_full_Data"] = [res["macro avg"]["f1-score"], res["weighted avg"]["f1-score"], res["accuracy"]]

In [49]:
results_df

Unnamed: 0,macro-f1-score,weighted-f1-score,accuracy
ClassifierDL_full_Data,0.872683,0.87421,0.872549


### ClassifierDL partial

In [50]:
log_folder="classifier_dl_logs_bert"

In [51]:
classifier_dl = nlp.ClassifierDLApproach()\
        .setInputCols(["sentence_embeddings"])\
        .setOutputCol("prediction_class")\
        .setLabelColumn("category")\
        .setBatchSize(8)\
        .setMaxEpochs(10)\
        .setLr(0.002)\
        .setDropout(0.1)\
        .setEnableOutputLogs(True)\
        .setOutputLogsPath(log_folder)

classifier_dl_pipeline = nlp.Pipeline(stages=[classifier_dl])

In [52]:
%%time

clfDL_model_bert = classifier_dl_pipeline.fit(trainingData_part1_with_embeddings)

CPU times: user 21.1 ms, sys: 5.18 ms, total: 26.3 ms
Wall time: 1min 27s


In [53]:
preds = clfDL_model_bert.transform(testData_with_embeddings)

preds_df = preds.select("category","text","prediction_class.result").toPandas()
preds_df["result"] = preds_df["result"].apply(lambda x : x[0])

print (classification_report(preds_df["category"], preds_df["result"]))

                  precision    recall  f1-score   support

Gastroenterology       0.00      0.00      0.00        25
       Neurology       0.00      0.00      0.00        22
      Orthopedic       0.34      1.00      0.51        35
         Urology       0.00      0.00      0.00        20

        accuracy                           0.34       102
       macro avg       0.09      0.25      0.13       102
    weighted avg       0.12      0.34      0.18       102



In [54]:
res = classification_report(preds_df["category"], preds_df["result"], output_dict=True)
results_df.loc["ClassifierDL_partial_Data"] = [res["macro avg"]["f1-score"], res["weighted avg"]["f1-score"], res["accuracy"]]

In [55]:
results_df

Unnamed: 0,macro-f1-score,weighted-f1-score,accuracy
ClassifierDL_full_Data,0.872683,0.87421,0.872549
ClassifierDL_partial_Data,0.127737,0.175326,0.343137


## Few Shot partial

In [56]:
graph_builder = medical.TFGraphBuilder()\
    .setModelName("fewshot_classifier")\
    .setInputCols(["sentence_embeddings"]) \
    .setLabelColumn("category")\
    .setGraphFolder("/tmp")\
    .setGraphFile("log_reg_graph.pb")\

few_shot_approach = medical.FewShotClassifierApproach()\
    .setLabelColumn("category")\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("prediction")\
    .setModelFile(f"/tmp/log_reg_graph.pb")\
    .setEpochsNumber(10)\
    .setBatchSize(8)\
    .setLearningRate(0.002)

pipeline = nlp.Pipeline(
    stages=[
        graph_builder,
        few_shot_approach
    ])

In [57]:
%%time
model = pipeline.fit(trainingData_part1_with_embeddings)

TF Graph Builder configuration:
Model name: fewshot_classifier
Graph folder: /tmp
Graph file name: log_reg_graph.pb
Build params: {'input_dim': 768, 'output_dim': 4, 'hidden_layers': [], 'output_act': 'sigmoid'}
fewshot_classifier graph exported to /tmp/log_reg_graph.pb
CPU times: user 446 ms, sys: 19.9 ms, total: 466 ms
Wall time: 2min 45s


In [58]:
preds = model.transform(testData_with_embeddings)

In [59]:
from sklearn.metrics import  classification_report

preds_df = preds.select("category","text","prediction.result").toPandas()
preds_df["result"] = preds_df["result"].apply(lambda x : x[0])

print (classification_report(preds_df["category"], preds_df["result"]))

                  precision    recall  f1-score   support

Gastroenterology       0.88      0.88      0.88        25
       Neurology       0.84      0.73      0.78        22
      Orthopedic       0.79      0.89      0.84        35
         Urology       0.84      0.80      0.82        20

        accuracy                           0.83       102
       macro avg       0.84      0.82      0.83       102
    weighted avg       0.84      0.83      0.83       102



In [60]:
res = classification_report(preds_df["category"], preds_df["result"], output_dict=True)
results_df.loc["FewShot_partial_Data"] = [res["macro avg"]["f1-score"], res["weighted avg"]["f1-score"], res["accuracy"]]

In [61]:
results_df

Unnamed: 0,macro-f1-score,weighted-f1-score,accuracy
ClassifierDL_full_Data,0.872683,0.87421,0.872549
ClassifierDL_partial_Data,0.127737,0.175326,0.343137
FewShot_partial_Data,0.82971,0.832405,0.833333


# Few Shot Clasification Model

|index|model|
|-----:|:-----|
|1|[few_shot_classifier_age_group_sbiobert_cased_mli](https://nlp.johnsnowlabs.com/2023/08/17/few_shot_classifier_age_group_sbiobert_cased_mli_en.html)
|2|[few_shot_classifier_patient_complaint_sbiobert_cased_mli](https://nlp.johnsnowlabs.com/2023/08/30/few_shot_classifier_patient_complaint_sbiobert_cased_mli_en.html)

1.	`few_shot_classifier_age_group_sbiobert_cased_mli`

In [62]:
document_assembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

bert_sent = nlp.BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols(["document"])\
    .setOutputCol("sentence_embeddings")

few_shot_classifier = medical.FewShotClassifierModel.pretrained("few_shot_classifier_age_group_sbiobert_cased_mli", "en", "clinical/models")\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("prediction")

clf_Pipeline = nlp.Pipeline(stages=[
    document_assembler,
    bert_sent,
    few_shot_classifier
])

data = spark.createDataFrame([
    ["""A patient presented with complaints of chest pain and shortness of breath. The medical history revealed the patient had a smoking habit for over 30 years, and was diagnosed with hypertension two years ago. After a detailed physical examination, the doctor found a noticeable wheeze on lung auscultation and prescribed a spirometry test, which showed irreversible airway obstruction. The patient was diagnosed with Chronic obstructive pulmonary disease (COPD) caused by smoking."""],
 ["""Hi, wondering if anyone has had a similar situation. My 1 year old daughter has the following; loose stools/ pale stools, elevated liver enzymes, low iron.  5 months and still no answers from drs. """],
 ["""Hi have chronic gastritis from 4 month(confirmed by endoscopy).I do not have acid reflux.Only dull ache above abdomen and left side of chest.I am on reberprozole and librax.My question is whether chronic gastritis is curable or is it a lifetime condition?I am loosing hope because this dull ache is not going away.Please please reply"""]
    ]).toDF("text")

result = clf_Pipeline.fit(data).transform(data)

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]
few_shot_classifier_age_group_sbiobert_cased_mli download started this may take some time.
Approximate size to download 46.1 KB
[OK!]


In [63]:
few_shot_classifier.getClasses()

['Unknown', 'Adult', 'Child']

In [64]:
result.show()

+--------------------+--------------------+--------------------+--------------------+
|                text|            document| sentence_embeddings|          prediction|
+--------------------+--------------------+--------------------+--------------------+
|A patient present...|[{document, 0, 47...|[{sentence_embedd...|[{category, 0, 0,...|
|Hi, wondering if ...|[{document, 0, 19...|[{sentence_embedd...|[{category, 0, 0,...|
|Hi have chronic g...|[{document, 0, 33...|[{sentence_embedd...|[{category, 0, 0,...|
+--------------------+--------------------+--------------------+--------------------+



In [65]:
result.select('prediction.result','text').show(truncate=150)

+---------+------------------------------------------------------------------------------------------------------------------------------------------------------+
|   result|                                                                                                                                                  text|
+---------+------------------------------------------------------------------------------------------------------------------------------------------------------+
|  [Adult]|A patient presented with complaints of chest pain and shortness of breath. The medical history revealed the patient had a smoking habit for over 30...|
|  [Child]|Hi, wondering if anyone has had a similar situation. My 1 year old daughter has the following; loose stools/ pale stools, elevated liver enzymes, l...|
|[Unknown]|Hi have chronic gastritis from 4 month(confirmed by endoscopy).I do not have acid reflux.Only dull ache above abdomen and left side of chest.I am o...|
+---------+-----------

2.	`few_shot_classifier_patient_complaint_sbiobert_cased_mli`

In [66]:
document_assembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

bert_sent = nlp.BertSentenceEmbeddings.pretrained("sbiobert_base_cased_mli", "en", "clinical/models")\
    .setInputCols(["document"])\
    .setOutputCol("sentence_embeddings")

few_shot_classifier = medical.FewShotClassifierModel.pretrained("few_shot_classifier_patient_complaint_sbiobert_cased_mli", "en", "clinical/models")\
    .setInputCols(["sentence_embeddings"])\
    .setOutputCol("prediction")

clf_Pipeline = nlp.Pipeline(stages=[
    document_assembler,
    bert_sent,
    few_shot_classifier
])

data = spark.createDataFrame([["""The Medical Center is a large state of the art hospital facility with great doctors, nurses, technicians and receptionists.  Service is top notch, knowledgeable and friendly.  This hospital site has plenty of parking"""],
 ["""My gf dad wasn’t feeling well so we decided to take him to this place cus it’s his insurance and we waited for a while and mind that my girl dad couldn’t breath good while the staff seem not to care and when they got to us they said they we’re gonna a take some blood samples and they made us wait again and to see the staff workers talking to each other and laughing taking there time and not seeming to care about there patience, while we were in the lobby there was another guy who told us they also made him wait while he can hardly breath and they left him there to wait my girl dad is coughing and not doing better and when the lady came in my girl dad didn’t have his shirt because he was hot and the lady came in said put on his shirt on and then left still waiting to get help rn"""]]).toDF("text")

result = clf_Pipeline.fit(data).transform(data)

sbiobert_base_cased_mli download started this may take some time.
Approximate size to download 384.3 MB
[OK!]
few_shot_classifier_patient_complaint_sbiobert_cased_mli download started this may take some time.
Approximate size to download 37.5 KB
[OK!]


In [67]:
few_shot_classifier.getClasses()

['No_Complaint', 'Complaint']

In [68]:
result.show()


+--------------------+--------------------+--------------------+--------------------+
|                text|            document| sentence_embeddings|          prediction|
+--------------------+--------------------+--------------------+--------------------+
|The Medical Cente...|[{document, 0, 21...|[{sentence_embedd...|[{category, 0, 0,...|
|My gf dad wasn’t ...|[{document, 0, 78...|[{sentence_embedd...|[{category, 0, 0,...|
+--------------------+--------------------+--------------------+--------------------+



In [69]:
result.select('prediction.result', 'text').show(truncate=150)


+--------------+------------------------------------------------------------------------------------------------------------------------------------------------------+
|        result|                                                                                                                                                  text|
+--------------+------------------------------------------------------------------------------------------------------------------------------------------------------+
|[No_Complaint]|The Medical Center is a large state of the art hospital facility with great doctors, nurses, technicians and receptionists.  Service is top notch, ...|
|   [Complaint]|My gf dad wasn’t feeling well so we decided to take him to this place cus it’s his insurance and we waited for a while and mind that my girl dad co...|
+--------------+------------------------------------------------------------------------------------------------------------------------------------------------