![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/healthcare-nlp/1.5.BertForTokenClassification_NER_ONNX_SparkNLP_with_Transformers.ipynb)

# Medical NER: BertForTokenClassification → ONNX → Spark NLP JSL

This notebook trains a Named Entity Recognition model on NCBI Disease corpus using BertForTokenClassification, exports it to ONNX format, and prepares it for **licensed** `johnsnowlabs` deployment.

## Pipeline Steps:
1. Download NCBI CoNLL dataset
2. Train BertForTokenClassification model
3. Export to ONNX format
4. Test ONNX inference
5. Package for Spark NLP JSL

**Dataset:** NCBI Disease CoNLL format from John Snow Labs workshop

## 1. Installation and Setup

In [None]:
# Install required packages
!pip install -q torch transformers datasets
!pip install -q onnx onnxruntime
!pip install -q seqeval scikit-learn

print("✅ All packages installed successfully!")

## 2. Import Libraries

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import (
    BertTokenizer,
    BertForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
import numpy as np
import onnx
import onnxruntime
from typing import Dict, List, Tuple
import json
import os
from pathlib import Path
from seqeval.metrics import classification_report, f1_score
from sklearn.metrics import classification_report as sklearn_report
from sklearn.metrics import precision_recall_fscore_support
from collections import Counter

print("✅ Libraries imported successfully!")

✅ Libraries imported successfully!


## 3. Parse CoNLL Format Data

In [2]:
def parse_conll_file(file_path):
    """
    Parse CoNLL format file and filter sentences with multiple unique tags

    Format:
    token1 tag1
    token2 tag2
    (empty line = sentence boundary)

    Filters out sentences with only one unique tag (e.g., all "O" tags)
    """
    sentences = []
    tags = []

    current_tokens = []
    current_tags = []

    special_char_tokens = []

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()

            if line == "" or line.startswith("-DOCSTART-"):
                if current_tokens:
                    unique_tags = set(current_tags)
                    if len(unique_tags) > 1:
                        sentences.append(current_tokens)
                        tags.append(current_tags)

                    current_tokens = []
                    current_tags = []
            else:
                parts = line.split()
                if len(parts) >= 2:
                    token = parts[0]
                    tag = parts[-1]
                    current_tokens.append(token)
                    current_tags.append(tag)

    if current_tokens:
        unique_tags = set(current_tags)
        if len(unique_tags) > 1:
            sentences.append(current_tokens)
            tags.append(current_tags)


    return sentences, tags


print("✅ Updated parse_conll_file defined")

✅ Updated parse_conll_file defined


## Load CoNLL Data

In [3]:
# Download NCBI CoNLL files
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/refs/heads/master/tutorials/Certification_Trainings/Healthcare/data/NER_NCBIconlltrain.txt -O train.conll
!wget -q https://raw.githubusercontent.com/JohnSnowLabs/spark-nlp-workshop/refs/heads/master/tutorials/Certification_Trainings/Healthcare/data/NER_NCBIconlltest.txt -O test.conll

print("✅ Files downloaded successfully!")
print("  - train.conll")
print("  - test.conll")

✅ Files downloaded successfully!
  - train.conll
  - test.conll


In [4]:
def load_conll_data():
    """Load CoNLL dataset"""

    print("\n📚 Loading CoNLL dataset...")

    # Parse train and test files
    train_sentences, train_tags = parse_conll_file("train.conll")
    test_sentences, test_tags = parse_conll_file("test.conll")

    # Get unique labels
    all_tags = set()
    for tags in train_tags + test_tags:
        all_tags.update(tags)

    label_list = sorted(list(all_tags))
    label2id = {label: i for i, label in enumerate(label_list)}
    id2label = {i: label for i, label in enumerate(label_list)}

    print(f"✓ Dataset loaded")
    print(f"  - Train sentences: {len(train_sentences)}")
    print(f"  - Test sentences: {len(test_sentences)}")
    print(f"  - Unique labels: {label_list}")
    print(f"  - Number of labels: {len(label_list)}")

    # Print label distribution
    train_tag_counts = Counter([tag for tags in train_tags for tag in tags])
    print(f"\n📊 Label distribution in training set:")
    for label, count in train_tag_counts.most_common():
        print(f"  {label}: {count}")

    test_tag_counts = Counter([tag for tags in test_tags for tag in tags])
    print(f"\n📊 Label distribution in test set:")
    for label, count in test_tag_counts.most_common():
        print(f"  {label}: {count}")

    return {
        'train': {'sentences': train_sentences, 'tags': train_tags},
        'test': {'sentences': test_sentences, 'tags': test_tags},
        'label_list': label_list,
        'label2id': label2id,
        'id2label': id2label
    }

# Load the data
data = load_conll_data()


📚 Loading CoNLL dataset...
✓ Dataset loaded
  - Train sentences: 1700
  - Test sentences: 392
  - Unique labels: ['B-Disease', 'I-Disease', 'O']
  - Number of labels: 3

📊 Label distribution in training set:
  O: 39427
  I-Disease: 3547
  B-Disease: 3093

📊 Label distribution in test set:
  O: 9316
  I-Disease: 789
  B-Disease: 708


## 4. Create PyTorch Dataset

In [5]:
class NERDataset(Dataset):
    """Custom NER Dataset for CoNLL format"""

    def __init__(self, sentences, tags, tokenizer, label2id, max_length=128):
        self.sentences = sentences
        self.tags = tags
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        tokens = self.sentences[idx]
        labels = self.tags[idx]

        # Tokenize with is_split_into_words=True
        encoding = self.tokenizer(
            tokens,
            is_split_into_words=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )

        # Align labels with subword tokens
        word_ids = encoding.word_ids(batch_index=0)
        label_ids = []
        previous_word_idx = None

        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens ([CLS], [SEP], [PAD]) get -100
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # First subword of a word gets the original label
                label_ids.append(self.label2id[labels[word_idx]])
            else:
                # Continuation subwords get I- version of the label
                original_label = labels[word_idx]

                if original_label == 'O':
                    # O labels stay O
                    label_ids.append(self.label2id['O'])
                elif original_label.startswith('B-'):
                    # B- becomes I- for continuation subwords
                    entity_type = original_label[2:]  # Remove 'B-'
                    continuation_label = f'I-{entity_type}'
                    label_ids.append(self.label2id[continuation_label])
                elif original_label.startswith('I-'):
                    # I- stays I- for continuation subwords
                    label_ids.append(self.label2id[original_label])
                else:
                    # Fallback: use -100
                    label_ids.append(-100)

            previous_word_idx = word_idx

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label_ids)
        }

