![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/1.5.Resume_MedicalNer_Model_Training.ipynb)

# 1.5 Resume MedicalNer Model Training

Steps:
- Train a new model for a few epochs.
- Load the same model and train for more epochs, and check stats.
- Train a model already trained on a different data

## Colab Setup

In [1]:
import json

from google.colab import files

license_keys = files.upload()

with open('jsl_keys.json') as f:
    license_keys = json.load(f)

Saving jsl_keys.json to jsl_keys.json


In [2]:

%%capture
for k,v in license_keys.items(): 
    %set_env $k=$v

! wget https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/jsl_colab_setup.sh
! bash jsl_colab_setup.sh -p 3.0.2 -s 3.1.2rc1

! pip install spark-nlp-display

# for Spark 2.4.x and Spark NLP 2.x, do the following
# !wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/colab_setup.sh
# !bash colab_setup.sh -p 2.4.x -s 2.7.x

In [3]:
# if you want to start the session with custom params as in start function above
def start(secret):
    builder = SparkSession.builder \
        .appName("Spark NLP Licensed") \
        .master("local[*]") \
        .config("spark.driver.memory", "16G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.kryoserializer.buffer.max", "2000M") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:"+version) \
        .config("spark.jars", "https://pypi.johnsnowlabs.com/"+secret+"/spark-nlp-jsl-"+jsl_version+".jar")
      
    return builder.getOrCreate()

#spark = start(secret)

In [4]:
import json
import os
from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl
import sparknlp

params = {"spark.driver.memory":"16G",
"spark.kryoserializer.buffer.max":"2000M",
"spark.driver.maxResultSize":"2000M"}

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)

print ("Spark NLP Version :", sparknlp.version())
print ("Spark NLP_JSL Version :", sparknlp_jsl.version())

Spark NLP Version : 3.1.2
Spark NLP_JSL Version : 3.1.2


## Download Clinical Word Embeddings for training

In [5]:
clinical_embeddings = WordEmbeddingsModel.pretrained('embeddings_clinical', "en", "clinical/models")\
    .setInputCols(["sentence", "token"])\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


## Download Data for Training (NCBI Disease Dataset)

In [6]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NCBI_disease_official_test.conll
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NCBI_disease_official_train_dev.conll

In [7]:
from sparknlp.training import CoNLL

training_data = CoNLL().readDataset(spark, 'NCBI_disease_official_train_dev.conll')

