![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/1.5.Resume_MedicalNer_Model_Training.ipynb)

# 1.5 Resume MedicalNer Model Training

Steps:
- Train a new model for a few epochs.
- Load the same model and train for more epochs on the same taxnonomy, and check stats.
- Train a model already trained on a different data

## Colab Setup

In [2]:
import json

from google.colab import files

license_keys = files.upload()

with open(list(license_keys.keys())[0]) as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)

# Adding license key-value pairs to environment variables
import os
os.environ.update(license_keys)

Saving public330_jsl330_ocr_380.json to public330_jsl330_ocr_380 (1).json


In [3]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

[K     |████████████████████████████████| 212.4 MB 63 kB/s 
[K     |████████████████████████████████| 120 kB 42.1 MB/s 
[K     |████████████████████████████████| 198 kB 53.7 MB/s 
[?25h  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 133 kB 2.0 MB/s 
[K     |████████████████████████████████| 95 kB 2.2 MB/s 
[K     |████████████████████████████████| 66 kB 4.7 MB/s 
[?25h

In [4]:
# if you want to start the session with custom params as in start function above
from pyspark.sql import SparkSession

def start(SECRET):
    builder = SparkSession.builder \
        .appName("Spark NLP Licensed") \
        .master("local[*]") \
        .config("spark.driver.memory", "16G") \
        .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
        .config("spark.kryoserializer.buffer.max", "2000M") \
        .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:"+PUBLIC_VERSION) \
        .config("spark.jars", "https://pypi.johnsnowlabs.com/"+SECRET+"/spark-nlp-jsl-"+JSL_VERSION+".jar")
      
    return builder.getOrCreate()

#spark = start(SECRET)

In [5]:
import json
import os
from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession

from sparknlp.annotator import *
from sparknlp_jsl.annotator import *
from sparknlp.base import *
import sparknlp_jsl
import sparknlp

import warnings
warnings.filterwarnings('ignore')

params = {"spark.driver.memory":"16G", # Amount of memory to use for the driver process, i.e. where SparkContext is initialized
          "spark.kryoserializer.buffer.max":"2000M", # Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified. 
          "spark.driver.maxResultSize":"2000M"} # Limit of total size of serialized results of all partitions for each Spark action (e.g. collect) in bytes. 
                                                # Should be at least 1M, or 0 for unlimited. 

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)

print ("Spark NLP Version :", sparknlp.version())
print ("Spark NLP_JSL Version :", sparknlp_jsl.version())

Spark NLP Version : 3.3.0
Spark NLP_JSL Version : 3.3.0


## Download Clinical Word Embeddings for training

In [6]:
clinical_embeddings = WordEmbeddingsModel.pretrained('embeddings_clinical', "en", "clinical/models")\
    .setInputCols(["sentence", "token"])\
    .setOutputCol("embeddings")

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


## Download Data for Training (NCBI Disease Dataset)

In [7]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NCBI_disease_official_test.conll
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NCBI_disease_official_train_dev.conll

In [8]:
from sparknlp.training import CoNLL

training_data = CoNLL().readDataset(spark, 'NCBI_disease_official_train_dev.conll')

training_data.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|Identification of...|[{document, 0, 89...|[{document, 0, 89...|[{token, 0, 13, I...|[{pos, 0, 13, NN,...|[{named_entity, 0...|
|The adenomatous p...|[{document, 0, 21...|[{document, 0, 21...|[{token, 0, 2, Th...|[{pos, 0, 2, NN, ...|[{named_entity, 0...|
|Complex formation...|[{document, 0, 63...|[{document, 0, 63...|[{token, 0, 6, Co...|[{pos, 0, 6, NN, ...|[{named_entity, 0...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows



In [9]:
from sparknlp.training import CoNLL

test_data = CoNLL().readDataset(spark, 'NCBI_disease_official_test.conll')

test_data.show(3)

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|                text|            document|            sentence|               token|                 pos|               label|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|Clustering of mis...|[{document, 0, 10...|[{document, 0, 10...|[{token, 0, 9, Cl...|[{pos, 0, 9, NN, ...|[{named_entity, 0...|
|Ataxia - telangie...|[{document, 0, 13...|[{document, 0, 13...|[{token, 0, 5, At...|[{pos, 0, 5, NN, ...|[{named_entity, 0...|
|The risk of cance...|[{document, 0, 15...|[{document, 0, 15...|[{token, 0, 2, Th...|[{pos, 0, 2, NN, ...|[{named_entity, 0...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 3 rows



## Split the test data into two parts:
- We Keep the first part separate and use it for training the model further, as it will be totally unseen data from the same taxonomy.

- The second part will be used to testing and evaluating

In [10]:
(test_data_1, test_data_2) = test_data.randomSplit([0.5, 0.5], seed = 100)

# save the test data as parquet for easy testing
clinical_embeddings.transform(test_data_1).write.parquet('test_1.parquet')

clinical_embeddings.transform(test_data_2).write.parquet('test_2.parquet')

## Train a new model, pause, and resume training on the same dataset.

### Create a graph

In [11]:
from sparknlp_jsl.training import tf_graph
%tensorflow_version 1.x
tf_graph.print_model_params("ner_dl")

tf_graph.build("ner_dl", build_params={"embeddings_dim": 200, "nchars": 128, "ntags": 12, "is_medical": 1}, model_location="./medical_ner_graphs", model_filename="auto")


TensorFlow 1.x selected.
ner_dl parameters.
Parameter            Required   Default value        Description
ntags                yes        -                    Number of tags.
embeddings_dim       no         200                  Embeddings dimension.
nchars               no         100                  Number of chars.
lstm_size            no         128                  Number of LSTM units.
gpu_device           no         0                    Device for training.
is_medical           no         0                    Build a Medical Ner graph.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
non-resource variables are not supported in the long term

### Train for 2 epochs

In [12]:
nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('./test_2.parquet')\
      .setGraphFolder('./medical_ner_graphs')\
      .setOutputLogsPath('./ner_logs')

ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [13]:

%%time
ner_model = ner_pipeline.fit(training_data)


CPU times: user 2.58 s, sys: 267 ms, total: 2.85 s
Wall time: 8min 13s


In [14]:
# Training Logs
! cat ner_logs/MedicalNerApproach_a6231a3fe051.log

cat: ner_logs/MedicalNerApproach_a6231a3fe051.log: No such file or directory


In [15]:
# Logs of 4 consecutive epochs to compare with 2+2 epochs on separate datasets from same taxonomy

#!cat ner_logs/MedicalNerApproach_4d3d69967c3f.log

### Evaluate

In [16]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_model.stages[1].transform(clinical_embeddings.transform(test_data_2))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+----+----+-----+---------+------+------+
| entity|   tp|  fp|  fn|total|precision|recall|    f1|
+-------+-----+----+----+-----+---------+------+------+
|Disease|367.0|66.0|93.0|460.0|   0.8476|0.7978|0.8219|
+-------+-----+----+----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8219484882418813|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8219484882418813|
+------------------+

None


### Save the model to disk

In [17]:
ner_model.stages[1].write().overwrite().save('models/NCBI_NER_2_epoch')

### Train using the saved model on unseen dataset
#### We use unseen data from same taxonomy

In [18]:

nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('/content/test_2.parquet')\
      .setOutputLogsPath('ner_logs')\
      .setGraphFolder('medical_ner_graphs')\
      .setPretrainedModelPath("models/NCBI_NER_2_epoch") ## load exisitng model
    
ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [19]:

%%time
ner_model_retrained = ner_pipeline.fit(test_data_1)


CPU times: user 332 ms, sys: 40.8 ms, total: 373 ms
Wall time: 57.7 s


In [20]:
!cat ./ner_logs/MedicalNerApproach_f7726480b5ef.log

cat: ./ner_logs/MedicalNerApproach_f7726480b5ef.log: No such file or directory


In [21]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_model_retrained.stages[1].transform(clinical_embeddings.transform(test_data_2))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+----+----+-----+---------+------+------+
| entity|   tp|  fp|  fn|total|precision|recall|    f1|
+-------+-----+----+----+-----+---------+------+------+
|Disease|418.0|65.0|42.0|460.0|   0.8654|0.9087|0.8865|
+-------+-----+----+----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8865323435843054|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8865323435843054|
+------------------+

None


## Now let's take a model trained on a different dataset and train on this dataset

In [22]:
jsl_ner = MedicalNerModel.pretrained('ner_jsl','en','clinical/models')

jsl_ner.getClasses()

ner_jsl download started this may take some time.
Approximate size to download 14.5 MB
[OK!]


['O',
 'B-Injury_or_Poisoning',
 'B-Direction',
 'B-Test',
 'I-Route',
 'B-Admission_Discharge',
 'B-Death_Entity',
 'I-Oxygen_Therapy',
 'I-Drug_BrandName',
 'B-Relationship_Status',
 'B-Duration',
 'I-Alcohol',
 'I-Triglycerides',
 'I-Date',
 'B-Respiration',
 'B-Hyperlipidemia',
 'I-Test',
 'B-Birth_Entity',
 'I-VS_Finding',
 'B-Age',
 'I-Social_History_Header',
 'B-Labour_Delivery',
 'I-Medical_Device',
 'B-Family_History_Header',
 'B-BMI',
 'I-Fetus_NewBorn',
 'I-BMI',
 'B-Temperature',
 'I-Section_Header',
 'I-Communicable_Disease',
 'I-ImagingFindings',
 'I-Psychological_Condition',
 'I-Obesity',
 'I-Sexually_Active_or_Sexual_Orientation',
 'I-Modifier',
 'B-Alcohol',
 'I-Temperature',
 'I-Vaccine',
 'I-Symptom',
 'B-Kidney_Disease',
 'I-Pulse',
 'B-Oncological',
 'I-EKG_Findings',
 'B-Medical_History_Header',
 'I-Relationship_Status',
 'I-Blood_Pressure',
 'B-Cerebrovascular_Disease',
 'I-Diabetes',
 'B-Oxygen_Therapy',
 'B-O2_Saturation',
 'B-Psychological_Condition',
 'B-Hear

### Now train a model using this model as base

In [23]:

nerTagger = MedicalNerApproach()\
      .setInputCols(["sentence", "token", "embeddings"])\
      .setLabelColumn("label")\
      .setOutputCol("ner")\
      .setMaxEpochs(2)\
      .setLr(0.003)\
      .setBatchSize(8)\
      .setRandomSeed(0)\
      .setVerbose(1)\
      .setEvaluationLogExtended(True) \
      .setEnableOutputLogs(True)\
      .setIncludeConfidence(True)\
      .setTestDataset('/content/test_2.parquet')\
      .setOutputLogsPath('ner_logs')\
      .setGraphFolder('medical_ner_graphs')\
      .setPretrainedModelPath("/root/cache_pretrained/ner_jsl_en_3.1.0_2.4_1624566960534")\
      .setOverrideExistingTags(True) # since the tags do not align, set this flag to true
    
# do hyperparameter by tuning the params above (max epoch, LR, dropout etc.) to get better results
ner_pipeline = Pipeline(stages=[
      clinical_embeddings,
      nerTagger
 ])

In [24]:

%%time
ner_jsl_retrained = ner_pipeline.fit(training_data)


CPU times: user 4.14 s, sys: 438 ms, total: 4.58 s
Wall time: 14min


In [25]:
!cat ./ner_logs/MedicalNerApproach_880049ba07d7.log

cat: ./ner_logs/MedicalNerApproach_880049ba07d7.log: No such file or directory


In [26]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

pred_df = ner_jsl_retrained.stages[1].transform(clinical_embeddings.transform(test_data_2))

evaler = NerDLMetrics(mode="full_chunk", dropO=True)

eval_result = evaler.computeMetricsFromDF(pred_df.select("label","ner"), prediction_col="ner", label_col="label").cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
    .withColumn("recall", F.round(eval_result["recall"],4))\
    .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+-------+-----+----+----+-----+---------+------+------+
| entity|   tp|  fp|  fn|total|precision|recall|    f1|
+-------+-----+----+----+-----+---------+------+------+
|Disease|384.0|80.0|76.0|460.0|   0.8276|0.8348|0.8312|
+-------+-----+----+----+-----+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8311688311688312|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.8311688311688312|
+------------------+

None
