# Task 2: Named Entity Recognition (NER)

In [39]:
# 1. Imports
import pandas as pd
import matplotlib.pyplot as plt
import random
import numpy as np
import json, requests
import re
import os
wandb_api_key = "1a489be8947a5bbf94fcc41855e40beaab35312e" # replace with your own wandb_api_key
os.environ['WANDB_API_KEY'] = wandb_api_key


## Data Exploration

In [40]:
file_path_drive = "ner_data.csv"
ner_df = pd.read_csv(file_path_drive)
ner_df[0:9]

Unnamed: 0,sentence_id,word,tag
0,1,Patients,O
1,1,experienced,O
2,1,cough,B-SYMPTOM
3,1,after,O
4,1,administration,O
5,1,of,O
6,1,75mg,B-DOSAGE
7,1,of,O
8,1,DrugZ,O


In [41]:
# Count unique entities
def count_unique_entities(df):
    # Filter only entity rows (excluding 'O' tags)
    entity_df = df[df['tag'] != 'O']

    # Extract entity type from tag (B-DRUG -> DRUG)
    entity_df['entity_type'] = entity_df['tag'].str.split('-').str[1]

    # Group by entity type and count unique words
    unique_counts = entity_df.groupby('entity_type')['word'].nunique()

    return unique_counts

# Get counts
entity_counts = count_unique_entities(ner_df)
print("Unique entity counts:")
print(entity_counts)

# Alternative visualization
print("\nDetailed breakdown:")
for entity_type in ['DOSAGE', 'DRUG', 'SYMPTOM']:
    unique_words = ner_df[ner_df['tag'] == f'B-{entity_type}']['word'].unique()
    print(f"\n{entity_type} (Total unique: {len(unique_words)}):")
    print(unique_words)

Unique entity counts:
entity_type
DOSAGE     8
SYMPTOM    8
Name: word, dtype: int64

Detailed breakdown:

DOSAGE (Total unique: 8):
['75mg' '500mg' '200mg' '100mg' '10mg' '50mg' '250mg' '5mg']

DRUG (Total unique: 0):
[]

SYMPTOM (Total unique: 8):
['cough' 'fever' 'rash' 'headache' 'dizziness' 'vomiting' 'fatigue'
 'nausea']


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  entity_df['entity_type'] = entity_df['tag'].str.split('-').str[1]


Although there are 1000 sentences, there is very little variability in the data. I plan to combine the sentences, remove duplicates, revert back to the IOB form, then generate more data in a way that I believe this synthetic dataset was made. I'm going to keep 1-word entities since the original dataset also only had 1-word entities (no I-SYMPTOM or I-DOSAGE)

In [42]:
# 1. Combine words by sentence_id to form complete sentences
sentences = ner_df.groupby('sentence_id')['word'].apply(lambda x: ' '.join(x)).reset_index()
sentences.columns = ['sentence_id', 'sentence']

# 2. Remove duplicate sentences (keeping first occurrence)
unique_sentences = sentences.drop_duplicates(subset='sentence', keep='first')

# 3. Convert back to IOB format
iob_data = []
for _, row in unique_sentences.iterrows():
    words = row['sentence'].split()
    # Get the original tags for this sentence
    original_tags = ner_df[ner_df['sentence_id'] == row['sentence_id']]['tag'].tolist()
    
    # Ensure we have tags for all words (in case some were lost in processing)
    if len(original_tags) == len(words):
        tags = original_tags
    else:
        # If tags don't match, default to 'O' (shouldn't happen with this approach)
        tags = ['O'] * len(words)
    
    for word, tag in zip(words, tags):
        iob_data.append({
            'sentence_id': row['sentence_id'],
            'word': word,
            'tag': tag
        })

# Create new DataFrame
deduplicated_df = pd.DataFrame(iob_data)

# Save to CSV
deduplicated_df.to_csv('deduplicated_medical_data.csv', index=False)

print("Original number of sentences:", ner_df['sentence_id'].nunique())
print("Number of unique sentences:", deduplicated_df['sentence_id'].nunique())
print("\nSample of cleaned data:")
print(deduplicated_df.head(20))

Original number of sentences: 1000
Number of unique sentences: 825

Sample of cleaned data:
    sentence_id            word        tag
