![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/healthcare-nlp/01.3.BertForTokenClassification_NER_SparkNLP_with_Transformers.ipynb)

# BertForTokenClassification NER Model Training with Transformers

In this notebook, you will find how to train BertForTokenClassification NER model with transformers and then import into Spark NLP. (There is no Approach() in this notebook, so you can use only transformers for training.)

In [None]:
# %%capture
! pip install -q seqeval
! pip install -q transformers==4.25.1
! pip install -q tensorflow==2.11.0

In [None]:
# Install the johnsnowlabs library to access Spark-OCR and Spark-NLP for Healthcare, Finance, and Legal.
! pip install -q johnsnowlabs==5.1.0

In [None]:
from google.colab import files
print('Please Upload your John Snow Labs License using the button below')
license_keys = files.upload()

In [None]:
from johnsnowlabs import nlp, medical

# # After uploading your license run this to install all licensed Python Wheels and pre-download Jars the Spark Session JVM
nlp.install()

In [None]:
from johnsnowlabs import nlp, medical
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

# Automatically load license data and start a session with all jars user has access to
spark = nlp.start()

👌 Detected license file /content/spark_nlp_for_healthcare_spark_ocr_8283.json
👌 Launched [92mcpu optimized[39m session with with: 🚀Spark-NLP==5.1.0, 💊Spark-Healthcare==5.1.0, running on ⚡ PySpark==3.1.2


In [None]:
spark

In [None]:
import os
import json
import numpy as np
from tqdm import tqdm, trange
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from keras.utils import pad_sequences
#from keras_preprocessing.sequence import pad_sequences

import transformers
from transformers import BertTokenizer, BertConfig
from transformers import BertForTokenClassification, TFBertForTokenClassification, AdamW
from transformers import get_linear_schedule_with_warmup

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

torch.cuda.get_device_name(0)

'Tesla T4'

## Download NCBI Disease CoNLL Dataset

In [None]:
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NER_NCBIconlltrain.txt
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/master/tutorials/Certification_Trainings/Healthcare/data/NER_NCBIconlltest.txt

In [None]:
PROJECT_NAME = 'ner_disease_main'

train_set =  "NER_NCBIconlltrain.txt"
test_set = "NER_NCBIconlltest.txt"

test_metrics = True

# select any Bert model from >> https://huggingface.co/models?pipeline_tag=token-classification&sort=downloads&search=bert

MODEL_TO_TRAIN = 'dmis-lab/biobert-base-cased-v1.2'
# emilyalsentzer/Bio_ClinicalBERT

# Defining some key variables that will be used later on in the training
MAX_LEN = 128 # 512
TRAIN_BATCH_SIZE = 64 # 8
VALID_BATCH_SIZE = 64 # 8
EPOCHS = 5
LEARNING_RATE = 2e-05

!mkdir -p {PROJECT_NAME}
!mkdir -p {PROJECT_NAME}/logs

## Run the follwing cells with no change

In [None]:
from pyspark.sql import DataFrame
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pyspark.sql as SQL
from pyspark import keyword_only

In [None]:
def get_conll_df(pth):
  data = nlp.CoNLL().readDataset(spark, pth)
  data = data.withColumn("sentence_idx", F.monotonically_increasing_id())
  data = data.withColumn('unique', F.array_distinct("label.result"))\
              .withColumn('c', F.size('unique'))\
              .filter(F.col('c')>1)

  df = data.select('sentence_idx', F.explode(F.arrays_zip(data.token.result,
                                                          data.label.result,
                                                          data.pos.result)).alias("cols")) \
          .select('sentence_idx',
                  F.expr("cols['0']").alias("word"),
                  F.expr("cols['1']").alias("tag"),
                  F.expr("cols['2']").alias("pos")).toPandas()

  return df


train_data_df = get_conll_df(train_set)
test_data_df = get_conll_df(test_set)

print ('=== TRAINING SET DISTRIBUTION ===')
print (train_data_df['tag'].value_counts())

print ('=== TEST SET DISTRIBUTION ===')
print (test_data_df['tag'].value_counts())

if not test_metrics:

  train_data_df = pd.concat([train_data_df, test_data_df])


## convert conll file to sentences

class SentenceGetter(object):

    def __init__(self, dataset):
        self.n_sent = 1
        self.dataset = dataset
        self.empty = False
        agg_func = lambda s: [(w,p, t) for w,p, t in zip(s["word"].values.tolist(),
                                                       s['pos'].values.tolist(),
                                                        s["tag"].values.tolist())]
        self.grouped = self.dataset.groupby("sentence_idx").apply(agg_func)
        self.sentences = [s for s in self.grouped]

    def get_next(self):
        try:
            s = self.grouped["Sentence: {}".format(self.n_sent)]
            self.n_sent += 1
            return s
        except:
            return None

train_getter = SentenceGetter(train_data_df)

if test_metrics:
  test_getter = SentenceGetter(test_data_df)


print ('=== Getting sentences and labels ===')

# Sentences
train_sentences = [[word[0] for word in sentence] for sentence in train_getter.sentences]
print("Example of train sentence:")
print (train_sentences[5])

if test_metrics:
  test_sentences = [[word[0] for word in sentence] for sentence in test_getter.sentences]
  print("Example of test sentence:")
  print (test_sentences[5])

# Labels
train_labels = [[s[2] for s in sentence] for sentence in train_getter.sentences]
print("Example of train sentence:")
print(train_labels[5])

if test_metrics:
  test_labels = [[s[2] for s in sentence] for sentence in test_getter.sentences]
  print("Example of test sentence:")
  print(test_labels[5])


tag_values = list(set(train_data_df["tag"].values))
tag_values.append("PAD")
tag2idx = {t: i for i, t in enumerate(tag_values)}

print(tag_values[:10])
print(tag2idx)

tokenizer = BertTokenizer.from_pretrained(MODEL_TO_TRAIN, do_lower_case=False)


def tokenize_and_preserve_labels(sentence, text_labels):
    tokenized_sentence = []
    labels = []

    for word, label in zip(sentence, text_labels):

        # Tokenize the word and count # of subwords the word is broken into
        tokenized_word = tokenizer.tokenize(word)
        n_subwords = len(tokenized_word)

        # Add the tokenized word to the final tokenized word list
        tokenized_sentence.extend(tokenized_word)

        # Add the same label to the new list of labels `n_subwords` times
        labels.extend([label] * n_subwords)

    return tokenized_sentence, labels


train_tokenized_texts_and_labels = [
    tokenize_and_preserve_labels(sent, labs)
    for sent, labs in zip(train_sentences, train_labels)
]

if test_metrics:

  test_tokenized_texts_and_labels = [
      tokenize_and_preserve_labels(sent, labs)
      for sent, labs in zip(test_sentences, test_labels)
  ]

train_tokenized_texts_tokens = [token_label_pair[0] for token_label_pair in train_tokenized_texts_and_labels]

if test_metrics:
  test_tokenized_texts_tokens = [token_label_pair[0] for token_label_pair in test_tokenized_texts_and_labels]
  print(test_tokenized_texts_tokens[5])

train_tokenized_texts_labels = [token_label_pair[1] for token_label_pair in train_tokenized_texts_and_labels]

if test_metrics:
  test_tokenized_texts_labels = [token_label_pair[1] for token_label_pair in test_tokenized_texts_and_labels]
  print(test_tokenized_texts_labels[5])



train_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in train_tokenized_texts_tokens],
                          maxlen=MAX_LEN, dtype="long", value=0.0,
                          truncating="post", padding="post")