training_data.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|Identification of...|[[document, 0, 89...|[[document, 0, 89...|[[token, 0, 13, I...|[[pos, 0, 13, NN,...|[[named_entity, 0...|
|The adenomatous p...|[[document, 0, 21...|[[document, 0, 21...|[[token, 0, 2, Th...|[[pos, 0, 2, NN, ...|[[named_entity, 0...|
|Complex formation...|[[document, 0, 63...|[[document, 0, 63...|[[token, 0, 6, Co...|[[pos, 0, 6, NN, ...|[[named_entity, 0...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows



In [8]:
from sparknlp.training import CoNLL

test_data = CoNLL().readDataset(spark, 'NCBI_disease_official_test.conll')

test_data.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|Clustering of mis...|[[document, 0, 10...|[[document, 0, 10...|[[token, 0, 9, Cl...|[[pos, 0, 9, NN, ...|[[named_entity, 0...|
|Ataxia - telangie...|[[document, 0, 13...|[[document, 0, 13...|[[token, 0, 5, At...|[[pos, 0, 5, NN, ...|[[named_entity, 0...|
|The risk of cance...|[[document, 0, 15...|[[document, 0, 15...|[[token, 0, 2, Th...|[[pos, 0, 2, NN, ...|[[named_entity, 0...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows



In [9]:
# save the test data as parquet for easy testing
clinical_embeddings.transform(test_data).write.parquet('test.parquet')

## Train a new model, pause, and resume training on the same dataset.

### Create a graph

In [None]:
from sparknlp_jsl.training import tf_graph
%tensorflow_version 1.x
tf_graph.print_model_params("ner_dl")

tf_graph.build("ner_dl", build_params={"embeddings_dim": 200, "nchars": 128, "ntags": 12, "is_medical": 1}, model_location="./medical_ner_graphs", model_filename="auto")


### Train for 2 epochs

In [11]:
nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('./test.parquet')\
      .setGraphFolder('./medical_ner_graphs')\
      .setOutputLogsPath('./ner_logs')

ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [12]:

%%time
ner_model = ner_pipeline.fit(training_data)


CPU times: user 2.78 s, sys: 280 ms, total: 3.06 s
Wall time: 8min 52s


In [27]:
! cat ner_logs/MedicalNerApproach_372ae93d46cc.log

Name of the selected graph: /content/./medical_ner_graphs/blstm_12_200_128_128.pb
Training started - total epochs: 2 - lr: 0.003 - batch size: 8 - labels: 3 - chars: 84 - training examples: 6347


Epoch 1/2 started, lr: 0.003, dataset size: 6347


Epoch 1/2 - 214.52s - loss: 1841.0598 - batches: 795
Quality on test dataset: 
time to finish evaluation: 14.19s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 960	 317	 127	 0.7517619	 0.88316464	 0.81218266
B-Disease	 771	 181	 189	 0.80987394	 0.803125	 0.80648535
tp: 1731 fp: 498 fn: 316 labels: 2
Macro-average	 prec: 0.7808179, rec: 0.84314483, f1: 0.81078535
Micro-average	 prec: 0.7765814, rec: 0.8456277, f1: 0.80963516


Epoch 2/2 started, lr: 0.0029850747, dataset size: 6347


Epoch 2/2 - 242.99s - loss: 754.204 - batches: 795
Quality on test dataset: 
time to finish evaluation: 12.79s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 953	 152	 134	 0.86244345	 0.87672496	 0.86952555
B-Disease	 852	 132	 108	 0.86585367	 0.8875	 0.8765433
tp: 

### Evaluate

In [14]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_model.stages[1].transform(clinical_embeddings.transform(test_data))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+-----+-----+-----+---------+------+------+
| entity|   tp|   fp|   fn|total|precision|recall|    f1|
+-------+-----+-----+-----+-----+---------+------+------+
|Disease|832.0|152.0|123.0|955.0|   0.8455|0.8712|0.8582|
+-------+-----+-----+-----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8581743166580712|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8581743166580712|
+------------------+

None


### Save the model to disk

In [15]:
ner_model.stages[1].write().overwrite().save('models/NCBI_NER_2_epoch')

### Train using the saved model

In [16]:

nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('/content/test.parquet')\
      .setOutputLogsPath('ner_logs')\
      .setGraphFolder('medical_ner_graphs')\
      .setPretrainedModelPath("models/NCBI_NER_2_epoch") ## load exisitng model
    
ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [17]:

%%time
ner_model_retrained = ner_pipeline.fit(training_data)


CPU times: user 6.38 s, sys: 632 ms, total: 7.02 s
Wall time: 20min 24s


In [25]:
!cat ./ner_logs/MedicalNerApproach_11757a341252.log

Name of the selected graph: pretrained graph
Training started - total epochs: 5 - lr: 0.003 - batch size: 8 - labels: 3 - chars: 84 - training examples: 6347


Epoch 1/5 started, lr: 0.003, dataset size: 6347


Epoch 1/5 - 219.99s - loss: 594.5702 - batches: 795
Quality on test dataset: 
time to finish evaluation: 13.01s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 1013	 274	 74	 0.7871018	 0.93192273	 0.853412
B-Disease	 844	 200	 116	 0.8084291	 0.87916666	 0.8423153
tp: 1857 fp: 474 fn: 190 labels: 2
Macro-average	 prec: 0.7977655, rec: 0.9055447, f1: 0.8482451
Micro-average	 prec: 0.7966538, rec: 0.90718126, f1: 0.84833264


Epoch 2/5 started, lr: 0.0029850747, dataset size: 6347


Epoch 2/5 - 223.11s - loss: 518.59485 - batches: 795
Quality on test dataset: 
time to finish evaluation: 12.49s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 919	 106	 168	 0.89658535	 0.84544617	 0.87026507
B-Disease	 830	 121	 130	 0.8727655	 0.8645833	 0.86865515
tp: 1749 fp: 227 fn: 298 labels: 2
Macro

In [19]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_model_retrained.stages[1].transform(clinical_embeddings.transform(test_data))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+-----+-----+-----+---------+------+------+
| entity|   tp|   fp|   fn|total|precision|recall|    f1|
+-------+-----+-----+-----+-----+---------+------+------+
|Disease|825.0|126.0|130.0|955.0|   0.8675|0.8639|0.8657|
+-------+-----+-----+-----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8656873032528857|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8656873032528857|
+------------------+

None


## Now let's take a model trained on a different dataset and train using new data

In [20]:
jsl_ner = MedicalNerModel.pretrained('ner_jsl','en','clinical/models')

jsl_ner.getClasses()

ner_jsl download started this may take some time.
Approximate size to download 14.5 MB
[OK!]


['O',
 'B-Injury_or_Poisoning',
 'B-Direction',
 'B-Test',
 'I-Route',
 'B-Admission_Discharge',
 'B-Death_Entity',
 'I-Oxygen_Therapy',
 'I-Drug_BrandName',
 'B-Relationship_Status',
 'B-Duration',
 'I-Alcohol',
 'I-Triglycerides',
 'I-Date',
 'B-Respiration',
 'B-Hyperlipidemia',
 'I-Test',
 'B-Birth_Entity',
 'I-VS_Finding',
 'B-Age',
 'I-Social_History_Header',
 'B-Labour_Delivery',
 'I-Medical_Device',
 'B-Family_History_Header',
 'B-BMI',
 'I-Fetus_NewBorn',
 'I-BMI',
 'B-Temperature',
 'I-Section_Header',
 'I-Communicable_Disease',
 'I-ImagingFindings',
 'I-Psychological_Condition',
 'I-Obesity',
 'I-Sexually_Active_or_Sexual_Orientation',
 'I-Modifier',
 'B-Alcohol',
 'I-Temperature',
 'I-Vaccine',
 'I-Symptom',
 'B-Kidney_Disease',
 'I-Pulse',
 'B-Oncological',
 'I-EKG_Findings',
 'B-Medical_History_Header',
 'I-Relationship_Status',
 'I-Blood_Pressure',
 'B-Cerebrovascular_Disease',
 'I-Diabetes',
 'B-Oxygen_Therapy',
 'B-O2_Saturation',
 'B-Psychological_Condition',
 'B-Hear

### Now train a model using this model as base

In [21]:

nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('/content/test.parquet')\
      .setOutputLogsPath('ner_logs')\
      .setGraphFolder('medical_ner_graphs')\
      .setPretrainedModelPath("/root/cache_pretrained/ner_jsl_en_3.1.0_2.4_1624566960534")\
      .setOverrideExistingTags(True) # since the tags do not align, set this flag to true
    
# do hyperparameter by tuning the params above (max epoch, LR, dropout etc.) to get better results
ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [22]:

%%time
ner_jsl_retrained = ner_pipeline.fit(training_data)


CPU times: user 9 s, sys: 854 ms, total: 9.85 s
Wall time: 29min 36s


In [24]:
!cat ./ner_logs/MedicalNerApproach_ea16f43dd897.log

Name of the selected graph: pretrained graph
Training started - total epochs: 5 - lr: 0.003 - batch size: 8 - labels: 3 - chars: 84 - training examples: 6347


Epoch 1/5 started, lr: 0.003, dataset size: 6347


Epoch 1/5 - 292.36s - loss: 2103.3435 - batches: 795
Quality on test dataset: 
time to finish evaluation: 68.17s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 853	 218	 234	 0.7964519	 0.7847286	 0.7905467
B-Disease	 759	 188	 201	 0.8014783	 0.790625	 0.79601467
tp: 1612 fp: 406 fn: 435 labels: 2
Macro-average	 prec: 0.7989651, rec: 0.7876768, f1: 0.7932808
Micro-average	 prec: 0.7988107, rec: 0.7874939, f1: 0.79311186


Epoch 2/5 started, lr: 0.0029850747, dataset size: 6347


Epoch 2/5 - 300.79s - loss: 978.90295 - batches: 795
Quality on test dataset: 
time to finish evaluation: 48.62s
label	 tp	 fp	 fn	 prec	 rec	 f1
I-Disease	 905	 165	 182	 0.8457944	 0.8325667	 0.83912843
B-Disease	 792	 144	 168	 0.84615386	 0.825	 0.835443
tp: 1697 fp: 309 fn: 350 labels: 2
Macro-average

In [23]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_jsl_retrained.stages[1].transform(clinical_embeddings.transform(test_data))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+-----+-----+-----+---------+------+------+
| entity|   tp|   fp|   fn|total|precision|recall|    f1|
+-------+-----+-----+-----+-----+---------+------+------+
|Disease|816.0|190.0|139.0|955.0|   0.8111|0.8545|0.8322|
+-------+-----+-----+-----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8322284548699643|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8322284548699643|
+------------------+

None