0             1        Patients          O
1             1     experienced          O
2             1           cough  B-SYMPTOM
3             1           after          O
4             1  administration          O
5             1              of          O
6             1            75mg   B-DOSAGE
7             1              of          O
8             1           DrugZ          O
9             2        Patients          O
10            2     experienced          O
11            2           fever  B-SYMPTOM
12            2           after          O
13            2          taking          O
14            2            75mg   B-DOSAGE
15            2              of          O
16            2           DrugE          O
17            3        Patients          O
18            3     experienced          O
19            3            rash  B-SYMPTOM


Notice that the `DrugZ` is tagged as O (outside, not important, not an entity to extract). Since we're training on this dataset, the model will learn that the drug names are NOT entities. 

However, "the reasoning behind this is that we can obfuscate PII or PHI during training, but we should be able to replace the tag with the correct drug name during inference."

I will still replace Drug[letter] with a random drug name to ensure that the model knows to recognize drugs. This keeps patient data protected since we are using random drugs anyways.

In [43]:
url = "https://gist.githubusercontent.com/ddbeck/56d54331a9c2526ff754/raw/fb8beb235bc97227a983a6d4e9a0067ecf9d29b5/drugs.json"
response = requests.get(url)
drug_names = response.json()['drugs']

# keep only drug names that are 1 words (avoid the complication of I-Drug data)
clean_drugs = [d for d in drug_names if (' ' not in d) and ('-' not in d)]

# Replace Drug[letter] with random real drugs
drug_mask = deduplicated_df['word'].str.match(r'Drug[A-Z]$')  # Finds DrugA, DrugB, etc.

# select rows where drug_mask is true in the column 'word', replaces drug placeholder with drug name in those cells
deduplicated_df.loc[drug_mask, 'word'] = np.random.choice(clean_drugs, size=drug_mask.sum())

# Update tags to B-DRUG
deduplicated_df.loc[drug_mask, 'tag'] = 'B-DRUG'

# Save corrected data
deduplicated_df.to_csv("ner_data_drugged.csv", index=False)
print(deduplicated_df[deduplicated_df['tag'] == 'B-DRUG'].sample(5))


      sentence_id       word     tag
6076          849    Torisel  B-DRUG
4241          547  Antivenin  B-DRUG
5478          753  Entecavir  B-DRUG
6218          876    Cubicin  B-DRUG
2857          359     Extina  B-DRUG


Make data

In [45]:
# Fetch real drug names from the provided URL
url = "https://gist.githubusercontent.com/ddbeck/56d54331a9c2526ff754/raw/fb8beb235bc97227a983a6d4e9a0067ecf9d29b5/drugs.json"
response = requests.get(url)
drug_names = response.json()['drugs']
# keep only drug names that are 1 words (avoid the complication of I-Drug data)
clean_drugs = [d for d in drug_names if (' ' not in d) and ('-' not in d)]
drug_names = clean_drugs

# Define other components with medical variations
subjects = ["Patients", "Subjects", "Individuals", "Participants", "Volunteers", "Cases", "Users"]
verbs = ["experienced", "reported", "developed", "presented with", "exhibited", "showed", "complained of"]
symptoms = [
    "diarrhea", "constipation", "sweating", "itching", "swelling", 
    "weakness", "tremors", "cramps", "bruising", "bleeding",
    "palpitations", "confusion", "insomnia", "anxiety", "depression",
    "bloating", "gas", "dehydration", "syncope", "vertigo"
]
prepositions = ["after", "following", "within hours of", "within days of", "post", "subsequent to"]
actions = ["administration", "ingestion", "consumption", "receipt", "intake", "dosage", "treatment with"]
dosages = [
    "20mg", "40mg", "80mg", "120mg", "150mg", "300mg", "400mg", 
    "600mg", "750mg", "1g", "2.5mg", "12.5mg", "15mg", "30mg", 
    "45mg", "60mg", "0.5mg", "0.25mg", "5ml", "10ml"
]
timeframes = ["1 hour", "2 days", "24 hours", "3 weeks", "several minutes", "a few days", "48 hours", "5 days"]