if test_metrics:

  test_input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in test_tokenized_texts_tokens],
                          maxlen=MAX_LEN, dtype="long", value=0.0,
                          truncating="post", padding="post")

train_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in train_tokenized_texts_labels],
                     maxlen=MAX_LEN, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")

train_attention_masks = [[float(i != 0.0) for i in ii] for ii in train_input_ids]

if test_metrics:

  test_tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in test_tokenized_texts_labels],
                     maxlen=MAX_LEN, value=tag2idx["PAD"], padding="post",
                     dtype="long", truncating="post")
  test_attention_masks = [[float(i != 0.0) for i in ii] for ii in test_input_ids]




tr_inputs = torch.tensor(train_input_ids)
tr_tags = torch.tensor(train_tags)
tr_masks = torch.tensor(train_attention_masks)

if test_metrics:

  val_inputs = torch.tensor(test_input_ids)
  val_tags = torch.tensor(test_tags)
  val_masks = torch.tensor(test_attention_masks)


train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=TRAIN_BATCH_SIZE)

if test_metrics:

  valid_data = TensorDataset(val_inputs, val_masks, val_tags)
  valid_sampler = SequentialSampler(valid_data)
  valid_dataloader = DataLoader(valid_data, sampler=valid_sampler, batch_size=TRAIN_BATCH_SIZE)



