![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)


# **ZeroShotRelationExtractionModel**

This notebook will cover the different parameters and usages of `ZeroShotRelationExtractionModel` annotator. 

**📖 Learning Objectives:**

1. Understand how to use `ZeroShotRelationExtractionModel`.

2. Become comfortable using the different parameters of the annotator.

3. Identify clinical relations on text without training data.


**🔗 Helpful Links:**

- Documentation : [ZeroShotRelationExtractionModel](https://nlp.johnsnowlabs.com/docs/en/licensed_annotators#zeroshotrelationextractionmodel)

- Python Docs : [ZeroShotRelationExtractionModel](https://nlp.johnsnowlabs.com/licensed/api/python/reference/autosummary/sparknlp_jsl/annotator/re/zero_shot_relation_extraction/index.html#sparknlp_jsl.annotator.re.zero_shot_relation_extraction.ZeroShotRelationExtractionModel)

- Scala Docs : [ZeroShotRelationExtractionModel](https://nlp.johnsnowlabs.com/licensed/api/com/johnsnowlabs/finance/graph/relation_extraction/ZeroShotRelationExtractionModel.html)

- For extended examples of usage, see the [Spark NLP Workshop repository](https://github.com/JohnSnowLabs/spark-nlp-workshop/tree/master/healthcare-nlp/).

## **📜 Background**


`ZeroShotRelationExtractionModel` implements zero-shot binary relations extraction by utilizing `BERT` transformer models trained on the NLI (Natural Language Inference) task.

The model inputs consists of documents/sentences and paired NER chunks, usually obtained by `RENerChunksFilter`. The definitions of relations which are extracted is given by a dictionary structures, specifying a set of statements regarding the relationship of named entities.

These statements are automatically appended to each document in the dataset and the NLI model is used to determine whether a particular relationship between entities.

As a zero-shot model, there is no need to train the model in a specific data  set, neither have the entities previously set.  

## **🎬 Colab Setup**

In [None]:
!pip install -q johnsnowlabs

In [None]:
from johnsnowlabs import nlp


nlp.install(force_browser=True)

<IPython.core.display.Javascript object>

127.0.0.1 - - [11/May/2023 19:56:14] "GET /login?code=qzaVnIKIkV5DF5qQyum3z6zKNJtm1d HTTP/1.1" 200 -


<IPython.core.display.Javascript object>

Downloading license...
Licenses extracted successfully
👌 JSL-Home is up to date! 
👌 Everything is already installed, no changes made


In [None]:
from johnsnowlabs import nlp, medical

spark = nlp.start()

Spark Session already created, some configs may not take.


## **🖨️ Input/Output Annotation Types**

- Input: `CHUNK`, `DOCUMENT`

- Output: `CATEGORY`

## **🔎 Parameters**


- `relationalCategories`: A dictionary with definitions of relational categories. The keys of dictionary are the relation labels and the values are lists of hypothesis templates.
- `predictionThreshold`: Minimal confidence score to encode a relation (Default: `0.5`)
- `multiLabel`: Whether or not a pair of entities can be categorized by multiple relations (Default: `False`).


All the parameters can be set using the corresponding set method in camel case. For example, `.setMultiLabel()`.

### `relationalCategories`

In [None]:
sample_text = "Paracetamol can alleviate headache or sickness. An MRI test can be used to find cancer."

data = spark.createDataFrame([[sample_text]]).toDF("text")
data.show(truncate=False)

+---------------------------------------------------------------------------------------+
|text                                                                                   |
+---------------------------------------------------------------------------------------+
|Paracetamol can alleviate headache or sickness. An MRI test can be used to find cancer.|
+---------------------------------------------------------------------------------------+



The ful pipeline contains many stages, let's define the initial stages in a preprocessing pipeline

In [None]:
documenter = nlp.DocumentAssembler().setInputCol("text").setOutputCol("document")

sentencer = (
    nlp.SentenceDetectorDLModel.pretrained(
        "sentence_detector_dl_healthcare", "en", "clinical/models"
    )
    .setInputCols(["document"])
    .setOutputCol("sentences")
)

tokenizer = nlp.Tokenizer().setInputCols(["sentences"]).setOutputCol("tokens")

words_embedder = (
    nlp.WordEmbeddingsModel()
    .pretrained("embeddings_clinical", "en", "clinical/models")
    .setInputCols(["sentences", "tokens"])
    .setOutputCol("embeddings")
)

ner_clinical = (
    medical.NerModel.pretrained("ner_clinical", "en", "clinical/models")
    .setInputCols(["sentences", "tokens", "embeddings"])
    .setOutputCol("ner_clinical")
)

ner_clinical_converter = (
    medical.NerConverterInternal()
    .setInputCols(["sentences", "tokens", "ner_clinical"])
    .setOutputCol("ner_clinical_chunks")
    .setWhiteList(["PROBLEM", "TEST"])
) 

ner_posology = (
    medical.NerModel.pretrained("ner_posology", "en", "clinical/models")
    .setInputCols(["sentences", "tokens", "embeddings"])
    .setOutputCol("ner_posology")
)

ner_posology_converter = (
    medical.NerConverterInternal()
    .setInputCols(["sentences", "tokens", "ner_posology"])
    .setOutputCol("ner_posology_chunks")
    .setWhiteList(["DRUG"])
)

chunk_merger = (
    medical.ChunkMergeApproach()
    .setInputCols("ner_clinical_chunks", "ner_posology_chunks")
    .setOutputCol("merged_ner_chunks")
)

pos_tagger = (
    nlp.PerceptronModel()
    .pretrained("pos_clinical", "en", "clinical/models")
    .setInputCols(["sentences", "tokens"])
    .setOutputCol("pos_tags")
)

dependency_parser = (
    nlp.DependencyParserModel()
    .pretrained("dependency_conllu", "en")
    .setInputCols(["document", "pos_tags", "tokens"])
    .setOutputCol("dependencies")
)

re_ner_chunk_filter = (
    medical.RENerChunksFilter()
    .setRelationPairs(["problem-test", "problem-drug"])
    .setMaxSyntacticDistance(4)
    .setDocLevelRelations(False)
    .setInputCols(["merged_ner_chunks", "dependencies"])
    .setOutputCol("re_ner_chunks")
)


pipeline = nlp.Pipeline().setStages(
    [
        documenter,
        sentencer,
        tokenizer,
        words_embedder,
        ner_clinical,
        ner_clinical_converter,
        ner_posology,
        ner_posology_converter,
        chunk_merger,
        pos_tagger,
        dependency_parser,
        re_ner_chunk_filter,
    ]
)


processed_df = pipeline.fit(data).transform(data)
processed_df.show()


sentence_detector_dl_healthcare download started this may take some time.
Approximate size to download 367.3 KB
[OK!]
embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
ner_clinical download started this may take some time.
[OK!]
ner_posology download started this may take some time.
[OK!]
pos_clinical download started this may take some time.
Approximate size to download 1.5 MB
[OK!]
dependency_conllu download started this may take some time.
Approximate size to download 16.7 MB
[OK!]
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|           sentences|              tokens|          embeddings|        ner_clinical| ner_clinical_chunks|        ner_posology| ner_posolog

We set the `relationalCategories` using a dictionary with the relation type as keys and a list of possible relations as values. Each possible relation should mention an `NER` label in curly brackets (e.g., `{PERSON}`, or `{DRUG}`) with common way to identify the relation between them. See the example below. 

In [None]:
re_model = (
    medical.ZeroShotRelationExtractionModel.pretrained(
        "re_zeroshot_biobert", "en", "clinical/models"
    )
    .setInputCols(["re_ner_chunks", "sentences"])
    .setOutputCol("relations")
    .setRelationalCategories(
        {
            "ADE": ["{DRUG} causes {PROBLEM}."],
            "IMPROVE": ["{DRUG} improves {PROBLEM}.", "{DRUG} cures {PROBLEM}."],
            "REVEAL": ["{TEST} reveals {PROBLEM}."],
        }
    )
)

re_zeroshot_biobert download started this may take some time.
[OK!]


In [None]:
result = re_model.transform(processed_df)

result.selectExpr("explode(relations) as relation").show(truncate=False)

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|relation                                                                                                                                                                                                                                                                                                                                                         |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [None]:
import pyspark.sql.functions as F

result.select(
    F.explode(F.arrays_zip("relations.metadata", "relations.result")).alias("cols")
).select(
    F.expr("cols['0']['sentence']").alias("sentence"),
    F.expr("cols['0']['entity1_begin']").alias("entity1_begin"),
    F.expr("cols['0']['entity1_end']").alias("entity1_end"),
    F.expr("cols['0']['chunk1']").alias("chunk1"),
    F.expr("cols['0']['entity1']").alias("entity1"),
    F.expr("cols['0']['entity2_begin']").alias("entity2_begin"),
    F.expr("cols['0']['entity2_end']").alias("entity2_end"),
    F.expr("cols['0']['chunk2']").alias("chunk2"),
    F.expr("cols['0']['entity2']").alias("entity2"),
    F.expr("cols['0']['hypothesis']").alias("hypothesis"),
    F.expr("cols['0']['nli_prediction']").alias("nli_prediction"),
    F.expr("cols['1']").alias("relation"),
    F.expr("cols['0']['confidence']").alias("confidence"),
).show(
    truncate=70
)

+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|sentence|entity1_begin|entity1_end|     chunk1|entity1|entity2_begin|entity2_end|  chunk2|entity2|                    hypothesis|nli_prediction|relation|confidence|
+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|       0|            0|         10|Paracetamol|   DRUG|           38|         45|sickness|PROBLEM|Paracetamol improves sickness.|        entail| IMPROVE|0.98819494|
|       0|            0|         10|Paracetamol|   DRUG|           26|         33|headache|PROBLEM|Paracetamol improves headache.|        entail| IMPROVE| 0.9929625|
|       1|           48|         58|An MRI test|   TEST|           80|         85|  cancer|PROBLEM|   An MRI test reveals cancer.|        entail|  REVEAL| 0.9760039|
+---

### `predictionThreshold`

All the scores on the precious had high scores, but let's change the threshold to filter the most confident ones.

In [None]:
re_model.setPredictionThreshold(0.99)

result = re_model.transform(processed_df)

result.select(
    F.explode(F.arrays_zip("relations.metadata", "relations.result")).alias("cols")
).select(
    F.expr("cols['0']['sentence']").alias("sentence"),
    F.expr("cols['0']['entity1_begin']").alias("entity1_begin"),
    F.expr("cols['0']['entity1_end']").alias("entity1_end"),
    F.expr("cols['0']['chunk1']").alias("chunk1"),
    F.expr("cols['0']['entity1']").alias("entity1"),
    F.expr("cols['0']['entity2_begin']").alias("entity2_begin"),
    F.expr("cols['0']['entity2_end']").alias("entity2_end"),
    F.expr("cols['0']['chunk2']").alias("chunk2"),
    F.expr("cols['0']['entity2']").alias("entity2"),
    F.expr("cols['0']['hypothesis']").alias("hypothesis"),
    F.expr("cols['0']['nli_prediction']").alias("nli_prediction"),
    F.expr("cols['1']").alias("relation"),
    F.expr("cols['0']['confidence']").alias("confidence"),
).show(
    truncate=70
)

+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|sentence|entity1_begin|entity1_end|     chunk1|entity1|entity2_begin|entity2_end|  chunk2|entity2|                    hypothesis|nli_prediction|relation|confidence|
+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|       0|            0|         10|Paracetamol|   DRUG|           26|         33|headache|PROBLEM|Paracetamol improves headache.|        entail| IMPROVE| 0.9929625|
+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+



### `multiLabel`

By allowing multilabel classification, more than one label could be assigned to the same relation if they pass the threshold.

In [None]:
re_model.setPredictionThreshold(0.1).setMultiLabel(True).setRelationalCategories(
        {
            "ADE": ["{DRUG} causes {PROBLEM}."],
            "IMPROVE": ["{DRUG} improves {PROBLEM}.", "{DRUG} cures {PROBLEM}."],
            "REVEAL": ["{TEST} reveals {PROBLEM}."],
            "CURE": ["{DRUG} cures {PROBLEM}."],
        }
    )

result = re_model.transform(processed_df)

result.select(
    F.explode(F.arrays_zip("relations.metadata", "relations.result")).alias("cols")
).select(
    F.expr("cols['0']['sentence']").alias("sentence"),
    F.expr("cols['0']['entity1_begin']").alias("entity1_begin"),
    F.expr("cols['0']['entity1_end']").alias("entity1_end"),
    F.expr("cols['0']['chunk1']").alias("chunk1"),
    F.expr("cols['0']['entity1']").alias("entity1"),
    F.expr("cols['0']['entity2_begin']").alias("entity2_begin"),
    F.expr("cols['0']['entity2_end']").alias("entity2_end"),
    F.expr("cols['0']['chunk2']").alias("chunk2"),
    F.expr("cols['0']['entity2']").alias("entity2"),
    F.expr("cols['0']['hypothesis']").alias("hypothesis"),
    F.expr("cols['0']['nli_prediction']").alias("nli_prediction"),
    F.expr("cols['1']").alias("relation"),
    F.expr("cols['0']['confidence']").alias("confidence"),
).show(
    truncate=70
)

+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|sentence|entity1_begin|entity1_end|     chunk1|entity1|entity2_begin|entity2_end|  chunk2|entity2|                    hypothesis|nli_prediction|relation|confidence|
+--------+-------------+-----------+-----------+-------+-------------+-----------+--------+-------+------------------------------+--------------+--------+----------+
|       0|            0|         10|Paracetamol|   DRUG|           26|         33|headache|PROBLEM|   Paracetamol cures headache.|        entail|    CURE|0.99268025|
|       0|            0|         10|Paracetamol|   DRUG|           38|         45|sickness|PROBLEM|Paracetamol improves sickness.|        entail| IMPROVE|0.98819494|
|       0|            0|         10|Paracetamol|   DRUG|           26|         33|headache|PROBLEM|Paracetamol improves headache.|        entail| IMPROVE| 0.9929625|
|   

## Fast inference with [LightPipelines](https://nlp.johnsnowlabs.com/docs/en/concepts#using-spark-nlps-lightpipeline)

We can use Spark NLP's `LightPipeline` to run fast inference directly on text (or list of text) instead of using spark data frames. 

Let's check how to do that.

In [None]:
pipeline = nlp.Pipeline().setStages(
    [
        documenter,
        sentencer,
        tokenizer,
        words_embedder,
        ner_clinical,
        ner_clinical_converter,
        ner_posology,
        ner_posology_converter,
        chunk_merger,
        pos_tagger,
        dependency_parser,
        re_ner_chunk_filter,
        re_model
    ]
)


pipelineModel = pipeline.fit(data)

lp = nlp.LightPipeline(pipelineModel)

In [None]:
result = lp.fullAnnotate("Paracetamol can alleviate headache or sickness. An MRI test can be used to find cancer.")
result

[{'sentences': [Annotation(document, 0, 46, Paracetamol can alleviate headache or sickness., {'sentence': '0'}, []),
   Annotation(document, 48, 86, An MRI test can be used to find cancer., {'sentence': '1'}, [])],
  'document': [Annotation(document, 0, 86, Paracetamol can alleviate headache or sickness. An MRI test can be used to find cancer., {}, [])],
  'ner_clinical': [Annotation(named_entity, 0, 10, B-TREATMENT, {'word': 'Paracetamol', 'confidence': '0.9999', 'sentence': '0'}, []),
   Annotation(named_entity, 12, 14, O, {'word': 'can', 'confidence': '0.9997', 'sentence': '0'}, []),
   Annotation(named_entity, 16, 24, O, {'word': 'alleviate', 'confidence': '0.9764', 'sentence': '0'}, []),
   Annotation(named_entity, 26, 33, B-PROBLEM, {'word': 'headache', 'confidence': '0.9877', 'sentence': '0'}, []),
   Annotation(named_entity, 35, 36, O, {'word': 'or', 'confidence': '0.998', 'sentence': '0'}, []),
   Annotation(named_entity, 38, 45, B-PROBLEM, {'word': 'sickness', 'confidence': '

In [None]:
for re_annotation in reult[0]["relations"]:
  

27 => O
years => O
old => O
patient => O
was => O
admitted => O
to => O
clinic => O
on => O
Sep => B-ADMISSION_DATE
1st => I-ADMISSION_DATE
by => O
Dr => O
. => O
X => O
for => O
a => B-PROBLEM
right-sided => I-PROBLEM
pleural => I-PROBLEM
effusion => I-PROBLEM
for => I-PROBLEM
thoracentesis => I-PROBLEM
. => O