# More diverse sentence structures
structures = [
    lambda: f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(symptoms)} {random.choice(prepositions)} {random.choice(actions)} of {random.choice(dosages)} of {random.choice(clean_drugs)}",
    lambda: f"{random.choice(subjects)} {random.choice(verbs)} {random.choice(symptoms)} {random.choice(prepositions)} taking {random.choice(dosages)} of {random.choice(clean_drugs)}",
    lambda: f"Following {random.choice(actions)} of {random.choice(dosages)} {random.choice(clean_drugs)}, {random.choice(subjects)} {random.choice(verbs)} {random.choice(symptoms)}",
    lambda: f"{random.choice(symptoms).capitalize()} was {random.choice(['observed', 'reported'])} in {random.choice(subjects)} {random.choice(prepositions)} {random.choice(actions)} of {random.choice(clean_drugs)} ({random.choice(dosages)})",
    lambda: f"{random.choice(subjects)} on {random.choice(dosages)} {random.choice(clean_drugs)} {random.choice(verbs)} {random.choice(symptoms)} within {random.choice(timeframes)}",
    lambda: f"{random.choice(clean_drugs)} at {random.choice(dosages)} {random.choice(['caused', 'induced', 'led to'])} {random.choice(symptoms)} in {random.choice(subjects)}",
    lambda: f"{random.choice(subjects)} {random.choice(['started', 'began'])} {random.choice(verbs)} {random.choice(symptoms)} {random.choice(prepositions)} {random.choice(clean_drugs)} therapy ({random.choice(dosages)})",
    lambda: f"{random.choice(['The', 'A'])} {random.choice(symptoms)} {random.choice(['occurred', 'appeared'])} {random.choice(prepositions)} {random.choice(clean_drugs)} use ({random.choice(dosages)})"
]

def clean_word(word):
    """Remove surrounding punctuation from a word"""
    return re.sub(r'^[^a-zA-Z0-9]*|[^a-zA-Z0-9]*$', '', word)

def generate_sentence():
    sentence = random.choice(structures)()
    # Clean up punctuation and spacing
    sentence = re.sub(r'\s([?.!,"](?:\s|$))', r'\1', sentence)
    return sentence

def tag_sentence(sentence):
    words = sentence.split()
    tags = ['O'] * len(words)
    
    # Find and tag symptoms (only B- tag as we're not using I- tags)
    for i, word in enumerate(words):
        lower_word = clean_word(word).lower()
        if lower_word in [s.lower() for s in symptoms]:
            tags[i] = 'B-SYMPTOM'
    
    # Find and tag dosages (pattern: number + mg, even with punctuation)
    for i, word in enumerate(words):
        clean_w = clean_word(word)
        # Matches: 45mg, (45mg), 1g, (1g), etc.
        if re.match(r'^\(?\d+\.?\d*[a-zA-Z]+\)?$', clean_w.lower()):
            tags[i] = 'B-DOSAGE'
    
    # Find and tag drugs (case-sensitive match with cleaned word)
    for i, word in enumerate(words):
        clean_w = clean_word(word)
        if clean_w in clean_drugs:
            tags[i] = 'B-DRUG'
    
    return words, tags

def generate_data(num_sentences, start_id=1):
    data = []
    sentence_id = start_id
    
    for _ in range(num_sentences):
        sentence = generate_sentence()
        words, tags = tag_sentence(sentence)
        
        for word, tag in zip(words, tags):
            data.append({
                'sentence_id': sentence_id,
                'word': word,
                'tag': tag
            })
        
        sentence_id += 1
    
    return pd.DataFrame(data)

# Generate 200 sentences for better variety
df = generate_data(500, start_id=1001)

# Verify we only have the specified tags
assert set(df['tag'].unique()) == {'O', 'B-SYMPTOM', 'B-DOSAGE', 'B-DRUG'}, "Invalid tags detected"

# Save to CSV
df.to_csv('synthetic_medical_ner_data_iob.csv', index=False)

print("Sample generated data:")
print(df.head(20))
print("\nUnique tags in dataset:", df['tag'].unique())
print("\nSample drug names used:", random.sample(drug_names, 10))

# Append to your existing data
combined_df = pd.concat([deduplicated_df, df], ignore_index=True)
combined_df.to_csv('deduplicated_suppl_ner_data.csv', index=False)

Sample generated data:
    sentence_id          word        tag
0          1001     Following          O
1          1001   consumption          O
2          1001            of          O
3          1001         750mg   B-DOSAGE
4          1001    Streptase,     B-DRUG
5          1001      Subjects          O
6          1001   experienced          O
7          1001       vertigo  B-SYMPTOM
8          1002      Bleeding  B-SYMPTOM
9          1002           was          O
10         1002      reported          O
11         1002            in          O
12         1002  Participants          O
13         1002          post          O
14         1002       receipt          O
15         1002            of          O
16         1002         BiDil     B-DRUG
17         1002        (80mg)   B-DOSAGE
18         1003  Participants          O
19         1003            on          O