model = BertForTokenClassification.from_pretrained(
    MODEL_TO_TRAIN,
    num_labels=len(tag2idx),
    output_attentions = False,
    output_hidden_states = False
)
model.to(device)

FULL_FINETUNING = True
if FULL_FINETUNING:
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         'weight_decay_rate': 0.0}
    ]
else:
    param_optimizer = list(model.classifier.named_parameters())
    optimizer_grouped_parameters = [{"params": [p for n, p in param_optimizer]}]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=3e-5,
    eps=1e-8
)


epochs = EPOCHS
max_grad_norm = 1.0

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)


## Store the average loss after each epoch so we can plot them.
loss_values, validation_loss_values = [], []

for EPOCH in trange(epochs, desc="Epoch"):
    # Put the model into training mode.
    model.train()
    # Reset the total loss for this epoch.
    total_loss = 0

    # Training loop
    for step, batch in enumerate(train_dataloader):
        # add batch to gpu
        batch = tuple(t.to(device) for t in batch)
        b_input_ids, b_input_mask, b_labels = batch
        # Always clear any previously calculated gradients before performing a backward pass.
        model.zero_grad()
        # forward pass
        # This will return the loss (rather than the model output)
        # because we have provided the `labels`.
        outputs = model(b_input_ids, token_type_ids=None,
                        attention_mask=b_input_mask, labels=b_labels)
        # get the loss
        loss = outputs[0]
        # Perform a backward pass to calculate the gradients.
        loss.backward()
        # track train loss
        total_loss += loss.item()
        # Clip the norm of the gradient
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)
        # update parameters
        optimizer.step()
        # Update the learning rate.
        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)
    tr_loss = f"Average train loss: {str(avg_train_loss)}\n"

    # Saving partial models (this creates the folder too)
    tokenizer.save_pretrained(f'{PROJECT_NAME}/{str(EPOCH)}/tokenizer/')
    model.save_pretrained(save_directory=f'{PROJECT_NAME}/{str(EPOCH)}/',
                          save_config=True, state_dict=model.state_dict())

    # Saving checkpoint in case it crashes, to restore work
    torch.save({
        'epoch': EPOCH,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_train_loss,
        }, f'{PROJECT_NAME}/{str(EPOCH)}/checkpoint.pth')

    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    if test_metrics:

      # Put the model into evaluation mode
      model.eval()
      # Reset the validation loss for this epoch.
      eval_loss, eval_accuracy = 0, 0
      nb_eval_steps, nb_eval_examples = 0, 0
      predictions , true_labels = [], []
      for batch in valid_dataloader:
          batch = tuple(t.to(device) for t in batch)
          b_input_ids, b_input_mask, b_labels = batch

          # Telling the model not to compute or store gradients,
          # saving memory and speeding up validation
          with torch.no_grad():
              # Forward pass, calculate logit predictions.
              # This will return the logits rather than the loss because we have not provided labels.
              outputs = model(b_input_ids, token_type_ids=None,
                              attention_mask=b_input_mask, labels=b_labels)
          # Move logits and labels to CPU
          logits = outputs[1].detach().cpu().numpy()
          label_ids = b_labels.to('cpu').numpy()

          # Calculate the accuracy for this batch of test sentences.
          eval_loss += outputs[0].mean().item()
          predictions.extend([list(p) for p in np.argmax(logits, axis=2)])
          true_labels.extend(label_ids)

      eval_loss = eval_loss / len(valid_dataloader)
      validation_loss_values.append(eval_loss)

      val_loss = f"Validation loss: {str(eval_loss)}\n"

    # Saving losses log
    with open(f'{PROJECT_NAME}/logs/epoch_' + str(EPOCH) + '_loss.log', 'a') as f:
      f.write(tr_loss)
      f.write('')
      if test_metrics:
          f.write(val_loss)

    # Calculating metrics
    pred_tags = [tag_values[p_i] for p, l in zip(predictions, true_labels)
                                 for p_i, l_i in zip(p, l) if tag_values[l_i] != "PAD"]
    valid_tags = [tag_values[l_i] for l in true_labels
                                  for l_i in l if tag_values[l_i] != "PAD"]

    report = classification_report(valid_tags, pred_tags)

    # Saving metrics
    with open(f'{PROJECT_NAME}/logs/epoch_' + str(EPOCH) + '_metrics.log', 'a') as f:
      f.write(report)

    # Printing also to stdout
    print(tr_loss)

    if test_metrics:
      print(val_loss)
      print(report)


