![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/10.Clinical_Relation_Extraction.ipynb)

# Clinical Relation Extraction Model

## Colab Setup

In [None]:
import json

with open('workshop_license_keys_365.json') as f:
    license_keys = json.load(f)

license_keys.keys()


dict_keys(['PUBLIC_VERSION', 'JSL_VERSION', 'SECRET', 'SPARK_NLP_LICENSE', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY', 'SPARK_OCR_LICENSE', 'SPARK_OCR_SECRET'])

In [None]:
import os

# Install java
! apt-get update -qq
! apt-get install -y openjdk-8-jdk-headless -qq > /dev/null

os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["PATH"] = os.environ["JAVA_HOME"] + "/bin:" + os.environ["PATH"]
! java -version

secret = license_keys['SECRET']

os.environ['SPARK_NLP_LICENSE'] = license_keys['SPARK_NLP_LICENSE']
os.environ['AWS_ACCESS_KEY_ID']= license_keys['AWS_ACCESS_KEY_ID']
os.environ['AWS_SECRET_ACCESS_KEY'] = license_keys['AWS_SECRET_ACCESS_KEY']
version = license_keys['PUBLIC_VERSION']
jsl_version = license_keys['JSL_VERSION']


! pip install --ignore-installed -q pyspark==2.4.4

! python -m pip install --upgrade spark-nlp-jsl==$jsl_version  --extra-index-url https://pypi.johnsnowlabs.com/$secret

! pip install --ignore-installed -q spark-nlp==$version

import sparknlp

print (sparknlp.version())

import json
import os
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession


from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl

spark = sparknlp_jsl.start(secret)

openjdk version "1.8.0_265"
OpenJDK Runtime Environment (build 1.8.0_265-8u265-b01-0ubuntu2~18.04-b01)
OpenJDK 64-Bit Server VM (build 25.265-b01, mixed mode)
Looking in indexes: https://pypi.org/simple, https://pypi.johnsnowlabs.com/2.5.5-4f4b7f600f8ba3cdc5973a6baa47b901b0c8d8a3
Requirement already up-to-date: spark-nlp-jsl==2.5.5 in /usr/local/lib/python3.6/dist-packages (2.5.5)
2.5.5


## Posology Releation Extraction

This is a demonstration of using SparkNLP for extracting posology relations. The following relatios are supported:

DRUG-DOSAGE
DRUG-FREQUENCY
DRUG-ADE (Adversed Drug Events)
DRUG-FORM
DRUG-ROUTE
DRUG-DURATION
DRUG-REASON
DRUG=STRENGTH

The model has been validated agains the posology dataset described in (Magge, Scotch, & Gonzalez-Hernandez, 2018).

| Relation | Recall | Precision | F1 | F1 (Magge, Scotch, & Gonzalez-Hernandez, 2018) |
| --- | --- | --- | --- | --- |
| DRUG-ADE | 0.66 | 1.00 | **0.80** | 0.76 |
| DRUG-DOSAGE | 0.89 | 1.00 | **0.94** | 0.91 |
| DRUG-DURATION | 0.75 | 1.00 | **0.85** | 0.92 |
| DRUG-FORM | 0.88 | 1.00 | **0.94** | 0.95* |
| DRUG-FREQUENCY | 0.79 | 1.00 | **0.88** | 0.90 |
| DRUG-REASON | 0.60 | 1.00 | **0.75** | 0.70 |
| DRUG-ROUTE | 0.79 | 1.00 | **0.88** | 0.95* |
| DRUG-STRENGTH | 0.95 | 1.00 | **0.98** | 0.97 |


*Magge, Scotch, Gonzalez-Hernandez (2018) collapsed DRUG-FORM and DRUG-ROUTE into a single relation.

In [None]:
import os
import re
import pyspark
import sparknlp
import sparknlp_jsl
import functools 
import json

import numpy as np
from scipy import spatial
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from sparknlp_jsl.annotator import *
from sparknlp.annotator import *
from sparknlp.base import *


**Build pipeline using SparNLP pretrained models and the relation extration model optimized for posology**.
 
 The precision of the RE model is controlled by "setMaxSyntacticDistance(4)", which sets the maximum syntactic distance between named entities to 4. A larger value will improve recall at the expense at lower precision. A value of 4 leads to literally perfect precision (i.e. the model doesn't produce any false positives) and reasonably good recall.

In [None]:
documenter = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentencer = SentenceDetector()\
    .setInputCols(["document"])\
    .setOutputCol("sentences")

tokenizer = sparknlp.annotators.Tokenizer()\
    .setInputCols(["sentences"])\
    .setOutputCol("tokens")

words_embedder = WordEmbeddingsModel()\
    .pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("embeddings")

pos_tagger = PerceptronModel()\
    .pretrained("pos_clinical", "en", "clinical/models") \
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("pos_tags")

ner_tagger = NerDLModel()\
    .pretrained("ner_posology", "en", "clinical/models")\
    .setInputCols("sentences", "tokens", "embeddings")\
    .setOutputCol("ner_tags")    

ner_chunker = NerConverter()\
    .setInputCols(["sentences", "tokens", "ner_tags"])\
    .setOutputCol("ner_chunks")

dependency_parser = DependencyParserModel()\
    .pretrained("dependency_conllu", "en")\
    .setInputCols(["sentences", "pos_tags", "tokens"])\
    .setOutputCol("dependencies")

reModel = RelationExtractionModel()\
    .pretrained("posology_re", "en", "clinical/models")\
    .setInputCols(["embeddings", "pos_tags", "ner_chunks", "dependencies"])\
    .setOutputCol("relations")\
    .setMaxSyntacticDistance(4)

pipeline = Pipeline(stages=[
    documenter,
    sentencer,
    tokenizer, 
    words_embedder, 
    pos_tagger, 
    ner_tagger,
    ner_chunker,
    dependency_parser,
    reModel
])

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
pos_clinical download started this may take some time.
Approximate size to download 1.7 MB
[OK!]
ner_posology download started this may take some time.
Approximate size to download 13.7 MB
[OK!]
dependency_conllu download started this may take some time.
Approximate size to download 16.6 MB
[OK!]


**Create empty dataframe**

In [None]:
empty_data = spark.createDataFrame([[""]]).toDF("text")


**Create a light pipeline for annotating free text**

In [None]:
model = pipeline.fit(empty_data)
lmodel = sparknlp.base.LightPipeline(model)

**Sample free text**

In [None]:
text = """
The patient was prescribed 1 unit of Advil for 5 days after meals. The patient was also 
given 1 unit of Metformin daily.
He was seen by the endocrinology service and she was discharged on 40 units of insulin glargine at night , 
12 units of insulin lispro with meals , and metformin 1000 mg two times a day.
"""
results = lmodel.fullAnnotate(text)

**Show extracted relations**

In [None]:
for rel in results[0]["relations"]:
    print("{}({}={} - {}={})".format(
        rel.result, 
        rel.metadata['entity1'], 
        rel.metadata['chunk1'], 
        rel.metadata['entity2'],
        rel.metadata['chunk2']
    ))

DOSAGE-DRUG(DOSAGE=1 unit - DRUG=Advil)
DRUG-DURATION(DRUG=Advil - DURATION=for 5 days)
DOSAGE-DRUG(DOSAGE=1 unit - DRUG=Metformin)
DRUG-FREQUENCY(DRUG=Metformin - FREQUENCY=daily)
DOSAGE-DRUG(DOSAGE=40 units - DRUG=insulin glargine)
DRUG-FREQUENCY(DRUG=insulin glargine - FREQUENCY=at night)
DOSAGE-DRUG(DOSAGE=12 units - DRUG=insulin lispro)
DRUG-FREQUENCY(DRUG=insulin lispro - FREQUENCY=with meals)
DRUG-STRENGTH(DRUG=metformin - STRENGTH=1000 mg)
DRUG-FREQUENCY(DRUG=metformin - FREQUENCY=two times a day)


In [None]:
import pandas as pd

def get_relations_df (results):
  rel_pairs=[]
  for rel in results[0]['relations']:
      rel_pairs.append((
          rel.result, 
          rel.metadata['entity1'], 
          rel.metadata['entity1_begin'],
          rel.metadata['entity1_end'],
          rel.metadata['chunk1'], 
          rel.metadata['entity2'],
          rel.metadata['entity2_begin'],
          rel.metadata['entity2_end'],
          rel.metadata['chunk2'], 
          rel.metadata['confidence']
      ))

  rel_df = pd.DataFrame(rel_pairs, columns=['relation','entity1','entity1_begin','entity1_end','chunk1','entity2','entity2_begin','entity2_end','chunk2', 'confidence'])

  return rel_df


rel_df = get_relations_df (results)

rel_df

Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
0,DOSAGE-DRUG,DOSAGE,28,33,1 unit,DRUG,38,42,Advil,1.0
1,DRUG-DURATION,DRUG,38,42,Advil,DURATION,44,53,for 5 days,1.0
2,DOSAGE-DRUG,DOSAGE,96,101,1 unit,DRUG,106,114,Metformin,1.0
3,DRUG-FREQUENCY,DRUG,106,114,Metformin,FREQUENCY,116,120,daily,1.0
4,DOSAGE-DRUG,DOSAGE,190,197,40 units,DRUG,202,217,insulin glargine,1.0
5,DRUG-FREQUENCY,DRUG,202,217,insulin glargine,FREQUENCY,219,226,at night,1.0
6,DOSAGE-DRUG,DOSAGE,231,238,12 units,DRUG,243,256,insulin lispro,1.0
7,DRUG-FREQUENCY,DRUG,243,256,insulin lispro,FREQUENCY,258,267,with meals,1.0
8,DRUG-STRENGTH,DRUG,275,283,metformin,STRENGTH,285,291,1000 mg,1.0
9,DRUG-FREQUENCY,DRUG,275,283,metformin,FREQUENCY,293,307,two times a day,1.0


In [None]:
text ="""A 28-year-old female with a history of gestational diabetes mellitus diagnosed eight years prior to presentation and subsequent type two diabetes mellitus ( T2DM ), 
one prior episode of HTG-induced pancreatitis three years prior to presentation,  associated with an acute hepatitis , and obesity with a body mass index ( BMI ) of 33.5 kg/m2 , presented with a one-week history of polyuria , polydipsia , poor appetite , and vomiting . Two weeks prior to presentation , she was treated with a five-day course of amoxicillin for a respiratory tract infection . She was on metformin , glipizide , and dapagliflozin for T2DM and atorvastatin and gemfibrozil for HTG . She had been on dapagliflozin for six months at the time of presentation. Physical examination on presentation was significant for dry oral mucosa ; significantly , her abdominal examination was benign with no tenderness , guarding , or rigidity . Pertinent laboratory findings on admission were : serum glucose 111 mg/dl , bicarbonate 18 mmol/l , anion gap 20 , creatinine 0.4 mg/dL , triglycerides 508 mg/dL , total cholesterol 122 mg/dL , glycated hemoglobin ( HbA1c ) 10% , and venous pH 7.27 . Serum lipase was normal at 43 U/L . Serum acetone levels could not be assessed as blood samples kept hemolyzing due to significant lipemia . The patient was initially admitted for starvation ketosis , as she reported poor oral intake for three days prior to admission . However , serum chemistry obtained six hours after presentation revealed her glucose was 186 mg/dL , the anion gap was still elevated at 21 , serum bicarbonate was 16 mmol/L , triglyceride level peaked at 2050 mg/dL , and lipase was 52 U/L . The β-hydroxybutyrate level was obtained and found to be elevated at 5.29 mmol/L - the original sample was centrifuged and the chylomicron layer removed prior to analysis due to interference from turbidity caused by lipemia again . The patient was treated with an insulin drip for euDKA and HTG with a reduction in the anion gap to 13 and triglycerides to 1400 mg/dL , within 24 hours . Her euDKA was thought to be precipitated by her respiratory tract infection in the setting of SGLT2 inhibitor use . The patient was seen by the endocrinology service and she was discharged on 40 units of insulin glargine at night , 12 units of insulin lispro with meals , and metformin 1000 mg two times a day . It was determined that all SGLT2 inhibitors should be discontinued indefinitely . 
She had close follow-up with endocrinology post discharge .
"""

annotations = lmodel.fullAnnotate(text)

rel_df = get_relations_df (annotations)

rel_df



Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
0,DURATION-DRUG,DURATION,493,500,five-day,DRUG,512,522,amoxicillin,1.0
1,DRUG-DURATION,DRUG,681,693,dapagliflozin,DURATION,695,708,for six months,1.0
2,DRUG-ROUTE,DRUG,1940,1946,insulin,ROUTE,1948,1951,drip,1.0
3,DOSAGE-DRUG,DOSAGE,2255,2262,40 units,DRUG,2267,2282,insulin glargine,1.0
4,DRUG-FREQUENCY,DRUG,2267,2282,insulin glargine,FREQUENCY,2284,2291,at night,1.0
5,DOSAGE-DRUG,DOSAGE,2295,2302,12 units,DRUG,2307,2320,insulin lispro,1.0
6,DRUG-FREQUENCY,DRUG,2307,2320,insulin lispro,FREQUENCY,2322,2331,with meals,1.0
7,DRUG-STRENGTH,DRUG,2339,2347,metformin,STRENGTH,2349,2355,1000 mg,1.0
8,DRUG-FREQUENCY,DRUG,2339,2347,metformin,FREQUENCY,2357,2371,two times a day,1.0


## Clinical RE

### The set of relations defined in the 2010 i2b2 relation challenge

TrIP: A certain treatment has improved or cured a medical problem (eg, ‘infection resolved with antibiotic course’)

TrWP: A patient's medical problem has deteriorated or worsened because of or in spite of a treatment being administered (eg, ‘the tumor was growing despite the drain’)

TrCP: A treatment caused a medical problem (eg, ‘penicillin causes a rash’)

TrAP: A treatment administered for a medical problem (eg, ‘Dexamphetamine for narcolepsy’)

TrNAP: The administration of a treatment was avoided because of a medical problem (eg, ‘Ralafen which is contra-indicated because of ulcers’)

TeRP: A test has revealed some medical problem (eg, ‘an echocardiogram revealed a pericardial effusion’)

TeCP: A test was performed to investigate a medical problem (eg, ‘chest x-ray done to rule out pneumonia’)

PIP: Two problems are related to each other (eg, ‘Azotemia presumed secondary to sepsis’)

In [None]:
clinical_ner_tagger = sparknlp.annotators.NerDLModel()\
    .pretrained("ner_clinical", "en", "clinical/models")\
    .setInputCols("sentence", "tokens", "embeddings")\
    .setOutputCol("ner_tags")    

clinical_re_Model = RelationExtractionModel()\
    .pretrained("re_clinical", "en", 'clinical/models')\
    .setInputCols(["embeddings", "pos_tags", "ner_chunks", "dependencies"])\
    .setOutputCol("relations")\
    .setMaxSyntacticDistance(4)\
    .setRelationPairs(["problem-test", "problem-treatment"]) # we can set the possible relation pairs (if not set, all the relations will be calculated)

loaded_pipeline = Pipeline(stages=[
    documenter,
    sentencer,
    tokenizer, 
    words_embedder, 
    pos_tagger, 
    clinical_ner_tagger,
    ner_chunker,
    dependency_parser,
    clinical_re_Model
])

ner_clinical download started this may take some time.
Approximate size to download 13.8 MB
[OK!]
re_clinical download started this may take some time.
Approximate size to download 6 MB
[OK!]


In [None]:
loaded_model = loaded_pipeline.fit(empty_data)
loaded_lmodel = LightPipeline(loaded_model)

In [None]:
text ="""A 28-year-old female with a history of gestational diabetes mellitus diagnosed eight years prior to presentation and subsequent type two diabetes mellitus ( T2DM ), 
one prior episode of HTG-induced pancreatitis three years prior to presentation,  associated with an acute hepatitis , and obesity with a body mass index ( BMI ) of 33.5 kg/m2 , presented with a one-week history of polyuria , polydipsia , poor appetite , and vomiting . Two weeks prior to presentation , she was treated with a five-day course of amoxicillin for a respiratory tract infection . She was on metformin , glipizide , and dapagliflozin for T2DM and atorvastatin and gemfibrozil for HTG . She had been on dapagliflozin for six months at the time of presentation. Physical examination on presentation was significant for dry oral mucosa ; significantly , her abdominal examination was benign with no tenderness , guarding , or rigidity . Pertinent laboratory findings on admission were : serum glucose 111 mg/dl , bicarbonate 18 mmol/l , anion gap 20 , creatinine 0.4 mg/dL , triglycerides 508 mg/dL , total cholesterol 122 mg/dL , glycated hemoglobin ( HbA1c ) 10% , and venous pH 7.27 . Serum lipase was normal at 43 U/L . Serum acetone levels could not be assessed as blood samples kept hemolyzing due to significant lipemia . The patient was initially admitted for starvation ketosis , as she reported poor oral intake for three days prior to admission . However , serum chemistry obtained six hours after presentation revealed her glucose was 186 mg/dL , the anion gap was still elevated at 21 , serum bicarbonate was 16 mmol/L , triglyceride level peaked at 2050 mg/dL , and lipase was 52 U/L . The β-hydroxybutyrate level was obtained and found to be elevated at 5.29 mmol/L - the original sample was centrifuged and the chylomicron layer removed prior to analysis due to interference from turbidity caused by lipemia again . The patient was treated with an insulin drip for euDKA and HTG with a reduction in the anion gap to 13 and triglycerides to 1400 mg/dL , within 24 hours . Her euDKA was thought to be precipitated by her respiratory tract infection in the setting of SGLT2 inhibitor use . The patient was seen by the endocrinology service and she was discharged on 40 units of insulin glargine at night , 12 units of insulin lispro with meals , and metformin 1000 mg two times a day . It was determined that all SGLT2 inhibitors should be discontinued indefinitely . 
She had close follow-up with endocrinology post discharge .
"""

annotations = loaded_lmodel.fullAnnotate(text)

rel_df = get_relations_df (annotations)

rel_df[rel_df.relation!="O"]


Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
6,TrAP,PROBLEM,617,620,T2DM,TREATMENT,626,637,atorvastatin,0.99955326
13,TrWP,TEST,1246,1258,blood samples,PROBLEM,1283,1301,significant lipemia,0.99998724
17,TeRP,TEST,1535,1547,the anion gap,PROBLEM,1553,1566,still elevated,0.9965193
23,TrAP,TEST,1838,1845,analysis,PROBLEM,1854,1880,interference from turbidity,0.9676019
27,TrWP,TREATMENT,1937,1951,an insulin drip,PROBLEM,1976,2003,a reduction in the anion gap,0.94099987
30,TeRP,PROBLEM,1976,2003,a reduction in the anion gap,TEST,2015,2027,triglycerides,0.9956793
31,TeRP,PROBLEM,2107,2137,her respiratory tract infection,TREATMENT,2157,2171,SGLT2 inhibitor,0.997498


## Train a Relation Extraction Model

In [None]:

data = spark.read.option("header","true").format("csv").load("i2b2_clinical_relfeatures.csv")

data.show(10)

+-------+-------------+--------------------+--------------------+-------+--------------------+------+----+-----+--------------------+------+----+---------+-------+------------+-------------+------------+-------------+-------------+-------------+-------------+
|dataset|       source|            txt_file|            sentence|sent_id|              chunk1|begin1|end1|  rel|              chunk2|begin2|end2|   label1| label2|lastCharEnt1|firstCharEnt1|lastCharEnt2|firstCharEnt2|words_in_ent1|words_in_ent2|words_between|
+-------+-------------+--------------------+--------------------+-------+--------------------+------+----+-----+--------------------+------+----+---------+-------+------------+-------------+------------+-------------+-------------+-------------+-------------+
|   test|beth+partners|i2b2 2010 VA/test...|# BRBPR -- The pa...|    100|               brbpr|     1|   1|    O|  his lower gi bleed|    12|  15|  problem|problem|           6|            2|          77|           60|   

In [None]:
#Annotation structure
annotationType = T.StructType([
            T.StructField('annotatorType', T.StringType(), False),
            T.StructField('begin', T.IntegerType(), False),
            T.StructField('end', T.IntegerType(), False),
            T.StructField('result', T.StringType(), False),
            T.StructField('metadata', T.MapType(T.StringType(), T.StringType()), False),
            T.StructField('embeddings', T.ArrayType(T.FloatType()), False)
        ])

#UDF function to convert train data to names entitities

@F.udf(T.ArrayType(annotationType))
def createTrainAnnotations(begin1, end1, begin2, end2, chunk1, chunk2, label1, label2):
    
    entity1 = sparknlp.annotation.Annotation("chunk", begin1, end1, chunk1, {'entity': label1.upper(), 'sentence': '0'}, [])
    entity2 = sparknlp.annotation.Annotation("chunk", begin2, end2, chunk2, {'entity': label2.upper(), 'sentence': '0'}, [])    
        
    entity1.annotatorType = "chunk"
    entity2.annotatorType = "chunk"

    return [entity1, entity2]    

#list of valid relations
rels = ["TrIP", "TrAP", "TeCP", "TrNAP", "TrCP", "PIP", "TrWP", "TeRP"]

#a query to select list of valid relations
valid_rel_query = "(" + " OR ".join(["rel = '{}'".format(rel) for rel in rels]) + ")"

data = data\
  .withColumn("begin1i", F.expr("cast(firstCharEnt1 AS Int)"))\
  .withColumn("end1i", F.expr("cast(lastCharEnt1 AS Int)"))\
  .withColumn("begin2i", F.expr("cast(firstCharEnt2 AS Int)"))\
  .withColumn("end2i", F.expr("cast(lastCharEnt2 AS Int)"))\
  .where("begin1i IS NOT NULL")\
  .where("end1i IS NOT NULL")\
  .where("begin2i IS NOT NULL")\
  .where("end2i IS NOT NULL")\
  .where(valid_rel_query)\
  .withColumn(
      "train_ner_chunks", 
      createTrainAnnotations(
          "begin1i", "end1i", "begin2i", "end2i", "chunk1", "chunk2", "label1", "label2"
      ).alias("train_ner_chunks", metadata={'annotatorType': "chunk"}))
    
train_data = data.where("dataset='train'")
test_data = data.where("dataset='test'")

In [None]:
documenter = sparknlp.DocumentAssembler()\
    .setInputCol("sentence")\
    .setOutputCol("document")

sentencer = SentenceDetector()\
    .setInputCols(["document"])\
    .setOutputCol("sentences")

tokenizer = sparknlp.annotators.Tokenizer()\
    .setInputCols(["sentences"])\
    .setOutputCol("tokens")\

words_embedder = WordEmbeddingsModel()\
    .pretrained("embeddings_clinical", "en", "clinical/models")\
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("embeddings")

pos_tagger = PerceptronModel()\
    .pretrained("pos_clinical", "en", "clinical/models") \
    .setInputCols(["sentences", "tokens"])\
    .setOutputCol("pos_tags")
    
dependency_parser = sparknlp.annotators.DependencyParserModel()\
    .pretrained("dependency_conllu", "en")\
    .setInputCols(["document", "pos_tags", "tokens"])\
    .setOutputCol("dependencies")

# set training params and upload model graph (see ../Healthcare/8.Generic_Classifier.ipynb)
reApproach = sparknlp_jsl.annotator.RelationExtractionApproach()\
    .setInputCols(["embeddings", "pos_tags", "train_ner_chunks", "dependencies"])\
    .setOutputCol("relations")\
    .setLabelColumn("rel")\
    .setEpochsNumber(50)\
    .setBatchSize(200)\
    .setLearningRate(0.001)\
    .setModelFile("/content/RE.in1200D.out20.pb")\
    .setFixImbalance(True)\
    .setValidationSplit(0.2)\
    .setFromEntity("begin1i", "end1i", "label1")\
    .setToEntity("begin2i", "end2i", "label2")
    
finisher = sparknlp.Finisher()\
    .setInputCols(["relations"])\
    .setOutputCols(["relations_out"])\
    .setCleanAnnotations(False)\
    .setValueSplitSymbol(",")\
    .setAnnotationSplitSymbol(",")\
    .setOutputAsArray(False)

train_pipeline = Pipeline(stages=[
    documenter, sentencer, tokenizer, words_embedder, pos_tagger, 
    dependency_parser, reApproach, finisher
])

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]
pos_clinical download started this may take some time.
Approximate size to download 1.7 MB
[OK!]
dependency_conllu download started this may take some time.
Approximate size to download 16.6 MB
[OK!]


In [None]:
rel_model = train_pipeline.fit(train_data)

In [None]:
rel_model.stages[-2]

RelationExtractionModel_42c3f7882f9d

In [None]:
rel_model.stages[-2].write().overwrite().save('custom_RE_model')

In [None]:
result = rel_model.transform(test_data)

In [None]:
recall = result\
  .groupBy("rel")\
  .agg(F.avg(F.expr("IF(rel = relations_out, 1, 0)")).alias("recall"))\
  .select(
      F.col("rel").alias("relation"), 
      F.format_number("recall", 2).alias("recall"))\
  .show()

performance  = result\
  .where("relations_out <> ''")\
  .groupBy("relations_out")\
  .agg(F.avg(F.expr("IF(rel = relations_out, 1, 0)")).alias("precision"))\
  .select(
      F.col("relations_out").alias("relation"), 
      F.format_number("precision", 2).alias("precision"))\
  .show()

+--------+------+
|relation|recall|
+--------+------+
|    TrIP|  0.20|
|    TrAP|  0.93|
|    TeCP|  0.53|
|   TrNAP|  0.14|
|    TrCP|  0.50|
|     PIP|  0.96|
|    TrWP|  0.03|
|    TeRP|  0.92|
+--------+------+

+--------+---------+
|relation|precision|
+--------+---------+
|    TrIP|     0.61|
|    TrAP|     0.79|
|    TeCP|     0.59|
|   TrNAP|     0.41|
|    TrCP|     0.62|
|     PIP|     0.98|
|    TrWP|     0.18|
|    TeRP|     0.90|
+--------+---------+



In [None]:
result_df = result.select(F.explode(F.arrays_zip('relations.result', 'relations.metadata')).alias("cols")) \
.select(F.expr("cols['0']").alias("relation"),
        F.expr("cols['1']['entity1']").alias("entity1"),
        F.expr("cols['1']['entity1_begin']").alias("entity1_begin"),
        F.expr("cols['1']['entity1_end']").alias("entity1_end"),
        F.expr("cols['1']['chunk1']").alias("chunk1"),
        F.expr("cols['1']['entity2']").alias("entity2"),
        F.expr("cols['1']['entity2_begin']").alias("entity2_begin"),
        F.expr("cols['1']['entity2_end']").alias("entity2_end"),
        F.expr("cols['1']['chunk2']").alias("chunk2"),
        F.expr("cols['1']['confidence']").alias("confidence")
        )

result_df.show(50, truncate=100)



+--------+---------+-------------+-----------+----------------------------------+-------+-------------+-----------+-------------------------------------------------------------------+----------+
|relation|  entity1|entity1_begin|entity1_end|                            chunk1|entity2|entity2_begin|entity2_end|                                                             chunk2|confidence|
+--------+---------+-------------+-----------+----------------------------------+-------+-------------+-----------+-------------------------------------------------------------------+----------+
|    TeRP|     TEST|            1|         14|                    an angiography|PROBLEM|           22|         44|                                            bleeding in two vessels|       1.0|
|   TrNAP|TREATMENT|            1|         12|                      his coumadin|PROBLEM|           44|         58|                                                    his acute bleed| 0.5183256|
|    TeCP|     TEST|     

# Load trained model from disk

In [None]:
loaded_re_Model = RelationExtractionModel() \
    .load("custom_RE_model")\
    .setInputCols(["embeddings", "pos_tags", "ner_chunks", "dependencies"]) \
    .setOutputCol("relations")\
    .setRelationPairs(["problem-test", "problem-treatment"])\
    .setPredictionThreshold(0.9)\
    .setMaxSyntacticDistance(4)

trained_pipeline = Pipeline(stages=[
    documenter,
    sentencer,
    tokenizer, 
    words_embedder, 
    pos_tagger, 
    clinical_ner_tagger,
    ner_chunker,
    dependency_parser,
    loaded_re_Model
])

empty_data = spark.createDataFrame([[""]]).toDF("sentence")

loaded_re_model = trained_pipeline.fit(empty_data)


text ="""A 28-year-old female with a history of gestational diabetes mellitus diagnosed eight years prior to presentation and subsequent type two diabetes mellitus ( T2DM ), 
one prior episode of HTG-induced pancreatitis three years prior to presentation,  associated with an acute hepatitis , and obesity with a body mass index ( BMI ) of 33.5 kg/m2 , presented with a one-week history of polyuria , polydipsia , poor appetite , and vomiting . Two weeks prior to presentation , she was treated with a five-day course of amoxicillin for a respiratory tract infection . She was on metformin , glipizide , and dapagliflozin for T2DM and atorvastatin and gemfibrozil for HTG . She had been on dapagliflozin for six months at the time of presentation. Physical examination on presentation was significant for dry oral mucosa ; significantly , her abdominal examination was benign with no tenderness , guarding , or rigidity . Pertinent laboratory findings on admission were : serum glucose 111 mg/dl , bicarbonate 18 mmol/l , anion gap 20 , creatinine 0.4 mg/dL , triglycerides 508 mg/dL , total cholesterol 122 mg/dL , glycated hemoglobin ( HbA1c ) 10% , and venous pH 7.27 . Serum lipase was normal at 43 U/L . Serum acetone levels could not be assessed as blood samples kept hemolyzing due to significant lipemia . The patient was initially admitted for starvation ketosis , as she reported poor oral intake for three days prior to admission . However , serum chemistry obtained six hours after presentation revealed her glucose was 186 mg/dL , the anion gap was still elevated at 21 , serum bicarbonate was 16 mmol/L , triglyceride level peaked at 2050 mg/dL , and lipase was 52 U/L . The β-hydroxybutyrate level was obtained and found to be elevated at 5.29 mmol/L - the original sample was centrifuged and the chylomicron layer removed prior to analysis due to interference from turbidity caused by lipemia again . The patient was treated with an insulin drip for euDKA and HTG with a reduction in the anion gap to 13 and triglycerides to 1400 mg/dL , within 24 hours . Her euDKA was thought to be precipitated by her respiratory tract infection in the setting of SGLT2 inhibitor use . The patient was seen by the endocrinology service and she was discharged on 40 units of insulin glargine at night , 12 units of insulin lispro with meals , and metformin 1000 mg two times a day . It was determined that all SGLT2 inhibitors should be discontinued indefinitely . 
She had close follow-up with endocrinology post discharge .
"""

loaded_re_model_light = LightPipeline(loaded_re_model)

annotations = loaded_re_model_light.fullAnnotate(text)

rel_df = get_relations_df (annotations)

rel_df[rel_df.relation!="O"]

Unnamed: 0,relation,entity1,entity1_begin,entity1_end,chunk1,entity2,entity2_begin,entity2_end,chunk2,confidence
0,TrAP,TREATMENT,512,522,amoxicillin,PROBLEM,528,556,a respiratory tract infection,0.99998116
1,TrAP,TREATMENT,571,579,metformin,PROBLEM,617,620,T2DM,0.999838
2,TrAP,TREATMENT,571,579,metformin,PROBLEM,659,661,HTG,0.99375445
3,TrAP,TREATMENT,583,591,glipizide,PROBLEM,617,620,T2DM,0.9999962
4,TrAP,TREATMENT,599,611,dapagliflozin,PROBLEM,617,620,T2DM,0.999997
5,TrAP,TREATMENT,599,611,dapagliflozin,PROBLEM,659,661,HTG,0.93626064
6,TrAP,PROBLEM,617,620,T2DM,TREATMENT,643,653,gemfibrozil,0.9410204
7,TrAP,TREATMENT,626,637,atorvastatin,PROBLEM,659,661,HTG,0.99999654
8,TrAP,TREATMENT,643,653,gemfibrozil,PROBLEM,659,661,HTG,0.99999964
9,TeRP,TEST,830,854,her abdominal examination,PROBLEM,875,884,tenderness,0.99999094