print("✅ NERDataset class defined")

✅ NERDataset class defined


## 5. Define Metrics

In [6]:
label_list = data['label_list']

# Initialize tracking variables
epoch_counter = {
    "current": 1,
    "best_epoch": 0,
    "best_f1": 0.0,
    "best_metrics": {}
}

def compute_metrics(pred):
    """Compute NER metrics using sklearn (token-level) for each epoch"""

    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)

    # Flatten predictions and labels, removing ignored index (-100)
    flat_predictions = []
    flat_labels = []

    for prediction, label in zip(predictions, labels):
        for p, l in zip(prediction, label):
            if l != -100:  # Skip special tokens
                flat_predictions.append(p)
                flat_labels.append(l)

    # Convert to label names
    pred_labels = [label_list[p] for p in flat_predictions]
    true_labels = [label_list[l] for l in flat_labels]

    # Compute overall metrics for tracking
    precision, recall, f1, _ = precision_recall_fscore_support(
        true_labels,
        pred_labels,
        average='weighted',
        zero_division=0
    )

    # Track best epoch silently (only during training)
    if not epoch_counter.get('is_final', False) and f1 > epoch_counter['best_f1']:
        epoch_counter['best_epoch'] = epoch_counter['current']
        epoch_counter['best_f1'] = f1
        epoch_counter['best_metrics'] = {
            'precision': precision,
            'recall': recall,
            'f1': f1
        }

    # Print header based on mode
    if epoch_counter.get('is_final', False):
        header = "📊 FINAL TOKEN-LEVEL METRICS"
    else:
        header = f"📊 METRICS - Epoch {epoch_counter['current']}"

    # Print detailed sklearn classification report (token-level)
    print("\n" + "="*70)
    print(header)
    print("="*70)
    report = sklearn_report(
        true_labels,
        pred_labels,
        digits=4,
        zero_division=0
    )
    print(report)
    print("="*70 + "\n")

    # Increment counter only during training
    if not epoch_counter.get('is_final', False):
        epoch_counter['current'] += 1

    results = {
        "precision": precision,
        "recall": recall,
        "f1": f1
    }

    return results

## 6. Train BertForTokenClassification Model

In [7]:
# Initialize tracking variables
epoch_counter = {
    "current": 1,
    "best_epoch": 0,
    "best_f1": 0.0,
    "best_metrics": {}
}

# Configuration
MODEL_NAME = 'dmis-lab/biobert-base-cased-v1.2'
OUTPUT_DIR = "./ncbi_ner_model"
NUM_EPOCHS = 3
BATCH_SIZE = 8
LEARNING_RATE = 3e-05

