![JohnSnowLabs](https://sparknlp.org/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Healthcare/4.13.End2End_Preannotation_and_Training_Pipeline.ipynb)

# **End2End Preannotation and Training Pipeline**

## Spark Setup

In [None]:
import json
import os

from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.5.1  spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

In [3]:
import os
import json
import numpy as np
import pandas as pd

import sparknlp
import sparknlp_jsl

from sparknlp.base import *
from sparknlp.util import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml import Pipeline, PipelineModel

import warnings
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

params = {"spark.driver.memory":"48G", # Amount of memory to use for the driver process, i.e. where SparkContext is initialized
          "spark.kryoserializer.buffer.max":"2000M", # Maximum allowable size of Kryo serialization buffer, in MiB unless otherwise specified.
          "spark.driver.maxResultSize":"2000M"} # Limit of total size of serialized results of all partitions for each Spark action (e.g. collect) in bytes.
                                                # Should be at least 1M, or 0 for unlimited.

spark = sparknlp_jsl.start(license_keys['SECRET'],params=params)
spark.sparkContext.setLogLevel("ERROR")
print ("Spark NLP Version :", sparknlp.version())
print ("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark

Spark NLP Version : 6.1.3
Spark NLP_JSL Version : 6.1.1


## Loading the Pretrained Pipeline

Spark NLP's pretrained pipeline, `clinical_deidentification_docwise_benchmark`, is loaded. This pipeline is designed to mask and obfuscate sensitive information in medical texts, such as names, ID numbers, contact information, locations, ages, and dates. The existing stages of the pipeline are examined to understand its structure.

In [4]:
from sparknlp.pretrained import PretrainedPipeline

deid_pipeline = PretrainedPipeline("clinical_deidentification_docwise_benchmark", "en", "clinical/models")

clinical_deidentification_docwise_benchmark download started this may take some time.
Approx size to download 2.3 GB
[OK!]


In [5]:
deid_pipeline.model.stages

[DocumentAssembler_ae0f203deedd,
 InternalDocumentSplitter_cc36578ceda6,
 REGEX_TOKENIZER_2e85686aea12,
 WORD_EMBEDDINGS_MODEL_9004b1d00302,
 MedicalNerModel_1a8637089929,
 NER_CONVERTER_1aef7e9d2de5,
 MedicalNerModel_d92d47622e85,
 MedicalNerModel_32184c1db80b,
 MedicalNerModel_ada39ac0d359,
 NER_CONVERTER_a99db4e6a79d,
 NER_CONVERTER_4a9436714344,
 NER_CONVERTER_ea6433988e18,
 PretrainedZeroShotNER_5f30ab9002f1,
 NER_CONVERTER_c97040caf7b3,
 MedicalNerModel_b8b167ec3114,
 NER_CONVERTER_06db473f3215,
 ContextualEntityRuler_11ff6711ef6b,
 ChunkMergeModel_95d6827691bb,
 CONTEXTUAL-PARSER_bf2a6abaf5fa,
 CONTEXTUAL-PARSER_ff6bad379d91,
 CONTEXTUAL-PARSER_89341cae7221,
 CONTEXTUAL-PARSER_c6b9eded8d31,
 CONTEXTUAL-PARSER_9480c24bd9f8,
 CONTEXTUAL-PARSER_3886bce391c8,
 CONTEXTUAL-PARSER_0bb3fb75cd01,
 ENTITY_EXTRACTOR_6792f2f6e85a,
 ENTITY_EXTRACTOR_74ace4be4f73,
 CONTEXTUAL-PARSER_dfb32adc7555,
 REGEX_MATCHER_5003669d6422,
 CONTEXTUAL-PARSER_746a25662aa6,
 CONTEXTUAL-PARSER_079220479a3d,
 C

### Sample text

In [6]:
text = """
(NOTE) Patient Name: John Lee. MR#: 7789201 Location: LERE Date Reported: 2025-05-12 16:30
Specimen #RD23-4897 Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A
Electronically Signed Out By Dr. Smith, Dr. Carter, CT(ASCP) Date Reported: 2025-05-12 16:30
General Hospital Dr. Fan Gabriel 90210 CPT Code(s) A: 88305

General Hospital in New York City Dr. Williams, NYC, NY
(212) 555-7890 Patient Name: John Lee Accession #: GH-556672
Patient ID #: 7789201 Collected: 2025-05-10 Address:
123 Main Street, FALL RIVER
NIAGARA FALLS, NY 14304
Received: 2025-05-10 Reported: 2025-05-12
Soc. Sec. #: XXX-XX-1234 DOB/Age/Sex: 1973 (Age: 52) M
Physician(s): Dr. Jameson. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.
The following special studies were performed at Barstow Heights Christus Southeast, NY – St Elizabeth; New York City.
· Chromosome analysis cytogenetics. (ADDENDUM REPORT TO FOLLOW.)
· Leukemic immunophenotyping flow cytometry.

...., and there is no evidence of dysplasia.
Fr/ap MATERIAL RECEIVED 6 SLIDES LABELED 032-1902, COLLECTED 2025-05-10
SPECIMEN SOURCE: GASTRIC, ILEUM AND RANDOM COLON, BIOPSIES
REFERRING FACILITY: NY
"""

## Extending the Pipeline with New Stages

New and customized stages are added to enhance the capabilities of the existing pipeline.

In [7]:
document_assembler = DocumentAssembler()\
      .setInputCol("text")\
      .setOutputCol("document")

splitter = (
            InternalDocumentSplitter()
            .setInputCols("document")
            .setOutputCol("splitter")
            .setSplitMode("recursive")
            .setSplitPatterns(["\s+"])  # Token base
            .setPatternsAreRegex(True)
            .setChunkSize(512)    # 512 Char Lenght
            .setChunkOverlap(50)
            .setEnableSentenceIncrement(True)  # Like sentenceDetector
)

tokenizer = (
    Tokenizer()
    .setInputCols("splitter")
    .setOutputCol("token")
)

### Create a Custom `CPT Code` Parser

Using `ContextualParserApproach`, a new parser is created to detect CPT (Current Procedural Terminology) codes within the text based on regex rules. This allows the pipeline to recognize a custom entity type not found in the standard de-identification pipeline.

In [8]:
cpt_rule = {
    "entity": "CPT_CODE",
    "ruleScope": "sentence",
    "regex": r"(?:CPT(?: Code\(s\)?|#|:)?\s*:?[\s#]*)?(\b88[0-9]{3}\b)",
    "matchScope": "token"
}

with open('cpt.json', 'w') as f:
    json.dump(cpt_rule, f)

cpt_parser = ContextualParserApproach() \
    .setInputCols(["splitter", "token"]) \
    .setOutputCol("entity_cpt") \
    .setJsonPath("cpt.json") \
    .setCaseSensitive(False) \
    .setPrefixAndSuffixMatch(False)

cpt_parser_pipeline = Pipeline(stages=[
    document_assembler,
    splitter,
    tokenizer,
    cpt_parser
  ])

empty_data = spark.createDataFrame([[""]]).toDF("text")

cpt_parser_model = cpt_parser_pipeline.fit(empty_data)
cpt_parser_model.stages[-1].write().overwrite().save("./parsers/cpt_parser")

cpt_parser = ContextualParserModel.load("parsers/cpt_parser") \
    .setInputCols(["splitter", "token"])\
    .setOutputCol("entity_cpt")

In [9]:
annotations = LightPipeline(cpt_parser_model).annotate(text)

annotations["entity_cpt"]

['88305']

###  Create a Custom `Specimen ID` Parser

Similarly, another parser is created with ContextualParserApproach to extract specimen IDs from medical texts

In [10]:
with open('specimen.json', 'w') as f:
    json.dump({
        "entity": "IDNUM",
        "ruleScope": "sentence",
        "regex": "(?:Specimen(?:\s*(?:ID|Number|Code|#|No\.?)?:?)?\s*)?#?[A-Z]{1,5}[0-9]{2,4}-?[0-9]{3,6}",
        "contextLength": 25,
        "matchScope": "token"
    }, f)

specimen_parser = ContextualParserApproach() \
    .setInputCols(["splitter", "token"]) \
    .setOutputCol("entity_specimen") \
    .setJsonPath("specimen.json") \
    .setCaseSensitive(False) \
    .setPrefixAndSuffixMatch(False)

specimen_parser_pipeline = Pipeline(stages=[
    document_assembler,
    splitter,
    tokenizer,
    specimen_parser
  ])

empty_data = spark.createDataFrame([[""]]).toDF("text")

specimen_parser_model = specimen_parser_pipeline.fit(empty_data)
specimen_parser_model.stages[-1].write().overwrite().save("./parsers/specimen_parser")

specimen_parser = ContextualParserModel.load("./parsers/specimen_parser") \
    .setInputCols(["splitter", "token"])\
    .setOutputCol("entity_specimen")

In [11]:
annotations = LightPipeline(specimen_parser_model).annotate(text)

annotations["entity_specimen"]

['#RD23-4897']

### **IOBTagger**

The `IOBTagger` is added to tag the entities recognized by the Named Entity Recognition (NER) model in the IOB (Inside, Outside, Beginning) format. This format provides a standard data structure required for training the NER model.

In [12]:
iobTagger = sparknlp_jsl.annotator.IOBTagger()\
  .setInputCols(["token", "ner_chunk"])\
  .setOutputCol("ner_label")

### **Update the Chunk Merging Strategy**

The inputs of the ChunkMergeModel, which is responsible for merging entities from different NER models, are updated to include the entities generated by the newly created cpt_parser and specimen_parser. This ensures that all entities found by both the pretrained models and our custom parsers are consolidated.

In [13]:
merger_input_cols = deid_pipeline.model.stages[35].getInputCols()
merger_input_cols

['entity_icd10',
 'entity_email',
 'entity_ip_address',
 'entity_age',
 'entity_medicalrecord',
 'entity_ssn',
 'entity_account',
 'entity_vin',
 'entity_date',
 'entity_phone',
 'entity_phone2',
 'entity_country',
 'entity_state',
 'entity_zip',
 'entity_plate',
 'entity_dln',
 'entity_license']

In [14]:
merger_input_cols = deid_pipeline.model.stages[35].getInputCols()

chunk_merge_rulebase = deid_pipeline.model.stages[35]\
      .setInputCols(["entity_cpt", "entity_specimen"] + merger_input_cols)

### Update the De-identification Blacklist

In [15]:
deid_pipeline.model.stages[38]

ChunkMergeModel_5a3f1e608447

In [16]:
deid_pipeline.model.stages[38] = deid_pipeline.model.stages[38]\
                                      .setBlackList(['CPT_CODE'])

### Updated Stages

In [17]:
deid_pipeline.model.stages = (
    deid_pipeline.model.stages[:35]
    + [cpt_parser, specimen_parser, chunk_merge_rulebase]
    + deid_pipeline.model.stages[36:]
    + [iobTagger]
)

In [18]:
deid_pipeline.model.stages

[DocumentAssembler_ae0f203deedd,
 InternalDocumentSplitter_cc36578ceda6,
 REGEX_TOKENIZER_2e85686aea12,
 WORD_EMBEDDINGS_MODEL_9004b1d00302,
 MedicalNerModel_1a8637089929,
 NER_CONVERTER_1aef7e9d2de5,
 MedicalNerModel_d92d47622e85,
 MedicalNerModel_32184c1db80b,
 MedicalNerModel_ada39ac0d359,
 NER_CONVERTER_a99db4e6a79d,
 NER_CONVERTER_4a9436714344,
 NER_CONVERTER_ea6433988e18,
 PretrainedZeroShotNER_5f30ab9002f1,
 NER_CONVERTER_c97040caf7b3,
 MedicalNerModel_b8b167ec3114,
 NER_CONVERTER_06db473f3215,
 ContextualEntityRuler_11ff6711ef6b,
 ChunkMergeModel_95d6827691bb,
 CONTEXTUAL-PARSER_bf2a6abaf5fa,
 CONTEXTUAL-PARSER_ff6bad379d91,
 CONTEXTUAL-PARSER_89341cae7221,
 CONTEXTUAL-PARSER_c6b9eded8d31,
 CONTEXTUAL-PARSER_9480c24bd9f8,
 CONTEXTUAL-PARSER_3886bce391c8,
 CONTEXTUAL-PARSER_0bb3fb75cd01,
 ENTITY_EXTRACTOR_6792f2f6e85a,
 ENTITY_EXTRACTOR_74ace4be4f73,
 CONTEXTUAL-PARSER_dfb32adc7555,
 REGEX_MATCHER_5003669d6422,
 CONTEXTUAL-PARSER_746a25662aa6,
 CONTEXTUAL-PARSER_079220479a3d,
 C

## Save and Test the Modified Pipeline

In [19]:
empty_result = deid_pipeline.transform(spark.createDataFrame([[""]]).toDF("text"))

deid_pipeline.model.write().overwrite().save("modified_pipeline")

In [20]:
# We are loading the pretrained pipeline using the `from_disk` method.
from sparknlp.pretrained import PretrainedPipeline

modified_pipeline = PretrainedPipeline.from_disk('modified_pipeline')

### Sample Result

In [21]:
samples_df = spark.createDataFrame([[text]]).toDF("text")

result = modified_pipeline.transform(samples_df).cache()

In [22]:
result.select(F.explode(F.arrays_zip(result.ner_chunk.result,
                                     result.ner_chunk.begin,
                                     result.ner_chunk.end,
                                     result.ner_chunk.metadata)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunk"),
              F.expr("cols['1']").alias("begin"),
              F.expr("cols['2']").alias("end"),
              F.expr("cols['3']['entity']").alias("ner_label"),
              F.expr("cols['3']['confidence']").alias("confidence")).show(50,truncate=False)

+----------------------------------+-----+----+---------+----------+
|chunk                             |begin|end |ner_label|confidence|
+----------------------------------+-----+----+---------+----------+
|John Lee                          |22   |29  |NAME     |0.9999912 |
|7789201                           |37   |43  |IDNUM    |0.72      |
|LERE                              |55   |58  |LOCATION |0.86184543|
|2025-05-12                        |75   |84  |DATE     |NULL      |
|#RD23-4897                        |101  |110 |IDNUM    |0.50      |
|Smith                             |232  |236 |NAME     |0.9992543 |
|Carter                            |243  |248 |NAME     |0.9988757 |
|2025-05-12                        |275  |284 |DATE     |NULL      |
|General Hospital                  |292  |307 |LOCATION |0.9980348 |
|Fan Gabriel                       |313  |323 |NAME     |0.98504215|
|90210                             |325  |329 |IDNUM    |0.5666    |
|General Hospital                 

In [23]:
pd.set_option("display.max_colwidth", 1000)

result_df = result.selectExpr("text",
                              "mask_entity.result as masked_result",
                              "obfuscated.result as obfuscated_result").toPandas()
result_df

Unnamed: 0,text,masked_result,obfuscated_result
0,"\n(NOTE) Patient Name: John Lee. MR#: 7789201 Location: LERE Date Reported: 2025-05-12 16:30\nSpecimen #RD23-4897 Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. Smith, Dr. Carter, CT(ASCP) Date Reported: 2025-05-12 16:30\nGeneral Hospital Dr. Fan Gabriel 90210 CPT Code(s) A: 88305\n\nGeneral Hospital in New York City Dr. Williams, NYC, NY\n(212) 555-7890 Patient Name: John Lee Accession #: GH-556672\nPatient ID #: 7789201 Collected: 2025-05-10 Address:\n123 Main Street, FALL RIVER\nNIAGARA FALLS, NY 14304\nReceived: 2025-05-10 Reported: 2025-05-12\nSoc. Sec. #: XXX-XX-1234 DOB/Age/Sex: 1973 (Age: 52) M\nPhysician(s): Dr. Jameson. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed at Barstow Heights Christus Southeast, NY – St Elizabeth; New York City.\n· Chromosome analysis cytogene...","[\n(NOTE) Patient Name: <NAME>. MR#: <IDNUM> Location: <LOCATION> Date Reported: <DATE> 16:30\nSpecimen <IDNUM> Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. <NAME>, Dr. <NAME>, CT(ASCP) Date Reported: <DATE> 16:30\n<LOCATION> Dr. <NAME> <IDNUM> CPT Code(s) A: 88305\n\n<LOCATION> in <LOCATION> City Dr. <NAME>, <LOCATION>, <LOCATION>\n<CONTACT> Patient Name: <NAME> Accession #: <IDNUM>\nPatient ID #: <IDNUM> Collected: <DATE> Address:\n<LOCATION>, <LOCATION>\n<LOCATION>, <LOCATION> <LOCATION>\nReceived: <DATE> Reported: <DATE>\nSoc. Sec. #: <IDNUM> DOB/Age/Sex: <DATE> (Age: <AGE>) M\nPhysician(s): Dr. <NAME>. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed at <LOCATION>, <LOCATION> – <LOCATION>; <LOCATION> City.\n· <LOCATION> analysis cytogenetics. (ADDENDUM REPORT TO FOLLOW.)\n·...","[\n(NOTE) Patient Name: Gillie Allan. MR#: 0074518 Location: 4500 MEMORIAL DRIVE Date Reported: 2025-06-29 16:30\nSpecimen #SA52-9740 Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. Wanna, Dr. Malvin, CT(ASCP) Date Reported: 2025-06-29 16:30\n310 Ellis Street Dr. Marcelo Danes 41581 CPT Code(s) A: 88305\n\n310 Ellis Street in 2000 Boise Ave City Dr. Duwaine, 427 GUY PARK AVE, 16100 SOUTH FREEWAY\n(585) 666-0741 Patient Name: Gillie Allan Accession #: PU-663305\nPatient ID #: 0074518 Collected: 2025-06-27 Address:\n3255 Independence Street, 302 W MCNEESE ST\n4101 NW 89TH BLVD, 16100 SOUTH FREEWAY 59 KOCH AVE\nReceived: 2025-06-27 Reported: 2025-06-29\nSoc. Sec. #: WWW-WW-8529 DOB/Age/Sex: 1974 (Age: 44) M\nPhysician(s): Dr. Marchelle. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed..."


##  Prepare Data for Custom NER Model Training

In [24]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/refs/heads/master/data/ner/eng.train -O eng.train

from sparknlp.training import CoNLL
data_conll = CoNLL(includeDocId=True,explodeSentences=True).readDataset(spark, "./eng.train")
data_conll.show(2)


+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|doc_id|                text|            document|            sentence|               token|                 pos|               label|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|     X|EU rejects German...|[{document, 0, 47...|[{document, 0, 47...|[{token, 0, 1, EU...|[{pos, 0, 1, NNP,...|[{named_entity, 0...|
|     X|     Peter Blackburn|[{document, 0, 14...|[{document, 0, 14...|[{token, 0, 4, Pe...|[{pos, 0, 4, NNP,...|[{named_entity, 0...|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 2 rows



In [25]:
data_conll.count()

14041

In [26]:
input_spark_df = data_conll.select("doc_id", "text")
input_spark_df.show(2, truncate=50)

+------+------------------------------------------------+
|doc_id|                                            text|
+------+------------------------------------------------+
|     X|EU rejects German call to boycott British lamb .|
|     X|                                 Peter Blackburn|
+------+------------------------------------------------+
only showing top 2 rows



### Preprocess Data with the Modified Pipeline

Run the entire dataset through our modified pipeline. This generates token, sentence, and embedding annotations required for the NER training downstream.

In [27]:
results = modified_pipeline.transform(input_spark_df)
results.columns

['doc_id',
 'text',
 'document',
 'splitter',
 'token',
 'embeddings',
 'ner_clinical_large',
 'ner_chunk_clinical_large',
 'ner_deid_generic_docwise',
 'ner_deid_docwise_subentity',
 'ner_deid_generic_docwise_merged_conll',
 'ner_chunk_generic_docwise',
 'ner_chunk_subentity_docwise',
 'ner_chunk_merged_docwise',
 'ner_zero_shot',
 'ner_chunk_zero_shot_raw',
 'ner_deid_subentity_docwise_new',
 'ner_chunk_subentity_docwise_new_chunk',
 'ner_chunk_zero_shot',
 'deid_merged_ner_chunk',
 'entity_icd10',
 'entity_ssn',
 'entity_account',
 'entity_dln',
 'entity_plate',
 'entity_vin',
 'entity_license',
 'entity_country',
 'entity_state',
 'entity_age',
 'entity_date',
 'entity_phone',
 'entity_phone2',
 'entity_zip',
 'entity_medicalrecord',
 'entity_email',
 'entity_ip_address',
 'entity_cpt',
 'entity_specimen',
 'deid_merged_ner_rulebased',
 'ner_chunk_raw',
 'ner_chunk_processed',
 'ner_chunk',
 'mask_entity',
 'obfuscated',
 'ner_label']

In [28]:
result_df = results.select('doc_id','text','document','splitter',
                          'token',"embeddings", 'ner_label')

In [29]:
result_df.show(2, truncate=40)

+------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+
|doc_id|                                    text|                                document|                                splitter|                                   token|                              embeddings|                               ner_label|
+------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+----------------------------------------+
|     X|EU rejects German call to boycott Bri...|[{document, 0, 47, EU rejects German ...|[{document, 0, 48, EU rejects German ...|[{token, 0, 1, EU, {sentence -> 0}, [...|[{word_embeddings, 0, 1, EU, {isOOV -...|[{named_entity, 0, 1, 

### Persist Preprocessed Data

Save the annotated DataFrame to Parquet format. This is an optimization step to speed up the training process by avoiding re-computation.

In [30]:
%%time

n_partitions = 48

# WRITING THE DATA
result_df.repartition(n_partitions).write.mode("overwrite").format("parquet")\
    .save(f"./data/result_df_{n_partitions}.parquet")


CPU times: user 7.52 s, sys: 1.44 s, total: 8.96 s
Wall time: 27min 13s


## Train a Custom Medical NER Model

In [31]:
# READING THE DATA
n_partitions = 48
result_df = spark.read \
    .parquet(f"./data/result_df_{n_partitions}.parquet")\
    .repartition(n_partitions)

In [32]:
result_df.count()

14041

In [33]:
result_df.show(2)

+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|doc_id|                text|            document|            splitter|               token|          embeddings|           ner_label|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|     X|Costa Rica - Rona...|[{document, 0, 52...|[{document, 0, 53...|[{token, 0, 4, Co...|[{word_embeddings...|[{named_entity, 0...|
|     X|These are leading...|[{document, 0, 79...|[{document, 0, 80...|[{token, 0, 4, Th...|[{word_embeddings...|[{named_entity, 0...|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 2 rows



In [34]:
(train_df, test_df) = result_df.randomSplit([0.8, 0.2], seed = 42)

In [35]:
test_df.repartition(n_partitions).write.mode("overwrite").format("parquet")\
    .save(f"./data/test_df.parquet")

###  Use MedicalNerDLGraphChecker for NER

The MedicalNerDLGraphChecker processes the dataset to extract required graph parameters (tokens, labels, embedding dimensions)

In [36]:
embeddings = (WordEmbeddingsModel.pretrained("embeddings_clinical", "en", "clinical/models")
            .setInputCols(["splitter", "token"])
            .setOutputCol("embeddings"))

embeddings_clinical download started this may take some time.
Approximate size to download 1.6 GB
[OK!]


In [37]:
nerDLGraphChecker = MedicalNerDLGraphChecker()\
    .setInputCols(["splitter", "token"])\
    .setLabelColumn("ner_label")\
    .setEmbeddingsModel(embeddings)

###  Configure and Run the MedicalNerApproach

In [38]:
nerTagger = MedicalNerApproach()\
    .setInputCols(["splitter", "token", "embeddings"])\
    .setLabelColumn("ner_label")\
    .setOutputCol("ner")\
    .setMaxEpochs(30)\
    .setBatchSize(8)\
    .setRandomSeed(0)\
    .setVerbose(1)\
    .setValidationSplit(0.2)\
    .setEvaluationLogExtended(True) \
    .setEnableOutputLogs(True)\
    .setIncludeConfidence(True)\
    .setOutputLogsPath('ner_logs')\
    .setEarlyStoppingCriterion(0.01)\
    .setEarlyStoppingPatience(5)\
    .setUseBestModel(False)\
    #.setTestDataset("./data/test_df.parquet")\
    #.setEnableMemoryOptimizer(True) #>> if you have a limited memory and a large conll file, you can set this True to train batch by batch
    #.setDatasetInfo("NCBI_sample_short dataset") #You can add details regarding the dataset

ner_pipeline = Pipeline(
    stages=[
          nerDLGraphChecker,
          nerTagger
 ])

In [39]:
%%time
ner_model = ner_pipeline.fit(train_df)

CPU times: user 16.1 s, sys: 2.22 s, total: 18.3 s
Wall time: 54min 19s


In [40]:
ner_model.stages[-1].getTrainingClassDistribution()

{'I-NAME': 4392, 'I-CONTACT': 191, 'I-AGE': 35, 'I-IDNUM': 63, 'B-DATE': 3583, 'I-DATE': 494, 'I-LOCATION': 3740, 'B-NAME': 5149, 'B-AGE': 588, 'B-LOCATION': 10472, 'B-IDNUM': 154, 'O': 136085, 'B-CONTACT': 311}

### Save the Trained NER Model and Review Logs

In [41]:
ner_model.stages[-1].write().overwrite().save('models/new_NER_model')

In [42]:
import os
log_file= os.listdir("ner_logs")[0]

with open (f"./ner_logs/{log_file}") as f:
    print(f.read())

Name of the selected graph: medical-ner-dl/blstm_100_200_128_100.pb
Training started - total epochs: 30 - lr: 0.001 - batch size: 8 - labels: 13 - chars: 84 - training examples: 11192


Epoch 1/30 started, lr: 0.001, dataset size: 11192


Epoch 1/30 - 98.42s - loss: 5278.641 - avg training loss: 4.7257304 - batches: 1117
Quality on validation dataset (20.0%), validation examples = 2238
time to finish evaluation: 16.48s
Total validation loss: 806.9229	Avg validation loss: 2.8214
label	 tp	 fp	 fn	 prec	 rec	 f1
I-NAME	 658	 104	 153	 0.86351705	 0.811344	 0.8366179
I-CONTACT	 18	 12	 24	 0.6	 0.42857143	 0.5
I-AGE	 0	 0	 7	 0.0	 0.0	 0.0
I-IDNUM	 0	 0	 18	 0.0	 0.0	 0.0
B-DATE	 662	 166	 69	 0.7995169	 0.9056088	 0.84926236
I-DATE	 68	 22	 16	 0.75555557	 0.8095238	 0.7816092
I-LOCATION	 375	 263	 391	 0.5877743	 0.48955613	 0.5341881
B-NAME	 744	 209	 244	 0.7806926	 0.75303644	 0.76661515
B-AGE	 57	 29	 62	 0.6627907	 0.4789916	 0.55609757
B-LOCATION	 1630	 360	 536	 0.8190955	 0.7525

## Evaluate the Newly Trained NER Model

In [43]:
pred_df = ner_model.stages[-1].transform(test_df).cache()

In [44]:
pred_df.show()

+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|doc_id|                text|            document|            splitter|               token|          embeddings|           ner_label|                 ner|
+------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|     X|" Investors unloa...|[{document, 0, 10...|[{document, 0, 10...|[{token, 0, 0, ",...|[{word_embeddings...|[{named_entity, 0...|[{named_entity, 0...|
|     X|" There were big ...|[{document, 0, 17...|[{document, 0, 17...|[{token, 0, 0, ",...|[{word_embeddings...|[{named_entity, 0...|[{named_entity, 0...|
|     X|" Up here in the ...|[{document, 0, 21...|[{document, 0, 21...|[{token, 0, 0, ",...|[{word_embeddings...|[{named_entity, 0...|[{named_entity, 0...|
|     X|* Conglomerate Bo...|[{document, 0, 26...|[{document, 0,

In [45]:
from pyspark.sql import functions as F

pred_token_df = pred_df.select(F.explode(F.arrays_zip(pred_df.ner_label.metadata,
                                                  pred_df.ner_label.begin,
                                                  pred_df.ner_label.end,
                                                  pred_df.ner_label.result,
                                                  pred_df.ner.result)).alias("cols")) \
          .select(F.expr("cols['0']['word']").alias("token"),
                  F.expr("cols['1']").alias("begin"),
                  F.expr("cols['2']").alias("end"),
                  F.expr("cols['3']").alias("gtruth"),
                  F.expr("cols['4']").alias("prediction"))\
          .toPandas()

pred_token_df

Unnamed: 0,token,begin,end,gtruth,prediction
0,"""",0,0,O,O
1,Investors,2,10,B-NAME,B-NAME
2,unloaded,12,19,O,O
3,their,21,25,O,O
4,shares,27,32,O,O
...,...,...,...,...,...
42194,2,18,18,O,O
42195,behind,0,5,O,O
42196,seeding,0,6,O,O
42197,),8,8,O,O


### Calculate Evaluation Metrics
Use the NerDLMetrics class to compute precision, recall, and F1-score for each entity. The evaluation is shown with both `full_chunk` and `partial_chunk_per_token` modes.

In [46]:
from sparknlp_jsl.eval import NerDLMetrics
import pyspark.sql.functions as F

evaler = NerDLMetrics(mode="full_chunk")

eval_result = evaler.computeMetricsFromDF(pred_df.select("ner_label","ner"),
                                          prediction_col="ner",
                                          label_col="ner_label",
                                          drop_o = True, case_sensitive = True).cache()

eval_result.withColumn("precision", F.round(eval_result["precision"],4))\
           .withColumn("recall", F.round(eval_result["recall"],4))\
           .withColumn("f1", F.round(eval_result["f1"],4)).show(100)

print(eval_result.selectExpr("avg(f1) as macro").show())
print (eval_result.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+--------+------+-----+-----+------+---------+------+------+
|  entity|    tp|   fp|   fn| total|precision|recall|    f1|
+--------+------+-----+-----+------+---------+------+------+
| CONTACT|  66.0| 21.0| 18.0|  84.0|   0.7586|0.7857|0.7719|
|    NAME|1216.0| 98.0| 97.0|1313.0|   0.9254|0.9261|0.9258|
|    DATE| 847.0| 21.0| 34.0| 881.0|   0.9758|0.9614|0.9686|
|   IDNUM|  21.0|  5.0| 14.0|  35.0|   0.8077|   0.6|0.6885|
|LOCATION|2424.0|209.0|218.0|2642.0|   0.9206|0.9175|0.9191|
|     AGE| 128.0| 12.0| 16.0| 144.0|   0.9143|0.8889|0.9014|
+--------+------+-----+-----+------+---------+------+------+

+------------------+
|             macro|
+------------------+
|0.8625398830857134|
+------------------+

None
+-----------------+
|            micro|
+-----------------+
|0.924830698074349|
+-----------------+

None


In [47]:
evaler = NerDLMetrics(mode="partial_chunk_per_token")
eval_result_partial = evaler.computeMetricsFromDF(pred_df.select("ner_label","ner"), prediction_col="ner", label_col="ner_label", drop_o = True, case_sensitive = True).cache()

eval_result_partial.withColumn("precision", F.round(eval_result_partial["precision"],4))\
           .withColumn("recall", F.round(eval_result_partial["recall"],4))\
           .withColumn("f1", F.round(eval_result_partial["f1"],4)).sort("entity").show(100)
df_partial=eval_result_partial.toPandas()
print("partial_chunk_per_token")
print(eval_result_partial.selectExpr("avg(f1) as macro").show())
print (eval_result_partial.selectExpr("sum(f1*total) as sumprod","sum(total) as sumtotal").selectExpr("sumprod/sumtotal as micro").show())

+--------+------+-----+-----+------+---------+------+------+
|  entity|    tp|   fp|   fn| total|precision|recall|    f1|
+--------+------+-----+-----+------+---------+------+------+
|     AGE| 134.0| 12.0| 18.0| 152.0|   0.9178|0.8816|0.8993|
| CONTACT| 109.0| 29.0| 20.0| 129.0|   0.7899| 0.845|0.8165|
|    DATE| 965.0| 20.0| 34.0| 999.0|   0.9797| 0.966|0.9728|
|   IDNUM|  29.0|  5.0| 20.0|  49.0|   0.8529|0.5918|0.6988|
|LOCATION|3354.0|301.0|244.0|3598.0|   0.9176|0.9322|0.9249|
|    NAME|2275.0|130.0|110.0|2385.0|   0.9459|0.9539|0.9499|
+--------+------+-----+-----+------+---------+------+------+

partial_chunk_per_token
+------------------+
|             macro|
+------------------+
|0.8770233322717506|
+------------------+

None
+------------------+
|             micro|
+------------------+
|0.9356149945870833|
+------------------+

None


## Create the Final Pipeline with the Custom NER Model

In [48]:
# We are loading the pretrained pipeline using the `from_disk` method.
from sparknlp.pretrained import PretrainedPipeline

modified_pipeline = PretrainedPipeline.from_disk('modified_pipeline')

In [49]:
modified_pipeline.model.stages

[DocumentAssembler_ae0f203deedd,
 InternalDocumentSplitter_cc36578ceda6,
 REGEX_TOKENIZER_2e85686aea12,
 WORD_EMBEDDINGS_MODEL_9004b1d00302,
 MedicalNerModel_1a8637089929,
 NER_CONVERTER_1aef7e9d2de5,
 MedicalNerModel_d92d47622e85,
 MedicalNerModel_32184c1db80b,
 MedicalNerModel_ada39ac0d359,
 NER_CONVERTER_a99db4e6a79d,
 NER_CONVERTER_4a9436714344,
 NER_CONVERTER_ea6433988e18,
 PretrainedZeroShotNER_5f30ab9002f1,
 NER_CONVERTER_c97040caf7b3,
 MedicalNerModel_b8b167ec3114,
 NER_CONVERTER_06db473f3215,
 ContextualEntityRuler_11ff6711ef6b,
 ChunkMergeModel_95d6827691bb,
 CONTEXTUAL-PARSER_bf2a6abaf5fa,
 CONTEXTUAL-PARSER_ff6bad379d91,
 CONTEXTUAL-PARSER_89341cae7221,
 CONTEXTUAL-PARSER_c6b9eded8d31,
 CONTEXTUAL-PARSER_9480c24bd9f8,
 CONTEXTUAL-PARSER_3886bce391c8,
 CONTEXTUAL-PARSER_0bb3fb75cd01,
 ENTITY_EXTRACTOR_6792f2f6e85a,
 ENTITY_EXTRACTOR_74ace4be4f73,
 CONTEXTUAL-PARSER_dfb32adc7555,
 REGEX_MATCHER_5003669d6422,
 CONTEXTUAL-PARSER_746a25662aa6,
 CONTEXTUAL-PARSER_079220479a3d,
 C

### New Stages

In [50]:
ner_deid_new = MedicalNerModel.load("models/new_NER_model")\
    .setInputCols(["splitter", "token", "embeddings"])\
    .setOutputCol("ner_deid_new")

ner_deid_new_converter = NerConverter()\
      .setInputCols(["splitter", "token", "ner_deid_new"])\
      .setOutputCol("ner_chunk_new")

ner_deid = MedicalNerModel.pretrained("ner_deid_subentity_docwise", "en", "clinical/models")  \
      .setInputCols(["splitter", "token", "embeddings"]) \
      .setOutputCol("ner_deid_subentity_docwise")

ner_deid_converter = NerConverter()\
      .setInputCols(["splitter", "token", "ner_deid_subentity_docwise"])\
      .setOutputCol("ner_chunk_subentity_docwise")

chunk_merge_ner = ChunkMergeModel()\
    .setInputCols("ner_chunk_new", # New Trained Model
                  "ner_chunk_subentity_docwise")\
    .setOutputCol("deid_merged_ner_chunk")\
    .setOrderingFeatures(["ChunkLength","ChunkBegin"])\
    .setMergeOverlapping(True)\
    .setResetSentenceIndices(True)


ner_deid_subentity_docwise download started this may take some time.
Approximate size to download 8.9 MB
[OK!]


### **Update Stages**

In [51]:
modified_pipeline.model.stages = (
    modified_pipeline.model.stages[:4]
    + [ner_deid_new,
       ner_deid_new_converter,
       ner_deid,
       ner_deid_converter,
       chunk_merge_ner]
    + modified_pipeline.model.stages[18:]

)

In [52]:
modified_pipeline.model.stages

[DocumentAssembler_ae0f203deedd,
 InternalDocumentSplitter_cc36578ceda6,
 REGEX_TOKENIZER_2e85686aea12,
 WORD_EMBEDDINGS_MODEL_9004b1d00302,
 MedicalNerModel_ad2d1a1803c3,
 NerConverter_539ea7b222b9,
 MedicalNerModel_32184c1db80b,
 NerConverter_1b0e093bf6d6,
 ChunkMergeModel_0bc225e58090,
 CONTEXTUAL-PARSER_bf2a6abaf5fa,
 CONTEXTUAL-PARSER_ff6bad379d91,
 CONTEXTUAL-PARSER_89341cae7221,
 CONTEXTUAL-PARSER_c6b9eded8d31,
 CONTEXTUAL-PARSER_9480c24bd9f8,
 CONTEXTUAL-PARSER_3886bce391c8,
 CONTEXTUAL-PARSER_0bb3fb75cd01,
 ENTITY_EXTRACTOR_6792f2f6e85a,
 ENTITY_EXTRACTOR_74ace4be4f73,
 CONTEXTUAL-PARSER_dfb32adc7555,
 REGEX_MATCHER_5003669d6422,
 CONTEXTUAL-PARSER_746a25662aa6,
 CONTEXTUAL-PARSER_079220479a3d,
 CONTEXTUAL-PARSER_f8b8f9aafb9f,
 CONTEXTUAL-PARSER_7f824493eafc,
 REGEX_MATCHER_26934077fe57,
 REGEX_MATCHER_5fe3de8b5a4e,
 CONTEXTUAL-PARSER_9c0e3df6e2bf,
 CONTEXTUAL-PARSER_9e3272f6a015,
 MERGE_ddff59e8b14a,
 ChunkMergeModel_50feb5f97568,
 ContextualEntityRuler_08eeaa89c938,
 ChunkMe

### Reassemble and Save the Final

Rebuild the pipeline's stages, replacing the original NER components with our new custom NER model and the reconfigured merger. The final pipeline is then saved.

In [53]:
empty_result = modified_pipeline.transform(spark.createDataFrame([[""]]).toDF("text"))

modified_pipeline.model.write().overwrite().save("new_pipeline")

In [54]:
from sparknlp.pretrained import PretrainedPipeline

new_pipeline = PretrainedPipeline.from_disk('new_pipeline')

## Final Test of the New Pipeline

In [55]:
samples_df = spark.createDataFrame([[text]]).toDF("text")

result = new_pipeline.transform(samples_df).cache()

In [56]:
result.select(F.explode(F.arrays_zip(result.ner_chunk.result,
                                     result.ner_chunk.begin,
                                     result.ner_chunk.end,
                                     result.ner_chunk.metadata)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunk"),
              F.expr("cols['1']").alias("begin"),
              F.expr("cols['2']").alias("end"),
              F.expr("cols['3']['entity']").alias("ner_label"),
              F.expr("cols['3']['confidence']").alias("confidence")).show(50,truncate=False)

+----------------------------------+-----+----+---------+----------+
|chunk                             |begin|end |ner_label|confidence|
+----------------------------------+-----+----+---------+----------+
|John Lee                          |22   |29  |NAME     |0.9799    |
|7789201                           |37   |43  |IDNUM    |0.72      |
|2025-05-12                        |75   |84  |DATE     |NULL      |
|#RD23-4897                        |101  |110 |IDNUM    |0.50      |
|Smith                             |232  |236 |NAME     |0.9999    |
|Carter                            |243  |248 |NAME     |0.9997    |
|2025-05-12                        |275  |284 |DATE     |NULL      |
|Fan Gabriel                       |313  |323 |NAME     |0.80315   |
|90210                             |325  |329 |DATE     |0.9192    |
|New York                          |373  |380 |LOCATION |NULL      |
|NYC                               |401  |403 |LOCATION |0.992     |
|NY                               

In [57]:
pd.set_option("display.max_colwidth", 1000)
result_df = result.selectExpr("text","mask_entity.result as masked_result","obfuscated.result as obfuscated_result").toPandas()
result_df

Unnamed: 0,text,masked_result,obfuscated_result
0,"\n(NOTE) Patient Name: John Lee. MR#: 7789201 Location: LERE Date Reported: 2025-05-12 16:30\nSpecimen #RD23-4897 Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. Smith, Dr. Carter, CT(ASCP) Date Reported: 2025-05-12 16:30\nGeneral Hospital Dr. Fan Gabriel 90210 CPT Code(s) A: 88305\n\nGeneral Hospital in New York City Dr. Williams, NYC, NY\n(212) 555-7890 Patient Name: John Lee Accession #: GH-556672\nPatient ID #: 7789201 Collected: 2025-05-10 Address:\n123 Main Street, FALL RIVER\nNIAGARA FALLS, NY 14304\nReceived: 2025-05-10 Reported: 2025-05-12\nSoc. Sec. #: XXX-XX-1234 DOB/Age/Sex: 1973 (Age: 52) M\nPhysician(s): Dr. Jameson. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed at Barstow Heights Christus Southeast, NY – St Elizabeth; New York City.\n· Chromosome analysis cytogene...","[\n(NOTE) Patient Name: <NAME>. MR#: <IDNUM> Location: LERE Date Reported: <DATE> 16:30\nSpecimen <IDNUM> Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. <NAME>, Dr. <NAME>, CT(ASCP) Date Reported: <DATE> 16:30\nGeneral Hospital Dr. <NAME> <DATE> CPT Code(s) A: 88305\n\nGeneral Hospital in <LOCATION> City Dr. Williams, <LOCATION>, <LOCATION>\n<CONTACT> Patient Name: <NAME> Accession #: <IDNUM>\nPatient ID #: <CONTACT> Collected: <DATE> Address:\n<LOCATION>, FALL <LOCATION>, <LOCATION> <LOCATION>\nReceived: <DATE> Reported: <DATE>\nSoc. Sec. #: XXX-XX-1234 DOB/Age/Sex: <DATE> (Age: <AGE>) <NAME>): Dr. <NAME>. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed at <LOCATION>, <LOCATION> – <LOCATION>; <LOCATION> City.\n<LOCATION> analysis cytogenetics. (ADDENDUM REPORT TO FOLLOW.)\n· Leu...","[\n(NOTE) Patient Name: Gillie Allan. MR#: 0074518 Location: LERE Date Reported: 2025-06-29 16:30\nSpecimen #SA52-9740 Clinical History: None Given. CLINICAL INFORMATION: Date of Last Menstrual Period: N/A\nElectronically Signed Out By Dr. Wanna, Dr. Malvin, CT(ASCP) Date Reported: 2025-06-29 16:30\nGeneral Hospital Dr. Marcelo Danes 06-01-1980 CPT Code(s) A: 88305\n\nGeneral Hospital in 2000 Boise Ave City Dr. Williams, 427 GUY PARK AVE, 16100 SOUTH FREEWAY\n(585) 666-0741 Patient Name: Gillie Allan Accession #: PU-663305\nPatient ID #: 0074518 Collected: 2025-06-27 Address:\n3255 Independence Street, FALL 401 BICENTENNIAL WAY, 16100 SOUTH FREEWAY 59 KOCH AVE\nReceived: 2025-06-27 Reported: 2025-06-29\nSoc. Sec. #: XXX-XX-1234 DOB/Age/Sex: 1974 (Age: 44) TERETHA Sol): Dr. Marchelle. Peripheral sequestration, i.e. splenomegaly or hepatomegaly should be excluded to be sure if peripheral sequestration is not present.\nThe following special studies were performed at 103 North Street, ..."