=== TRAINING SET DISTRIBUTION ===
O            39427
I-Disease     3547
B-Disease     3093
Name: tag, dtype: int64
=== TEST SET DISTRIBUTION ===
O            9316
I-Disease     789
B-Disease     708
Name: tag, dtype: int64
=== Getting sentences and labels ===
Example of train sentence:
['A', 'common', 'MSH2', 'mutation', 'in', 'English', 'and', 'North', 'American', 'HNPCC', 'families', ':', 'origin', ',', 'phenotypic', 'expression', ',', 'and', 'sex', 'specific', 'differences', 'in', 'colorectal', 'cancer', '.']
Example of test sentence:
['Two', 'of', 'seventeen', 'mutated', 'T', '-', 'PLL', 'samples', 'had', 'a', 'previously', 'reported', 'A', '-', 'T', 'allele', '.']
Example of train sentence:
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'O']
Example of test sentence:
['O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'I-Disease', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'I-Disea

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.11k [00:00<?, ?B/s]

['Two', 'of', 'seventeen', 'm', '##uta', '##ted', 'T', '-', 'P', '##LL', 'samples', 'had', 'a', 'previously', 'reported', 'A', '-', 'T', 'all', '##ele', '.']
['O', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'I-Disease', 'I-Disease', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'I-Disease', 'O', 'O', 'O']


Downloading pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.2 were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initi

Average train loss: 0.5129824501496775

Validation loss: 0.09469861537218094

              precision    recall  f1-score   support

   B-Disease       0.57      0.66      0.61      1718
   I-Disease       0.72      0.43      0.54      1560
           O       0.94      0.97      0.95     11654
         PAD       0.00      0.00      0.00         0

    accuracy                           0.88     14932
   macro avg       0.56      0.52      0.53     14932
weighted avg       0.88      0.88      0.87     14932



Epoch:  40%|████      | 2/5 [01:17<01:55, 38.44s/it]

Average train loss: 0.07679254064957301

Validation loss: 0.056468677307878225

              precision    recall  f1-score   support

   B-Disease       0.77      0.68      0.72      1718
   I-Disease       0.69      0.83      0.75      1560
           O       0.98      0.97      0.97     11654

    accuracy                           0.92     14932
   macro avg       0.81      0.83      0.82     14932
weighted avg       0.92      0.92      0.92     14932



Epoch:  60%|██████    | 3/5 [01:56<01:17, 38.86s/it]

Average train loss: 0.043312250050129716

Validation loss: 0.041627985292247364

              precision    recall  f1-score   support

   B-Disease       0.83      0.85      0.84      1718
   I-Disease       0.81      0.88      0.84      1560
           O       0.98      0.97      0.98     11654

    accuracy                           0.95     14932
   macro avg       0.87      0.90      0.89     14932
weighted avg       0.95      0.95      0.95     14932



Epoch:  80%|████████  | 4/5 [02:34<00:38, 38.68s/it]

Average train loss: 0.028706883666691958

Validation loss: 0.03743800920035158

              precision    recall  f1-score   support

   B-Disease       0.86      0.87      0.87      1718
   I-Disease       0.86      0.87      0.87      1560
           O       0.98      0.98      0.98     11654

    accuracy                           0.95     14932
   macro avg       0.90      0.91      0.90     14932
weighted avg       0.95      0.95      0.95     14932



Epoch: 100%|██████████| 5/5 [03:13<00:00, 38.73s/it]

Average train loss: 0.023591196647396794

Validation loss: 0.03817272585417543

              precision    recall  f1-score   support

   B-Disease       0.86      0.89      0.88      1718
   I-Disease       0.85      0.89      0.87      1560
           O       0.99      0.97      0.98     11654

    accuracy                           0.96     14932
   macro avg       0.90      0.92      0.91     14932
weighted avg       0.96      0.96      0.96     14932






In [None]:
!rm -rf /content/ner_disease_main/0
!rm -rf /content/ner_disease_main/1
!rm -rf /content/ner_disease_main/2
!rm -rf /content/ner_disease_main/3

## Load the model as TF and save properly


In [None]:
last_successfull_epoch = len(loss_values) - 1
if last_successfull_epoch < 0:
  last_successfull_epoch = None

if last_successfull_epoch is None:
  print("No epochs finished successfully.")
else:
  print(f"Last successfull epoch: {str(last_successfull_epoch)}")

# first save the model as pytorch model (we'll cast later)
MODEL_NAME_PYTORCH = 'model_epoch_'+str(last_successfull_epoch)+'_pytorch'
MODEL_NAME_TF = 'model_epoch_'+str(last_successfull_epoch)+'_tf'

print(MODEL_NAME_PYTORCH)
print(MODEL_NAME_TF)

tokenizer.save_pretrained(f'./{PROJECT_NAME}/{MODEL_NAME_PYTORCH}_tokenizer/')
model.save_pretrained(f'./{PROJECT_NAME}/{MODEL_NAME_PYTORCH}', saved_model=True, save_format='tf')


Last successfull epoch: 4
model_epoch_4_pytorch
model_epoch_4_tf


In [None]:
import tensorflow as tf
from transformers import TFBertForTokenClassification

# now load the model as TF and save properly

loaded_model = TFBertForTokenClassification.from_pretrained(f'./{PROJECT_NAME}/{MODEL_NAME_PYTORCH}', from_pt=True)

# Define TF Signature
@tf.function(
  input_signature=[
      {
          "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"),
          "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"),
          "token_type_ids": tf.TensorSpec((None, None), tf.int32, name="token_type_ids"),
      }
  ]
)
def serving_fn(input):
    return loaded_model(input)