print("\n🔧 Initializing BertTokenizer and BertForTokenClassification...")
from transformers import BertTokenizerFast, BertForTokenClassification

# CRITICAL: Configure tokenizer properly for medical text
tokenizer = BertTokenizerFast.from_pretrained(
    MODEL_NAME,
    do_lower_case=False,          # BioBERT is cased
    strip_accents=None,            # Let model decide
    clean_text=True,               # Remove control chars
    tokenize_chinese_chars=True,   # Standard BERT behavior
    do_basic_tokenize=True,         # Essential for proper tokenization
    never_split=['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
)

model = BertForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_list),
    id2label=data['id2label'],
    label2id=data['label2id']
)

print(f"✓ Model initialized with {len(label_list)} labels")

# Create datasets
print("\n📝 Creating PyTorch datasets...")
train_dataset = NERDataset(
    data['train']['sentences'],
    data['train']['tags'],
    tokenizer,
    data['label2id']
)

test_dataset = NERDataset(
    data['test']['sentences'],
    data['test']['tags'],
    tokenizer,
    data['label2id']
)

print(f"✓ Train dataset: {len(train_dataset)} samples")
print(f"✓ Test dataset: {len(test_dataset)} samples")

# Data collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=LEARNING_RATE,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    logging_dir=f"{OUTPUT_DIR}/logs",
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
    push_to_hub=False,
    report_to="none"
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

# Train the model
print("🚀 We are starting training...")
print("=" * 70)
trainer.train()

# Display best model information
print("\n" + "="*70)
print("🎯 TRAINING COMPLETE - BEST MODEL SUMMARY")
print("="*70)
print(f"🏆 Best Epoch: {epoch_counter['best_epoch']}")
print(f"📊 Best F1 Score: {epoch_counter['best_f1']:.4f}")
print(f"📊 Best Precision: {epoch_counter['best_metrics']['precision']:.4f}")
print(f"📊 Best Recall: {epoch_counter['best_metrics']['recall']:.4f}")
print("="*70 + "\n")

# Save best epoch info with the model
best_epoch_info = {
    "best_epoch": epoch_counter['best_epoch'],
    "best_f1": float(epoch_counter['best_f1']),
    "best_precision": float(epoch_counter['best_metrics']['precision']),
    "best_recall": float(epoch_counter['best_metrics']['recall'])
}

with open(f"{OUTPUT_DIR}/best_epoch_info.json", "w") as f:
    json.dump(best_epoch_info, f, indent=2)

print(f"✅ Best epoch info saved to {OUTPUT_DIR}/best_epoch_info.json")


🔧 Initializing BertTokenizer and BertForTokenClassification...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Model initialized with 3 labels

📝 Creating PyTorch datasets...
✓ Train dataset: 1700 samples
✓ Test dataset: 392 samples
🚀 We are starting training...


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.0862,0.084476,0.972302,0.971467,0.971728
2,0.0373,0.082123,0.974228,0.973677,0.973876
3,0.0103,0.102188,0.973502,0.973074,0.973229



📊 METRICS - Epoch 1
              precision    recall  f1-score   support

   B-Disease     0.8898    0.9124    0.9010       708
   I-Disease     0.9140    0.9638    0.9383      2570
           O     0.9902    0.9767    0.9834     11652

    accuracy                         0.9715     14930
   macro avg     0.9313    0.9510    0.9409     14930
weighted avg     0.9723    0.9715    0.9717     14930



📊 METRICS - Epoch 2
              precision    recall  f1-score   support

   B-Disease     0.8814    0.9237    0.9021       708
   I-Disease     0.9305    0.9584    0.9442      2570
           O     0.9895    0.9801    0.9848     11652

    accuracy                         0.9737     14930
   macro avg     0.9338    0.9541    0.9437     14930
weighted avg     0.9742    0.9737    0.9739     14930



📊 METRICS - Epoch 3
              precision    recall  f1-score   support

   B-Disease     0.9000    0.9153    0.9076       708
   I-Disease     0.9258    0.9564    0.9409      2570
          

## 7. Evaluate and Save Model