Unique tags in dataset: ['O' 'B-DOSAGE' 'B-DRUG' 'B-SYMPTOM']

Sample drug names used: ['Acthrel', 'Riomet', 'Atrop

Now that the data is different enough, we can begin work on the NER model

# DEEPSEEK GENERATED CODE

In [46]:
# 1. Imports
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, precision_recall_fscore_support
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForTokenClassification,
    Trainer,
    TrainingArguments,
    DataCollatorForTokenClassification
)

import torch
from datasets import Dataset
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')


def load_conll_data(file_path):
    """Load CoNLL format data and convert to sentences with labels"""
    # Read the data
    ner_df = pd.read_csv(file_path)

    # Group by sentence_id to reconstruct sentences
    sentences = []
    labels = []

    for sentence_id in ner_df['sentence_id'].unique():
        sentence_data = ner_df[ner_df['sentence_id'] == sentence_id]
        sentence_words = sentence_data['word'].tolist()
        sentence_tags = sentence_data['tag'].tolist()

        sentences.append(sentence_words)
        labels.append(sentence_tags)

    return sentences, labels


sentences, labels = load_conll_data("deduplicated_suppl_ner_data.csv")

# 3. Analyze Label Distribution
def analyze_labels(labels):
    """Analyze the distribution of entity labels"""
    all_labels = [label for sentence_labels in labels for label in sentence_labels]
    label_counts = pd.Series(all_labels).value_counts()

    print("\n📈 Label Distribution:")
    for label, count in label_counts.items():
        print(f"  {label}: {count}")

    return label_counts

label_counts = analyze_labels(labels)


📈 Label Distribution:
  O: 7545
  B-SYMPTOM: 1325
  B-DOSAGE: 1325
  B-DRUG: 1325


In [47]:
# 4. Create Label Mappings
def create_label_mappings(labels):
    """Create mappings between labels and ids"""
    unique_labels = set(label for sentence_labels in labels for label in sentence_labels)
    label_to_id = {label: i for i, label in enumerate(sorted(unique_labels))}
    id_to_label = {i: label for label, i in label_to_id.items()}

    print(f"\n🏷️  Found {len(unique_labels)} unique labels:")
    for label, id in sorted(label_to_id.items()):
        print(f"  {label}: {id}")

    return label_to_id, id_to_label

label_to_id, id_to_label = create_label_mappings(labels)


🏷️  Found 4 unique labels:
  B-DOSAGE: 0
  B-DRUG: 1
  B-SYMPTOM: 2
  O: 3


In [48]:
# 5. Tokenization and Alignment
def tokenize_and_align_labels(sentences, labels, tokenizer, label_to_id):
    """Tokenize sentences and align labels with subword tokens"""
    tokenized_inputs = tokenizer(
        sentences,
        truncation=True,
        padding=True,
        is_split_into_words=True,
        return_tensors="pt"
    )

    aligned_labels = []

    for i, sentence_labels in enumerate(labels):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []

        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens get -100 label (ignored in loss)
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # First token of a word gets the actual label
                label_ids.append(label_to_id[sentence_labels[word_idx]])
            else:
                # Subsequent tokens of the same word get -100 (ignored)
                label_ids.append(-100)
            previous_word_idx = word_idx

        aligned_labels.append(label_ids)

    return tokenized_inputs, aligned_labels

# Initialize tokenizer
print("\n🔤 Initializing tokenizer...")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

# 6. Train-Test Split
print("✂️  Splitting data...")
train_sentences, val_sentences, train_labels, val_labels = train_test_split(
    sentences, labels, test_size=0.2, random_state=42
)

print(f"🚂 Training sentences: {len(train_sentences)}")
print(f"🔍 Validation sentences: {len(val_sentences)}")


# 7. Tokenize and align labels
print("🔤 Tokenizing and aligning labels...")
train_tokenized, train_aligned_labels = tokenize_and_align_labels(
    train_sentences, train_labels, tokenizer, label_to_id
)
val_tokenized, val_aligned_labels = tokenize_and_align_labels(
    val_sentences, val_labels, tokenizer, label_to_id
)


# 8. Create Dataset Class
class NERDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

    def __len__(self):
        return len(self.labels)