loaded_model.save_pretrained(f'./{PROJECT_NAME}/{MODEL_NAME_TF}', saved_model=True, signatures={"serving_default": serving_fn})
labels = sorted(tag2idx, key=tag2idx.get)

print (labels)

with open(f'./{PROJECT_NAME}/{MODEL_NAME_TF}/saved_model/1/assets/labels.txt', 'w') as f:
    f.write('\n'.join(labels))

vocab_pth = f"./{PROJECT_NAME}/{MODEL_NAME_PYTORCH}_tokenizer/vocab.txt"
saved_model_pth = f'./{PROJECT_NAME}/{MODEL_NAME_TF}/saved_model/1/assets/'

! cp $vocab_pth $saved_model_pth

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForTokenClassification: ['bert.embeddings.position_ids']
- This IS expected if you are initializing TFBertForTokenClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForTokenClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertForTokenClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForTokenClassification for predictions without further training.


['B-Disease', 'I-Disease', 'O', 'PAD']


## Load the saved model in Spark NLP and save it properly¶


In [None]:
tokenClassifier = nlp.BertForTokenClassification.loadSavedModel(
      f'./{PROJECT_NAME}/{MODEL_NAME_TF}/saved_model/1',
      spark)\
    .setInputCols(["sentence",'token'])\
    .setOutputCol("ner")\
    .setCaseSensitive(True)\
    .setMaxSentenceLength(128) # 512

tokenClassifier.write().overwrite().save(f"./{PROJECT_NAME}/{MODEL_NAME_TF}_spark_nlp")

## Test the imported model in Spark NLP¶


In [None]:
documentAssembler = nlp.DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

sentenceDetector = nlp.SentenceDetectorDLModel.pretrained()\
    .setInputCols(["document"])\
    .setOutputCol("sentence")

tokenizer = nlp.Tokenizer()\
    .setInputCols("sentence")\
    .setOutputCol("token")

tokenClassifier = nlp.BertForTokenClassification.load(f"./{PROJECT_NAME}/{MODEL_NAME_TF}_spark_nlp")\
    .setInputCols("token", "sentence")\
    .setOutputCol("label")\
    .setCaseSensitive(True)

ner_converter = medical.NerConverterInternal()\
    .setInputCols(["sentence","token","label"])\
    .setOutputCol("ner_chunk")


pipeline =  nlp.Pipeline(
    stages=[
        documentAssembler,
        sentenceDetector,
        tokenizer,
        tokenClassifier,
        ner_converter
    ]
)

p_model = pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

sentence_detector_dl download started this may take some time.
Approximate size to download 354.6 KB
[OK!]