In [8]:
def evaluate_with_seqeval(model, dataset, tokenizer, label_list, batch_size=16):
    """Evaluate model using seqeval (chunk-level metrics)"""
    from seqeval.metrics import classification_report, precision_score, recall_score, f1_score
    from torch.utils.data import DataLoader

    model.eval()
    dataloader = DataLoader(dataset, batch_size=batch_size)
    all_predictions, all_labels = [], []

    # Get predictions
    for batch in dataloader:
        with torch.no_grad():
            input_ids = batch['input_ids'].to(model.device)
            attention_mask = batch['attention_mask'].to(model.device)
            labels = batch['labels']

            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()
            labels = labels.numpy()

            # Convert to label sequences
            for pred, label in zip(predictions, labels):
                pred_seq = [label_list[p] for p, l in zip(pred, label) if l != -100]
                label_seq = [label_list[l] for l in label if l != -100]
                if pred_seq:
                    all_predictions.append(pred_seq)
                    all_labels.append(label_seq)

    # Compute and print metrics
    print("\n" + "="*70)
    print("🎯FINAL ENTITY-LEVEL EVALUATION")
    print("="*70)
    print(f"Precision: {precision_score(all_labels, all_predictions):.4f}")
    print(f"Recall:    {recall_score(all_labels, all_predictions):.4f}")
    print(f"F1-Score:  {f1_score(all_labels, all_predictions):.4f}\n")
    print(classification_report(all_labels, all_predictions, digits=4))
    print("="*70)

    return {
        "precision": precision_score(all_labels, all_predictions),
        "recall": recall_score(all_labels, all_predictions),
        "f1": f1_score(all_labels, all_predictions)
    }

print("✅ Seqeval evaluation function defined")

✅ Seqeval evaluation function defined


In [9]:
# Set flag for final evaluation (don't reset counter!)
epoch_counter['is_final'] = True

# Token-level evaluation
eval_results = trainer.evaluate()
print(f"F1:        {eval_results['eval_f1']:.4f}")
print(f"Precision: {eval_results['eval_precision']:.4f}")
print(f"Recall:    {eval_results['eval_recall']:.4f}")
print(f"Loss:      {eval_results['eval_loss']:.4f}")

# Entity-level evaluation
seqeval_results = evaluate_with_seqeval(model, test_dataset, tokenizer, label_list, batch_size=16)

# Show best epoch summary
print("\n" + "="*70)
print("🏆 BEST MODEL SUMMARY")
print("="*70)
print(f"Best Epoch: {epoch_counter['best_epoch']}")
print(f"Best F1:    {epoch_counter['best_f1']:.4f}")
print(f"Precision:  {epoch_counter['best_metrics']['precision']:.4f}")
print(f"Recall:     {epoch_counter['best_metrics']['recall']:.4f}")
print("="*70)

# Save everything
print("\n💾 Saving model and results...")
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# Save labels and mappings
with open(f"{OUTPUT_DIR}/tags.txt", "w") as f:
    f.write('\n'.join(label_list))

label_info = {
    "label2id": data['label2id'],
    "id2label": data['id2label'],
    "labels": label_list
}
with open(f"{OUTPUT_DIR}/label_mappings.json", "w") as f:
    json.dump(label_info, f, indent=2)

# Save evaluation results
eval_summary = {
    "best_epoch": {
        "epoch": epoch_counter['best_epoch'],
        "f1": float(epoch_counter['best_f1']),
        "precision": float(epoch_counter['best_metrics']['precision']),
        "recall": float(epoch_counter['best_metrics']['recall'])
    },
    "final_token_level": {
        "precision": float(eval_results['eval_precision']),
        "recall": float(eval_results['eval_recall']),
        "f1": float(eval_results['eval_f1']),
        "loss": float(eval_results['eval_loss'])
    },
    "final_entity_level": {
        "precision": float(seqeval_results['precision']),
        "recall": float(seqeval_results['recall']),
        "f1": float(seqeval_results['f1'])
    }
}
with open(f"{OUTPUT_DIR}/evaluation_results.json", "w") as f:
    json.dump(eval_summary, f, indent=2)

print(f"\n✅ All results saved to {OUTPUT_DIR}")
print("="*70)

# Reset flag
epoch_counter['is_final'] = False


📊 FINAL TOKEN-LEVEL METRICS
              precision    recall  f1-score   support

   B-Disease     0.8814    0.9237    0.9021       708
   I-Disease     0.9305    0.9584    0.9442      2570
           O     0.9895    0.9801    0.9848     11652

    accuracy                         0.9737     14930
   macro avg     0.9338    0.9541    0.9437     14930
weighted avg     0.9742    0.9737    0.9739     14930


F1:        0.9739
Precision: 0.9742
Recall:    0.9737
Loss:      0.0821

🎯FINAL ENTITY-LEVEL EVALUATION
Precision: 0.8206
Recall:    0.8983
F1-Score:  0.8577

              precision    recall  f1-score   support

     Disease     0.8206    0.8983    0.8577       708

   micro avg     0.8206    0.8983    0.8577       708
   macro avg     0.8206    0.8983    0.8577       708