# Create datasets
train_dataset = NERDataset(train_tokenized, train_aligned_labels)
val_dataset = NERDataset(val_tokenized, val_aligned_labels)

# 9. Initialize Model
print("🤖 Initializing DistilBERT model...")
model = DistilBertForTokenClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=len(label_to_id),
    id2label=id_to_label,
    label2id=label_to_id
)

# 10. Training Arguments
training_args = TrainingArguments(
    output_dir="/tmp/ner_results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="/tmp/ner_logs",
    logging_steps=10,
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

# 11. Data Collator
data_collator = DataCollatorForTokenClassification(tokenizer)

# 12. Custom Metrics Function
def compute_metrics(eval_pred):
    """Compute entity-level precision, recall, and F1 score"""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [id_to_label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [id_to_label[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    # Flatten for sklearn metrics
    flat_true_labels = [label for sentence in true_labels for label in sentence]
    flat_predictions = [pred for sentence in true_predictions for pred in sentence]

    # Calculate metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        flat_true_labels, flat_predictions, average='weighted', zero_division=0
    )

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

# 13. Initialize Trainer
print("👨‍🏫 Initializing trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


🔤 Initializing tokenizer...
✂️  Splitting data...
🚂 Training sentences: 1060
🔍 Validation sentences: 265
🔤 Tokenizing and aligning labels...
🤖 Initializing DistilBERT model...


Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


👨‍🏫 Initializing trainer...


In [None]:
# 14. Train the Model
print("🚀 Starting training...")
trainer.train()

# 15. Final Evaluation
print("\n📊 Final Evaluation:")
final_metrics = trainer.evaluate()
for key, value in final_metrics.items():
    print(f"  {key}: {value:.4f}")

🚀 Starting training...


[34m[1mwandb[0m: Currently logged in as: [33mmichellew[0m ([33mmichellew-mcmaster-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.012,0.00491,0.999569,0.999568,0.999568
2,0.0027,0.002577,0.999569,0.999568,0.999568
3,0.0021,0.002258,0.999569,0.999568,0.999568



📊 Final Evaluation:


  eval_loss: 0.0023
  eval_precision: 0.9996
  eval_recall: 0.9996
  eval_f1: 0.9996
  eval_runtime: 1.2744
  eval_samples_per_second: 207.9340
  eval_steps_per_second: 13.3390
  epoch: 3.0000


In [50]:
# 16. Entity-Level Evaluation Function
def evaluate_entities(sentences, true_labels, predictions, id_to_label):
    """Extract and evaluate entities at the entity level"""
    def extract_entities(tokens, labels):
        entities = []
        current_entity = []
        current_label = None

        for token, label in zip(tokens, labels):
            if label.startswith('B-'):  # Beginning of entity
                if current_entity:  # Save previous entity
                    entities.append((' '.join(current_entity), current_label))
                current_entity = [token]
                current_label = label[2:]  # Remove B- prefix
            elif label.startswith('I-') and current_label == label[2:]:  # Inside entity
                current_entity.append(token)
            else:  # Outside entity or different entity
                if current_entity:
                    entities.append((' '.join(current_entity), current_label))
                current_entity = []
                current_label = None

        if current_entity:  # Don't forget the last entity
            entities.append((' '.join(current_entity), current_label))

        return entities

    true_entities = []
    pred_entities = []

    for sent_tokens, sent_true, sent_pred in zip(sentences, true_labels, predictions):
        # Convert predictions back to labels
        sent_pred_labels = [id_to_label[p] for p in sent_pred if p != -100]
        sent_true_labels = [l for l in sent_true if l != 'O']  # Filter out O labels for entity extraction

        true_ents = extract_entities(sent_tokens, sent_true)
        pred_ents = extract_entities(sent_tokens, sent_pred_labels)

        true_entities.extend(true_ents)
        pred_entities.extend(pred_ents)

    return true_entities, pred_entities


In [51]:
# 17. Prediction Function
# Outputs just extracted entities
def predict_entities(text, model, tokenizer, label_to_id, id_to_label):
    """Predict entities in a given text"""
    # Tokenize the input
    tokens = text.split()
    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", truncation=True, padding=True)

    # Get predictions
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2)

    # Align predictions with original tokens
    word_ids = inputs.word_ids()
    aligned_predictions = []

    previous_word_idx = None
    for word_idx, pred_id in zip(word_ids, predictions[0]):
        if word_idx is not None and word_idx != previous_word_idx:
            aligned_predictions.append(id_to_label[pred_id.item()])
        previous_word_idx = word_idx

    # Extract entities
    entities = []
    current_entity = []
    current_label = None

    for token, label in zip(tokens, aligned_predictions):
        if label.startswith('B-'):
            if current_entity:
                entities.append({
                    'text': ' '.join(current_entity),
                    'label': current_label,
                    'start': tokens.index(current_entity[0]),
                    'end': tokens.index(current_entity[-1]) + 1
                })
            current_entity = [token]
            current_label = label[2:]
        elif label.startswith('I-') and current_label == label[2:]:
            current_entity.append(token)
        else:
            if current_entity:
                entities.append({
                    'text': ' '.join(current_entity),
                    'label': current_label,
                    'start': tokens.index(current_entity[0]),
                    'end': tokens.index(current_entity[-1]) + 1
                })
            current_entity = []
            current_label = None

    if current_entity:
        entities.append({
            'text': ' '.join(current_entity),
            'label': current_label,
            'start': tokens.index(current_entity[0]),
            'end': tokens.index(current_entity[-1]) + 1
        })

    return entities

In [52]:
# 18. Test the Model with Example
print("\n🧪 Testing with example sentence:")
example_text = "Patients were given 50mg of Aspirin and developed rash"
predicted_entities = predict_entities(example_text, model, tokenizer, label_to_id, id_to_label)

print(f"Input: {example_text}")
print("Predicted Entities:")
for entity in predicted_entities:
    print(f"  - {entity['label']}: {entity['text']}")


🧪 Testing with example sentence:
Input: Patients were given 50mg of Aspirin and developed rash
Predicted Entities:
  - DOSAGE: 50mg
  - DRUG: Aspirin
  - SYMPTOM: rash


In [53]:

# 19. Detailed Entity-Level Evaluation
print("\n📈 Detailed Entity-Level Evaluation:")

# Get predictions for validation set
val_predictions = trainer.predict(val_dataset)
val_pred_labels = np.argmax(val_predictions.predictions, axis=2)

# Filter out -100 labels and convert to entity format
filtered_predictions = []
filtered_true_labels = []

for i, (pred_seq, true_seq) in enumerate(zip(val_pred_labels, val_aligned_labels)):
    pred_filtered = [id_to_label[p] for p, t in zip(pred_seq, true_seq) if t != -100]
    true_filtered = [id_to_label[t] for t in true_seq if t != -100]

    filtered_predictions.append(pred_filtered)
    filtered_true_labels.append(true_filtered)

# Calculate entity-level metrics by entity type
entity_types = ['DRUG', 'SYMPTOM', 'DOSAGE']
print("\nEntity-Level Metrics by Type:")

for entity_type in entity_types:
    true_entities = []
    pred_entities = []

    for sent_idx, (sent_tokens, true_labels, pred_labels) in enumerate(
        zip(val_sentences, val_labels, filtered_predictions)
    ):
        # Extract entities of this type
        def extract_entities_of_type(tokens, labels, target_type):
            entities = []
            current_entity = []

            for token, label in zip(tokens, labels):
                if label == f'B-{target_type}':
                    if current_entity:
                        entities.append(' '.join(current_entity))
                    current_entity = [token]
                elif label == f'I-{target_type}' and current_entity:
                    current_entity.append(token)
                else:
                    if current_entity:
                        entities.append(' '.join(current_entity))
                        current_entity = []

            if current_entity:
                entities.append(' '.join(current_entity))

            return entities

        true_ents = extract_entities_of_type(sent_tokens, true_labels, entity_type)
        pred_ents = extract_entities_of_type(sent_tokens, pred_labels, entity_type)

        true_entities.extend([(ent, sent_idx) for ent in true_ents])
        pred_entities.extend([(ent, sent_idx) for ent in pred_ents])

    # Calculate precision, recall, F1
    true_set = set(true_entities)
    pred_set = set(pred_entities)

    if len(pred_set) > 0:
        precision = len(true_set & pred_set) / len(pred_set)
    else:
        precision = 0.0

    if len(true_set) > 0:
        recall = len(true_set & pred_set) / len(true_set)
    else:
        recall = 0.0

    if precision + recall > 0:
        f1 = 2 * (precision * recall) / (precision + recall)
    else:
        f1 = 0.0

    print(f"  {entity_type}:")
    print(f"    Precision: {precision:.4f}")
    print(f"    Recall: {recall:.4f}")
    print(f"    F1-Score: {f1:.4f}")
    print(f"    True entities: {len(true_set)}")
    print(f"    Predicted entities: {len(pred_set)}")

print("\n✅ NER Model Training Complete!")
print("🎯 Model successfully trained to extract Drug Names, Symptoms, and Dosages")


📈 Detailed Entity-Level Evaluation:

Entity-Level Metrics by Type:
  DRUG:
    Precision: 0.9962
    Recall: 1.0000
    F1-Score: 0.9981
    True entities: 265
    Predicted entities: 266
  SYMPTOM:
    Precision: 1.0000
    Recall: 1.0000
    F1-Score: 1.0000
    True entities: 265
    Predicted entities: 265
  DOSAGE:
    Precision: 1.0000
    Recall: 1.0000
    F1-Score: 1.0000
    True entities: 265
    Predicted entities: 265

✅ NER Model Training Complete!
🎯 Model successfully trained to extract Drug Names, Symptoms, and Dosages


Example Sentences

In [56]:
examples = [
    {
        "text": "500mg Tylenol every 6h for fever, but dizziness occurred"
    },
    {
        "text": "Patients experienced cough after administration of 75mg of Multaq"
    },
    {
        "text": "Aspirin 81mg daily caused GI bleeding and tinnitus"
    },
    {
        "text": "Overdose on CaCO3 (calcium carbonate): vomiting, drowsiness"
    },
    {
        "text": "Lisinoprol 10mg led to dry cough and fatigue"
    },
    {
        "text": "Janumet XR 50mg/1000mg BID caused diarrhea"
    },
    {
        "text": "Street drug 'Molly' induced hyperthermia and seizures"
    },
    {
        "text": "No ibuprofen use, but naproxen 250mg caused dyspepsia"
    }
]

for example in examples:
    print("\n" + "="*50)
    print(f"Input Text: {example['text']}")
    
    # Get predictions
    predicted_entities = predict_entities(
        text=example["text"],
        model=model,
        tokenizer=tokenizer,
        label_to_id=label_to_id,  # Ensure this is defined
        id_to_label=id_to_label   # Ensure this is defined
    )
    
    # Print predicted entities
    print("\nPredicted Entities:")
    for ent in predicted_entities:
        print(f"- {ent['label']}: '{ent['text']}' (positions {ent['start']}-{ent['end']})")
    
    print("="*50)


Input Text: 500mg Tylenol every 6h for fever, but dizziness occurred

Predicted Entities:
- DOSAGE: '500mg' (positions 0-1)
- DRUG: 'Tylenol' (positions 1-2)
- DOSAGE: '6h' (positions 3-4)
- SYMPTOM: 'fever,' (positions 5-6)
- SYMPTOM: 'dizziness' (positions 7-8)

Input Text: Patients experienced cough after administration of 75mg of Multaq

Predicted Entities:
- SYMPTOM: 'cough' (positions 2-3)
- DOSAGE: '75mg' (positions 6-7)
- DRUG: 'Multaq' (positions 8-9)

Input Text: Aspirin 81mg daily caused GI bleeding and tinnitus

Predicted Entities:
- DRUG: 'Aspirin' (positions 0-1)
- DOSAGE: '81mg' (positions 1-2)
- SYMPTOM: 'GI' (positions 4-5)
- SYMPTOM: 'bleeding' (positions 5-6)
- DRUG: 'tinnitus' (positions 7-8)

Input Text: Overdose on CaCO3 (calcium carbonate): vomiting, drowsiness

Predicted Entities:
- DRUG: 'CaCO3' (positions 2-3)
- DRUG: 'carbonate):' (positions 4-5)
- SYMPTOM: 'vomiting,' (positions 5-6)
- SYMPTOM: 'drowsiness' (positions 6-7)

Input Text: Lisinoprol 10mg led t

In [55]:
# Save the model
model.save_pretrained("clinical_ner_model_v2")
tokenizer.save_pretrained("clinical_ner_model_v2")

('clinical_ner_model_v2\\tokenizer_config.json',
 'clinical_ner_model_v2\\special_tokens_map.json',
 'clinical_ner_model_v2\\vocab.txt',
 'clinical_ner_model_v2\\added_tokens.json',
 'clinical_ner_model_v2\\tokenizer.json')