In [None]:
text = 'A 28-year-old female with a history of gestational diabetes mellitus diagnosed eight years prior to presentation and subsequent type two diabetes mellitus ( T2DM ), one prior episode of HTG-induced pancreatitis three years prior to presentation , associated with an acute hepatitis , and obesity with a body mass index ( BMI ) of 33.5 kg/m2 , presented with a one-week history of polyuria , polydipsia , poor appetite , and vomiting . Two weeks prior to presentation , she was treated with a five-day course of amoxicillin for a respiratory tract infection . She was on metformin , glipizide , and dapagliflozin for T2DM and atorvastatin and gemfibrozil for HTG . She had been on dapagliflozin for six months at the time of presentation . Physical examination on presentation was significant for dry oral mucosa ; significantly , her abdominal examination was benign with no tenderness , guarding , or rigidity . Pertinent laboratory findings on admission were : serum glucose 111 mg/dl , bicarbonate 18 mmol/l , anion gap 20 , creatinine 0.4 mg/dL , triglycerides 508 mg/dL , total cholesterol 122 mg/dL , glycated hemoglobin ( HbA1c ) 10% , and venous pH 7.27 . Serum lipase was normal at 43 U/L . Serum acetone levels could not be assessed as blood samples kept hemolyzing due to significant lipemia . The patient was initially admitted for starvation ketosis , as she reported poor oral intake for three days prior to admission . However , serum chemistry obtained six hours after presentation revealed her glucose was 186 mg/dL , the anion gap was still elevated at 21 , serum bicarbonate was 16 mmol/L , triglyceride level peaked at 2050 mg/dL , and lipase was 52 U/L . The β-hydroxybutyrate level was obtained and found to be elevated at 5.29 mmol/L - the original sample was centrifuged and the chylomicron layer removed prior to analysis due to interference from turbidity caused by lipemia again . The patient was treated with an insulin drip for euDKA and HTG with a reduction in the anion gap to 13 and triglycerides to 1400 mg/dL , within 24 hours . Her euDKA was thought to be precipitated by her respiratory tract infection in the setting of SGLT2 inhibitor use . The patient was seen by the endocrinology service and she was discharged on 40 units of insulin glargine at night , 12 units of insulin lispro with meals , and metformin 1000 mg two times a day . It was determined that all SGLT2 inhibitors should be discontinued indefinitely . She had close follow-up with endocrinology post discharge .'

result = p_model.transform(spark.createDataFrame([[text]]).toDF('text'))

In [None]:
tokenClassifier.getClasses()

['PAD', 'O', 'I-Disease', 'B-Disease']

In [None]:
result.select(F.explode(F.arrays_zip(result.token.result,
                                     result.label.result)).alias("cols")) \
      .select(F.expr("cols['0']").alias("token"),
              F.expr("cols['1']").alias("label")).show(50, truncate=False)

+------------+---------+
|token       |label    |
+------------+---------+
|A           |O        |
|28-year-old |O        |
|female      |O        |
|with        |O        |
|a           |O        |
|history     |O        |
|of          |O        |
|gestational |B-Disease|
|diabetes    |I-Disease|
|mellitus    |I-Disease|
|diagnosed   |O        |
|eight       |O        |
|years       |O        |
|prior       |O        |
|to          |O        |
|presentation|O        |
|and         |O        |
|subsequent  |O        |
|type        |B-Disease|
|two         |B-Disease|
|diabetes    |I-Disease|
|mellitus    |I-Disease|
|(           |O        |
|T2DM        |B-Disease|
|),          |O        |
|one         |O        |
|prior       |O        |
|episode     |O        |
|of          |O        |
|HTG-induced |B-Disease|
|pancreatitis|I-Disease|
|three       |O        |
|years       |O        |
|prior       |O        |
|to          |O        |
|presentation|O        |
|,           |O        |


In [None]:
result.select(F.explode(F.arrays_zip(result.ner_chunk.result, result.ner_chunk.metadata)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunk"),
              F.expr("cols['1']['entity']").alias("ner_label")).show(truncate=False)

+-----------------------------+---------+
|chunk                        |ner_label|
+-----------------------------+---------+
|gestational diabetes mellitus|Disease  |
|type                         |Disease  |
|two diabetes mellitus        |Disease  |
|T2DM                         |Disease  |
|HTG-induced pancreatitis     |Disease  |
|acute                        |Disease  |
|hepatitis                    |Disease  |
|obesity                      |Disease  |
|polyuria                     |Disease  |
|polydipsia                   |Disease  |
|poor                         |Disease  |
|appetite                     |Disease  |
|vomiting                     |Disease  |
|respiratory                  |Disease  |
|tract infection              |Disease  |
|T2DM                         |Disease  |
|HTG                          |Disease  |
|dry                          |Disease  |
|oral mucosa                  |Disease  |
|abdominal                    |Disease  |
+-----------------------------+---