weighted avg     0.8206    0.8983    0.8577       708


🏆 BEST MODEL SUMMARY
Best Epoch: 2
Best F1:    0.9739
Precision:  0.9742
Recall:     0.9737

💾 Saving model and results...

✅ All results saved to ./ncbi_ne

## 8. Test Model Predictions

In [10]:
print("\n🧪 Testing model predictions...")

# Load model for inference
model = BertForTokenClassification.from_pretrained(OUTPUT_DIR)
tokenizer = BertTokenizer.from_pretrained(OUTPUT_DIR)
model.eval()

with open(f"{OUTPUT_DIR}/label_mappings.json", "r") as f:
    label_info = json.load(f)
id2label = {int(k): v for k, v in label_info['id2label'].items()}

# Test examples
test_texts = [
    "Breast cancer is a disease in which cells in the breast grow out of control.",
    "Patients with diabetes mellitus require insulin therapy.",
    "Alzheimer disease is characterized by progressive cognitive deterioration."
]

for text in test_texts:
    print(f"\n📝 Text: {text}")

    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )

    # Predict
    with torch.no_grad():
        outputs = model(**inputs)

    predictions = torch.argmax(outputs.logits, dim=2)

    # Get tokens and labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    pred_labels = [id2label[p.item()] for p in predictions[0]]

    # Extract entities with proper subword reconstruction
    entities = []
    current_entity = ""  # ← STRING, not list!
    current_label = None

    for token, label in zip(tokens, pred_labels):
        if token in ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']:
            continue

        if label.startswith('B-'):
            # Save previous entity
            if current_entity:
                entities.append((current_label, current_entity.strip()))

            # Start new entity
            current_label = label[2:]
            if token.startswith('##'):
                current_entity = token[2:]  # Remove ## without space
            else:
                current_entity = token

        elif label.startswith('I-') and current_label:
            # Continue entity
            if token.startswith('##'):
                current_entity += token[2:]  # ← Concatenate WITHOUT space
            else:
                current_entity += " " + token  # ← Add space for new word

        else:
            # Non-entity token
            if current_entity:
                entities.append((current_label, current_entity.strip()))
            current_entity = ""
            current_label = None

    # Don't forget last entity
    if current_entity:
        entities.append((current_label, current_entity.strip()))

    # Print detected entities
    if entities:
        print("\n🎯 Detected Entities:")
        for label, entity in entities:
            print(f"  - {entity} ({label})")
    else:
        print("\n  No entities detected")

print("\n✅ Testing completed!")


🧪 Testing model predictions...

📝 Text: Breast cancer is a disease in which cells in the breast grow out of control.

🎯 Detected Entities:
  - Breast cancer (Disease)

📝 Text: Patients with diabetes mellitus require insulin therapy.

🎯 Detected Entities:
  - diabetes mellitus (Disease)

📝 Text: Alzheimer disease is characterized by progressive cognitive deterioration.

🎯 Detected Entities:
  - Alzheimer disease (Disease)
  - cognitive deterioration (Disease)

✅ Testing completed!


## 9. Export to ONNX

In [11]:
# Export to ONNX for Spark NLP Healthcare
import onnx
from onnx import TensorProto

ONNX_PATH = "./ncbi_ner_model.onnx"
MAX_LENGTH = 512
OPSET_VERSION = 14

print("\n📦 Exporting to ONNX...")

# Load model and create dummy input
model = BertForTokenClassification.from_pretrained(OUTPUT_DIR)
tokenizer = BertTokenizer.from_pretrained(OUTPUT_DIR)
model.eval()

inputs = tokenizer(
    "Sample text for ONNX export",
    return_tensors="pt",
    padding="max_length",
    truncation=True,
    max_length=512,
    is_split_into_words=True,
)

# Export to ONNX
print(f"🔄 Converting to ONNX (opset {OPSET_VERSION})...")
torch.onnx.export(
    model,
    (inputs["input_ids"], inputs["attention_mask"], inputs.get("token_type_ids")),
    ONNX_PATH,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence_length"},
        "attention_mask": {0: "batch_size", 1: "sequence_length"},
        "token_type_ids": {0: "batch_size", 1: "sequence_length"},
        "logits": {0: "batch_size", 1: "sequence_length"}
    },
    opset_version=OPSET_VERSION,
    do_constant_folding=True
)

# Patch for Spark NLP Healthcare: INT64 inputs + rename for medical_input_ids
print("✓ Patching ONNX model for Spark NLP...")
onnx_model = onnx.load(ONNX_PATH)
onnx.checker.check_model(onnx_model)

# Rename input_ids -> medical_input_ids
name_map = {"input_ids": "medical_input_ids"}

# Force INT64 on all inputs
for ip in onnx_model.graph.input:
    if ip.type.tensor_type.elem_type != TensorProto.INT64:
        ip.type.tensor_type.elem_type = TensorProto.INT64
    if ip.name in name_map:
        new_name = name_map[ip.name]
        print(f"✓ Renaming: {ip.name} -> {new_name}")
        ip.name = new_name

# Update node references
for node in onnx_model.graph.node:
    node.input[:] = [name_map.get(n, n) for n in node.input]

# Update initializers
for init in onnx_model.graph.initializer:
    if init.name in name_map:
        init.name = name_map[init.name]

onnx.save(onnx_model, ONNX_PATH)
print(f"✅ ONNX model saved to {ONNX_PATH}")
print(f"   Inputs: {[i.name for i in onnx_model.graph.input]}")
print(f"   Outputs: {[o.name for o in onnx_model.graph.output]}")


📦 Exporting to ONNX...
🔄 Converting to ONNX (opset 14)...


  torch.onnx.export(
  inverted_mask = torch.tensor(1.0, dtype=dtype) - expanded_mask


✓ Patching ONNX model for Spark NLP...
✓ Renaming: input_ids -> medical_input_ids
✅ ONNX model saved to ./ncbi_ner_model.onnx
   Inputs: ['medical_input_ids', 'attention_mask', 'token_type_ids']
   Outputs: ['logits']


## 10. Test ONNX Model

In [12]:
# Test ONNX model with dynamic input-name handling
import numpy as np
import onnxruntime as ort

print("\n🧪 Testing ONNX model...")

# Load inputs
test_text = "Breast cancer and diabetes are common diseases."
inputs = tokenizer(
    test_text,
    return_tensors="np",
    padding="max_length",
    truncation=True,
    max_length=512,
    return_token_type_ids=True  # Add this explicitly
)

# Start session and inspect expected names
session = ort.InferenceSession(ONNX_PATH)
expected = [i.name for i in session.get_inputs()]
print("Model expects inputs:", expected)

def as_i64(x):
    import numpy as _np
    return _np.asarray(x, dtype=_np.int64)

# Build feed dict based on expected naming
if set(expected) == set(["medical_input_ids", "attention_mask", "token_type_ids"]):
    ort_inputs = {
        "medical_input_ids": as_i64(inputs["input_ids"]),
        "attention_mask": as_i64(inputs["attention_mask"]),
        "token_type_ids": as_i64(inputs.get("token_type_ids", np.zeros_like(inputs["input_ids"])))
    }
else:
    raise ValueError(f"Unexpected input names in model: {expected}")

# Run inference
outputs = session.run(None, ort_inputs)
logits = outputs[0]
print("✅ ONNX forward pass OK. Logits shape:", logits.shape)



🧪 Testing ONNX model...
Model expects inputs: ['medical_input_ids', 'attention_mask', 'token_type_ids']
✅ ONNX forward pass OK. Logits shape: (1, 512, 3)


## 11. Import the Model to Spark NLP Healthcare Library

### Upload License File

In [14]:
import json
import os

from google.colab import files

if 'spark_jsl.json' not in os.listdir():
  license_keys = files.upload()
  os.rename(list(license_keys.keys())[0], 'spark_jsl.json')

with open('spark_jsl.json') as f:
    license_keys = json.load(f)

# Defining license key-value pairs as local variables
locals().update(license_keys)
os.environ.update(license_keys)

### Install Necessary Libraries

In [None]:
# Installing pyspark and spark-nlp
! pip install --upgrade -q pyspark==3.4.1 spark-nlp==$PUBLIC_VERSION

# Installing Spark NLP Healthcare
! pip install --upgrade -q spark-nlp-jsl==$JSL_VERSION  --extra-index-url https://pypi.johnsnowlabs.com/$SECRET

# Installing Spark NLP Display Library for visualization
! pip install -q spark-nlp-display

### Import Libraries and Start Spark Session

In [15]:
import json
import os

import sparknlp
import sparknlp_jsl

from sparknlp.base import *
from sparknlp.annotator import *
from sparknlp_jsl.annotator import *

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.ml import Pipeline,PipelineModel

import warnings
warnings.filterwarnings('ignore')

print("Spark NLP Version :", sparknlp.version())
print("Spark NLP_JSL Version :", sparknlp_jsl.version())

spark = sparknlp_jsl.start(license_keys['SECRET'])

spark

Spark NLP Version : 6.1.3
Spark NLP_JSL Version : 6.1.1


## 12. Prepare Spark NLP JSL Model Package

In [16]:
import shutil
import onnx
import json

SPARK_NLP_PATH = "./spark_nlp_jsl_ncbi_ner"
MODEL_NAME = "ncbi_disease_ner_bert"

print("\n📋 Preparing Spark NLP JSL model...")

# 1. Create directory structure
os.makedirs(SPARK_NLP_PATH, exist_ok=True)
assets_path = os.path.join(SPARK_NLP_PATH, "assets")
os.makedirs(assets_path, exist_ok=True)

# 2. Copy ONNX model to ROOT as model.onnx
onnx_dest = os.path.join(SPARK_NLP_PATH, "model.onnx")
shutil.copy(ONNX_PATH, onnx_dest)
print(f"✓ Copied model.onnx to root")

# 3. Verify ONNX inputs
onnx_model = onnx.load(onnx_dest)
input_names = [i.name for i in onnx_model.graph.input]
print(f"  ONNX inputs: {input_names}")

# 4. Load model and extract labels
print("\n📝 Extracting labels from model...")
if 'model' not in dir():
    from transformers import BertForTokenClassification
    model = BertForTokenClassification.from_pretrained(OUTPUT_DIR)

labels_dict = model.config.label2id
labels_sorted = sorted(labels_dict, key=labels_dict.get)

# 5. Save labels to assets folder
with open(os.path.join(assets_path, 'labels.txt'), 'w') as f:
    f.write('\n'.join(labels_sorted))
print(f"✓ Saved labels.txt ({len(labels_sorted)} labels)")

# 6. Copy vocab.txt to assets
vocab_src = os.path.join(OUTPUT_DIR, "vocab.txt")
if os.path.exists(vocab_src):
    shutil.copy(vocab_src, os.path.join(assets_path, "vocab.txt"))
    print("✓ Copied vocab.txt to assets/")

# 7. Copy tokenizer files to ROOT directory (not assets!)
print("\n📁 Copying tokenizer config files to root...")
tokenizer_files = ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.json"]
for file in tokenizer_files:
    src = os.path.join(OUTPUT_DIR, file)
    if os.path.exists(src):
        shutil.copy(src, os.path.join(SPARK_NLP_PATH, file))  # Copy to ROOT
        print(f"  ✓ Copied {file} to root")

# 8. Create config.json in ROOT
config = {
    "architectures": ["BertForTokenClassification"],
    "model_type": "bert",
    "max_position_embeddings": 512,
    "hidden_size": 768,
    "num_labels": len(labels_sorted)
}
with open(os.path.join(SPARK_NLP_PATH, "config.json"), "w") as f:
    json.dump(config, f, indent=2)
print("✓ Created config.json in root")

# 9. Verify final structure
print("\n📁 Final directory structure:")
for root, dirs, files in os.walk(SPARK_NLP_PATH):
    level = root.replace(SPARK_NLP_PATH, '').count(os.sep)
    indent = ' ' * 2 * level
    print(f'{indent}{os.path.basename(root)}/')
    subindent = ' ' * 2 * (level + 1)
    for file in sorted(files):
        file_path = os.path.join(root, file)
        file_size = os.path.getsize(file_path)
        print(f'{subindent}{file} ({file_size} bytes)')

print("\n✅ Spark NLP JSL model preparation complete!")


📋 Preparing Spark NLP JSL model...
✓ Copied model.onnx to root
  ONNX inputs: ['medical_input_ids', 'attention_mask', 'token_type_ids']

📝 Extracting labels from model...
✓ Saved labels.txt (3 labels)
✓ Copied vocab.txt to assets/

📁 Copying tokenizer config files to root...
  ✓ Copied tokenizer_config.json to root
  ✓ Copied special_tokens_map.json to root
  ✓ Copied tokenizer.json to root
✓ Created config.json in root

📁 Final directory structure:
spark_nlp_jsl_ncbi_ner/
  config.json (160 bytes)
  model.onnx (431144990 bytes)
  special_tokens_map.json (125 bytes)
  tokenizer.json (669188 bytes)
  tokenizer_config.json (1389 bytes)
  assets/
    labels.txt (21 bytes)
    vocab.txt (213450 bytes)

✅ Spark NLP JSL model preparation complete!


## Load the Saved Model into Spark NLP

In [17]:
print(f"\n📦 Loading model from {SPARK_NLP_PATH}...")

tokenClassifier = MedicalBertForTokenClassifier\
    .loadSavedModel(SPARK_NLP_PATH, spark)\
    .setInputCols(["document", 'token'])\
    .setOutputCol("ner")\
    .setCaseSensitive(False)\
    .setMaxSentenceLength(512)

print("✓ Model loaded successfully into Spark NLP")

# Save the model in Spark NLP format
output_path = f"./{MODEL_NAME}_spark_nlp_onnx"
print(f"\n💾 Saving Spark NLP model to {output_path}...")

tokenClassifier.write().overwrite().save(output_path)

print(f"✅ Spark NLP model saved to {output_path}")
print(f"\n📋 Final output locations:")
print(f"  1. ONNX Export: {SPARK_NLP_PATH}/")
print(f"  2. Spark NLP Model: {output_path}/")


📦 Loading model from ./spark_nlp_jsl_ncbi_ner...
✓ Model loaded successfully into Spark NLP

💾 Saving Spark NLP model to ./ncbi_disease_ner_bert_spark_nlp_onnx...
✅ Spark NLP model saved to ./ncbi_disease_ner_bert_spark_nlp_onnx

📋 Final output locations:
  1. ONNX Export: ./spark_nlp_jsl_ncbi_ner/
  2. Spark NLP Model: ./ncbi_disease_ner_bert_spark_nlp_onnx/


## Test Spark NLP Pipeline

In [18]:
document_assembler = DocumentAssembler()\
    .setInputCol("text")\
    .setOutputCol("document")

tokenizer = Tokenizer()\
    .setInputCols(["document"])\
    .setOutputCol("token")

# Load the saved Spark NLP model
ner_model = MedicalBertForTokenClassifier.load(output_path)\
    .setInputCols(["document", "token"])\
    .setOutputCol("ner")\
    .setCaseSensitive(False)\
    .setMaxSentenceLength(512)

ner_converter = NerConverterInternal() \
    .setInputCols(["document", "token", "ner"]) \
    .setOutputCol("ner_chunk")

pipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    ner_model,
    ner_converter
])

In [19]:
ner_model.getClasses()

['B-Disease', 'I-Disease', 'O']

In [20]:
import pyspark.sql.functions as F

print("\n🧪 Testing Spark NLP model...")

# Test data
test_texts = [
    "Breast cancer is a disease in which cells in the breast grow out of control.",
    "Patients with diabetes mellitus require insulin therapy.",
    "Alzheimer disease is characterized by progressive cognitive deterioration."
]

# Create DataFrame
test_df = spark.createDataFrame([[text] for text in test_texts], ["text"])

# Fit and transform
model = pipeline.fit(test_df)
result = model.transform(test_df)

# Show results
print("\n📊 Results:")

# After transformation, inspect token metadata
result.select(
    F.explode(
        F.arrays_zip(
            result.token.result,
            result.token.begin,
            result.token.end,
            result.ner.result
        )
    ).alias("cols")
).select(
    F.expr("cols['0']").alias("token"),
    F.expr("cols['1']").alias("char_begin"),
    F.expr("cols['2']").alias("char_end"),
    F.expr("cols['3']").alias("label")
).show(50, truncate=False)


🧪 Testing Spark NLP model...

📊 Results:
+-------------+----------+--------+---------+
|token        |char_begin|char_end|label    |
+-------------+----------+--------+---------+
|Breast       |0         |5       |B-Disease|
|cancer       |7         |12      |I-Disease|
|is           |14        |15      |O        |
|a            |17        |17      |O        |
|disease      |19        |25      |O        |
|in           |27        |28      |O        |
|which        |30        |34      |O        |
|cells        |36        |40      |O        |
|in           |42        |43      |O        |
|the          |45        |47      |O        |
|breast       |49        |54      |O        |
|grow         |56        |59      |O        |
|out          |61        |63      |O        |
|of           |65        |66      |O        |
|control      |68        |74      |O        |
|.            |75        |75      |O        |
|Patients     |0         |7       |O        |
|with         |9         |12      |O  

In [21]:
result.select(F.explode(F.arrays_zip(result.ner_chunk.result,
                                     result.ner_chunk.metadata)).alias("cols")) \
      .select(F.expr("cols['0']").alias("chunk"),
              F.expr("cols['1']['entity']").alias("ner_label")).show(truncate=False)

+-----------------------+---------+
|chunk                  |ner_label|
+-----------------------+---------+
|Breast cancer          |Disease  |
|diabetes mellitus      |Disease  |
|Alzheimer disease      |Disease  |
|cognitive deterioration|Disease  |
+-----------------------+---------+

