## Setup and Imports

In [None]:
import json
import os
import numpy as np
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForTokenClassification, 
    TrainingArguments, 
    Trainer,
    DataCollatorForTokenClassification
)
from torch.utils.data import Dataset
import pandas as pd
from tqdm import tqdm

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

print("Setup complete")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

In [None]:
# Define entity labels
ENTITY_LABELS = [
    "anatomical location",
    "animal",
    "bacteria",
    "biomedical technique",
    "chemical",
    "DDF",
    "dietary supplement",
    "drug",
    "food",
    "gene",
    "human",
    "microbiome",
    "statistical technique"
]

# Create BIO tags for each entity label
label_list = ['O']  # Outside
for entity_label in ENTITY_LABELS:
    label_list.append(f'B-{entity_label}')  # Beginning
    label_list.append(f'I-{entity_label}')  # Inside

label2id = {k: v for v, k in enumerate(label_list)}
id2label = {v: k for v, k in enumerate(label_list)}

print(f"Total labels: {len(label_list)}")
print(f"\nFirst 10 labels: {label_list[:10]}")

# Model configuration
model_name = "dmis-lab/biobert-v1.1"  # BioBERT for biomedical text
output_model_dir = "models/bert_biomedbert_ner"

print(f"\nModel: {model_name}")
print(f"Output directory: {output_model_dir}")

## Data Loading Functions

In [None]:
def load_ner_data(file_paths):
    """
    Load NER data from multiple JSON files.
    Each file contains documents with entities.
    """
    all_data = {}
    
    for file_path in file_paths:
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
            all_data.update(data)
            print(f"Loaded {len(data)} documents from {os.path.basename(file_path)}")
        else:
            print(f"Warning: {file_path} not found")
    
    return all_data


def prepare_documents_for_ner(data):
    """
    Convert raw data into structured format for NER.
    Each document has title and abstract as separate text segments.
    """
    documents = []
    
    for pmid, article in data.items():
        # Process title
        title_text = article['metadata']['title']
        title_entities = [e for e in article['entities'] if e['location'] == 'title']
        
        documents.append({
            'pmid': pmid,
            'location': 'title',
            'text': title_text,
            'entities': title_entities
        })
        
        # Process abstract
        abstract_text = article['metadata']['abstract']
        abstract_entities = [e for e in article['entities'] if e['location'] == 'abstract']
        
        documents.append({
            'pmid': pmid,
            'location': 'abstract',
            'text': abstract_text,
            'entities': abstract_entities
        })
    
    return documents


print("✓ Data loading functions defined")

## Load Training and Dev Data

In [None]:
# Load training data from three quality levels
train_files = [
    "../data/Annotations/Train/gold_quality/json_format/train_gold.json",
    "../data/Annotations/Train/platinum_quality/json_format/train_platinum.json",
    "../data/Annotations/Train/silver_quality/json_format/train_silver.json"
]

train_data = load_ner_data(train_files)
train_documents = prepare_documents_for_ner(train_data)

print(f"\nTotal training documents: {len(train_documents)}")
print(f"Total training text segments: {len(train_documents)}")

In [None]:
# Load dev data
dev_data = load_ner_data(["../data/Annotations/Dev/json_format/dev.json"])
dev_documents = prepare_documents_for_ner(dev_data)

print(f"Total dev documents: {len(dev_documents)}")

In [None]:
# Show example document
example_doc = train_documents[10]
print(f"Example document:")
print(f"  PMID: {example_doc['pmid']}")
print(f"  Location: {example_doc['location']}")
print(f"  Text: {example_doc['text'][:200]}...")
print(f"  Number of entities: {len(example_doc['entities'])}")
print(f"\nFirst 3 entities:")
for entity in example_doc['entities'][:3]:
    print(f"    - '{entity['text_span']}' [{entity['label']}] @ {entity['start_idx']}-{entity['end_idx']}")

## Initialize BERT Model and Tokenizer

In [None]:
# Initialize tokenizer and model
print("Initializing BERT tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(
    model_name, 
    num_labels=len(label_list), 
    id2label=id2label, 
    label2id=label2id
)

print(f"✓ Tokenizer loaded: {tokenizer.__class__.__name__}")
print(f"✓ Model loaded with {model.num_labels} labels")
print(f"  Vocabulary size: {tokenizer.vocab_size}")

# Test tokenization
sample_text = "The gut microbiome plays a role in Parkinson's disease."
tokens = tokenizer.tokenize(sample_text)
print(f"\nSample tokenization: {tokens}")

## BIO Tag Generation for Training Data

In [None]:
def align_labels_with_tokens(text, entities, tokenizer, label2id):
    """
    Create BIO tags for tokenized text based on character-level entity annotations.
    Uses offset mapping to align character positions with token positions.
    """
    # Tokenize and get offset mapping
    encoding = tokenizer(
        text,
        return_offsets_mapping=True,
        add_special_tokens=True,
        truncation=True,
        max_length=512
    )
    
    tokens = tokenizer.convert_ids_to_tokens(encoding['input_ids'])
    offset_mapping = encoding['offset_mapping']
    
    # Initialize all labels as 'O' (outside)
    labels = ['O'] * len(tokens)
    
    # Sort entities by start position to handle overlaps
    sorted_entities = sorted(entities, key=lambda e: (e['start_idx'], -(e['end_idx'] - e['start_idx'])))
    
    # Track which tokens have been labeled (to handle overlaps)
    labeled_positions = set()
    
    for entity in sorted_entities:
        entity_start = entity['start_idx']
        entity_end = entity['end_idx']
        entity_label = entity['label']
        
        # Find tokens that overlap with this entity
        entity_token_start = None
        entity_token_end = None
        
        for idx, (token_start, token_end) in enumerate(offset_mapping):
            # Skip special tokens
            if token_start == 0 and token_end == 0:
                continue
            
            # Check if token overlaps with entity
            if token_start < entity_end and token_end > entity_start:
                if entity_token_start is None:
                    entity_token_start = idx
                entity_token_end = idx
        
        # Apply BIO tagging
        if entity_token_start is not None and entity_token_end is not None:
            for idx in range(entity_token_start, entity_token_end + 1):
                # Only label if not already labeled (handle overlaps)
                if idx not in labeled_positions:
                    if idx == entity_token_start:
                        labels[idx] = f'B-{entity_label}'
                    else:
                        labels[idx] = f'I-{entity_label}'
                    labeled_positions.add(idx)
    
    # Convert labels to IDs
    label_ids = [label2id.get(label, label2id['O']) for label in labels]
    
    return {
        'input_ids': encoding['input_ids'],
        'attention_mask': encoding['attention_mask'],
        'labels': label_ids,
        'tokens': tokens
    }


print("✓ BIO tag generation function defined")

## Process Training and Dev Data with BIO Tags

In [None]:
# Process training data
print("Processing training data...")
processed_train = []

for i, doc in enumerate(tqdm(train_documents, desc="Processing train")):
    processed = align_labels_with_tokens(
        doc['text'],
        doc['entities'],
        tokenizer,
        label2id
    )
    processed['pmid'] = doc['pmid']
    processed['location'] = doc['location']
    processed['text'] = doc['text']
    processed['entities'] = doc['entities']
    processed_train.append(processed)

print(f"✓ Training data processed: {len(processed_train)} segments")

In [None]:
# Process dev data
print("Processing dev data...")
processed_dev = []

for i, doc in enumerate(tqdm(dev_documents, desc="Processing dev")):
    processed = align_labels_with_tokens(
        doc['text'],
        doc['entities'],
        tokenizer,
        label2id
    )
    processed['pmid'] = doc['pmid']
    processed['location'] = doc['location']
    processed['text'] = doc['text']
    processed['entities'] = doc['entities']
    processed_dev.append(processed)

print(f"✓ Dev data processed: {len(processed_dev)} segments")

In [None]:
# Show example with BIO tags
example_idx = 10
example = processed_train[example_idx]

print(f"Example from training data:")
print(f"  Text: {example['text'][:150]}...")
print(f"  Entities: {len(example['entities'])}")
print(f"\nToken-Label pairs (first 30):")

token_label_pairs = []
for token, label_id in zip(example['tokens'][:30], example['labels'][:30]):
    label = id2label[label_id]
    token_label_pairs.append((token, label))

df = pd.DataFrame(token_label_pairs, columns=['Token', 'Label'])
print(df.to_string(index=False))

## Prepare Dataset for BERT Training

In [None]:
class NERDataset(Dataset):
    """Custom dataset for NER token classification."""
    
    def __init__(self, processed_data, max_length=512):
        self.data = processed_data
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Pad or truncate to max_length
        input_ids = item['input_ids'][:self.max_length]
        attention_mask = item['attention_mask'][:self.max_length]
        labels = item['labels'][:self.max_length]
        
        # Pad if necessary
        padding_length = self.max_length - len(input_ids)
        if padding_length > 0:
            input_ids = input_ids + [tokenizer.pad_token_id] * padding_length
            attention_mask = attention_mask + [0] * padding_length
            labels = labels + [-100] * padding_length  # -100 is ignored by loss
        
        return {
            'input_ids': torch.tensor(input_ids, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'labels': torch.tensor(labels, dtype=torch.long)
        }


print("✓ Custom dataset class defined")

In [None]:
# Create datasets
print("Creating training datasets...")

train_dataset = NERDataset(processed_train)
dev_dataset = NERDataset(processed_dev)

print(f"✓ Training dataset: {len(train_dataset)} examples")
print(f"✓ Dev dataset: {len(dev_dataset)} examples")

## Configure Training Arguments

In [None]:
# Setup data collator for token classification
data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt"
)
print("✓ Data collator initialized")

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir=output_model_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    logging_steps=100,
    save_total_limit=2,
    seed=42,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

print("✓ Training configuration ready")
print(f"  Batch size: {training_args.per_device_train_batch_size}")
print(f"  Epochs: {training_args.num_train_epochs}")
print(f"  Learning rate: {training_args.learning_rate}")

## Train BERT Model

Note: This cell might take several minutes to hours depending on dataset size and hardware.

**Additional configurations to test:**
- Change hyperparameters (*learning_rate*, *batch_size*, *num_train_epochs*, *weight_decay*)
- Try different pre-trained models (e.g., "allenai/scibert_scivocab_uncased", "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
- Experiment with max_length for longer contexts

In [None]:
# Initialize Trainer
print("Initializing Trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator
)

print("✓ Trainer initialized")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Evaluation samples: {len(dev_dataset)}")

In [None]:
# Start training
print("="*60)
print("Starting model training...")
print("="*60)

import time
training_start_time = time.time()

train_result = trainer.train()

training_duration = time.time() - training_start_time

print("\n" + "="*60)
print("✓ TRAINING COMPLETED!")
print("="*60)
print(f"Training time: {training_duration/60:.2f} minutes")

## Save Trained Model

In [None]:
# Save the trained model
print("Saving trained model...")

os.makedirs(output_model_dir, exist_ok=True)
trainer.save_model(output_model_dir)
tokenizer.save_pretrained(output_model_dir)

print(f"✓ Model saved to: {output_model_dir}")

## Load Model for Inference

In [None]:
# Load the trained model for inference
print("Loading trained model for inference...")

inference_model = AutoModelForTokenClassification.from_pretrained(output_model_dir)
inference_tokenizer = AutoTokenizer.from_pretrained(output_model_dir)
inference_model.eval()

if torch.cuda.is_available():
    inference_model = inference_model.cuda()

print(f"✓ Model loaded from: {output_model_dir}")

## Inference Function

In [None]:
def predict_entities(model, tokenizer, text, id2label):
    """
    Perform NER inference on a single text.
    Returns list of entities with their positions and labels.
    """
    # Tokenize
    encoding = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        return_offsets_mapping=True,
        max_length=512
    )
    
    offset_mapping = encoding.pop('offset_mapping')[0].numpy()
    
    # Move to GPU if available
    if torch.cuda.is_available():
        encoding = {k: v.cuda() for k, v in encoding.items()}
    
    # Predict
    with torch.no_grad():
        outputs = model(**encoding)
        predictions = torch.argmax(outputs.logits, dim=-1)[0].cpu().numpy()
    
    # Convert predictions to labels
    predicted_labels = [id2label[pred] for pred in predictions]
    
    # Extract entities from BIO tags
    entities = []
    current_entity = None
    
    for idx, (label, (start_char, end_char)) in enumerate(zip(predicted_labels, offset_mapping)):
        # Skip special tokens
        if start_char == 0 and end_char == 0:
            continue
        
        if label.startswith('B-'):
            # Save previous entity if exists
            if current_entity:
                entities.append(current_entity)
            
            # Start new entity
            entity_label = label[2:]  # Remove 'B-' prefix
            current_entity = {
                'start_idx': start_char,
                'end_idx': end_char,
                'label': entity_label,
                'text_span': text[start_char:end_char]
            }
        
        elif label.startswith('I-') and current_entity:
            # Extend current entity
            entity_label = label[2:]  # Remove 'I-' prefix
            if entity_label == current_entity['label']:
                current_entity['end_idx'] = end_char
                current_entity['text_span'] = text[current_entity['start_idx']:end_char]
        
        else:
            # Outside or label mismatch - save current entity
            if current_entity:
                entities.append(current_entity)
                current_entity = None
    
    # Save last entity if exists
    if current_entity:
        entities.append(current_entity)
    
    return entities


print("✓ Inference function defined")

## Predict on Dev Set

In [None]:
# Run inference on all dev data
print("Running inference on dev set...")

# Group predictions by document
predictions = {}

for doc in tqdm(dev_documents, desc="Predicting"):
    pmid = doc['pmid']
    location = doc['location']
    text = doc['text']
    
    # Predict entities
    predicted_entities = predict_entities(
        inference_model,
        inference_tokenizer,
        text,
        id2label
    )
    
    # Add location to each entity
    for entity in predicted_entities:
        entity['location'] = location
    
    # Initialize document if not exists
    if pmid not in predictions:
        predictions[pmid] = {'entities': []}
    
    # Add entities to document
    predictions[pmid]['entities'].extend(predicted_entities)

print(f"✓ Inference completed: {len(predictions)} documents")
total_entities = sum(len(p['entities']) for p in predictions.values())
print(f"  Total entities predicted: {total_entities}")

## Save Predictions

In [None]:
# Save predictions to file
output_path = "predictions/bert_ner_predictions.json"

with open(output_path, "w", encoding="utf-8") as f:
    json.dump(predictions, f, ensure_ascii=False, indent=2)

print(f"Predictions saved to {output_path}")

## Evaluate Performance

In [None]:
# Load evaluation functions from evaluate.py concepts
def remove_duplicated_entities(predictions):
    """Remove duplicated entities from predictions."""
    removed_count = 0
    for pmid in list(predictions.keys()):
        seen = set()
        deduped = []
        for ent in predictions[pmid]["entities"]:
            key = (ent["start_idx"], ent["end_idx"], ent["location"])
            if key not in seen:
                seen.add(key)
                deduped.append(ent)
            else:
                removed_count += 1
        predictions[pmid]["entities"] = deduped
    
    if removed_count > 0:
        print(f"Removed {removed_count} duplicated entities from predictions")

def remove_overlapping_entities_eval(predictions):
    """Remove overlapping entities, keeping longest spans."""
    removed_count = 0

    for pmid in list(predictions.keys()):
        original_len = len(predictions[pmid]['entities'])
        
        groups = {'title': [], 'abstract': []}
        for ent in predictions[pmid]['entities']:
            loc = ent["location"]
            groups[loc].append(ent)

        keepers = set()
        for loc in groups:
            group = groups[loc]
            group = sorted(group, key=lambda e: e["start_idx"])

            clusters = []
            cluster = []
            current_end = None

            for ent in group:
                if not cluster:
                    cluster = [ent]
                    current_end = ent["end_idx"]
                else:
                    if ent["start_idx"] < current_end:
                        cluster.append(ent)
                        if ent["end_idx"] > current_end:
                            current_end = ent["end_idx"]
                    else:
                        clusters.append(cluster)
                        cluster = [ent]
                        current_end = ent["end_idx"]
            if cluster:
                clusters.append(cluster)

            for clust in clusters:
                longest = clust[0]
                max_len = longest["end_idx"] - longest["start_idx"]
                for ent in clust[1:]:
                    length = ent["end_idx"] - ent["start_idx"]
                    if length > max_len:
                        longest = ent
                        max_len = length
                keepers.add((longest["start_idx"],
                             longest["end_idx"],
                             longest["location"]))

        deduped = []
        for ent in predictions[pmid]['entities']:
            key = (ent["start_idx"], ent["end_idx"], ent["location"])
            if key in keepers:
                deduped.append(ent)
                keepers.remove(key)

        predictions[pmid]["entities"] = deduped
        removed_count += (original_len - len(deduped))

    if removed_count > 0:
        print(f"Removed {removed_count} overlapping entities")

print("✓ Evaluation helper functions defined")

In [None]:
def evaluate_ner(predictions, ground_truth):
    """Evaluate NER predictions against ground truth."""
    # Remove duplicated and overlapping entities
    remove_duplicated_entities(predictions)
    remove_overlapping_entities_eval(predictions)
    
    LEGAL_ENTITY_LABELS = [
        "anatomical location", "animal", "bacteria", "biomedical technique",
        "chemical", "DDF", "dietary supplement", "drug", "food", "gene",
        "human", "microbiome", "statistical technique"
    ]
    
    ground_truth_NER = dict()
    count_annotated_entities_per_label = {}
    
    for pmid, article in ground_truth.items():
        if pmid not in ground_truth_NER:
            ground_truth_NER[pmid] = []
        for entity in article['entities']:
            start_idx = int(entity["start_idx"])
            end_idx = int(entity["end_idx"])
            location = str(entity["location"])
            text_span = str(entity["text_span"])
            label = str(entity["label"]) 
            
            entry = (start_idx, end_idx, location, text_span, label)
            ground_truth_NER[pmid].append(entry)
            
            if label not in count_annotated_entities_per_label:
                count_annotated_entities_per_label[label] = 0
            count_annotated_entities_per_label[label] += 1

    count_predicted_entities_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}
    count_true_positives_per_label = {label: 0 for label in list(count_annotated_entities_per_label.keys())}

    for pmid in predictions.keys():
        entities = predictions[pmid]['entities']
        
        for entity in entities:
            start_idx = int(entity["start_idx"])
            end_idx = int(entity["end_idx"])
            location = str(entity["location"])
            text_span = str(entity["text_span"])
            label = str(entity["label"]) 
            
            if label not in LEGAL_ENTITY_LABELS:
                continue

            if label in count_predicted_entities_per_label:
                count_predicted_entities_per_label[label] += 1

            entry = (start_idx, end_idx, location, text_span, label)
            if pmid in ground_truth_NER and entry in ground_truth_NER[pmid]:
                count_true_positives_per_label[label] += 1

    count_annotated_entities = sum(count_annotated_entities_per_label.values())
    count_predicted_entities = sum(count_predicted_entities_per_label.values())
    count_true_positives = sum(count_true_positives_per_label.values())

    micro_precision = count_true_positives / (count_predicted_entities + 1e-10)
    micro_recall = count_true_positives / (count_annotated_entities + 1e-10)
    micro_f1 = 2 * ((micro_precision * micro_recall) / (micro_precision + micro_recall + 1e-10))

    precision, recall, f1 = 0, 0, 0
    n = len(count_annotated_entities_per_label)
    for label in count_annotated_entities_per_label.keys():
        current_precision = count_true_positives_per_label[label] / (count_predicted_entities_per_label[label] + 1e-10) 
        current_recall = count_true_positives_per_label[label] / (count_annotated_entities_per_label[label] + 1e-10) 
        
        precision += current_precision
        recall += current_recall
        f1 += 2 * ((current_precision * current_recall) / (current_precision + current_recall + 1e-10))
    
    precision = precision / n
    recall = recall / n
    f1 = f1 / n

    return precision, recall, f1, micro_precision, micro_recall, micro_f1


# Evaluate
precision, recall, f1, micro_precision, micro_recall, micro_f1 = evaluate_ner(predictions, dev_data)

print("="*60)
print("BERT NER BASELINE RESULTS")
print("="*60)
print("\nMacro-averaged Metrics:")
print(f"  Macro-Precision: {precision:.4f}")
print(f"  Macro-Recall:    {recall:.4f}")
print(f"  Macro-F1 Score:  {f1:.4f}")

print("\nMicro-averaged Metrics:")
print(f"  Micro-Precision: {micro_precision:.4f}")
print(f"  Micro-Recall:    {micro_recall:.4f}")
print(f"  Micro-F1 Score:  {micro_f1:.4f}")
print("="*60)

## Example Predictions

In [None]:
# Show example predictions
print("Example Predictions:\n")

sample_pmids = list(dev_data.keys())[:5]

for pmid in sample_pmids:
    article = dev_data[pmid]
    pred = predictions[pmid]
    
    print(f"Document PMID: {pmid}")
    print(f"Title: {article['metadata']['title'][:100]}...")
    print(f"\nGold entities: {len(article['entities'])}")
    print(f"Predicted entities: {len(pred['entities'])}")
    
    # Show first few predicted entities
    print("\nSample predictions:")
    for entity in pred['entities'][:5]:
        print(f"  - '{entity['text_span']}' [{entity['label']}] in {entity['location']}")
    
    # Calculate match statistics
    gold_set = set()
    for entity in article['entities']:
        gold_set.add((
            entity['start_idx'],
            entity['end_idx'],
            entity['location'],
            entity['text_span'],
            entity['label']
        ))
    
    pred_set = set()
    for entity in pred['entities']:
        pred_set.add((
            entity['start_idx'],
            entity['end_idx'],
            entity['location'],
            entity['text_span'],
            entity['label']
        ))
    
    correct = len(gold_set & pred_set)
    missed = len(gold_set - pred_set)
    wrong = len(pred_set - gold_set)
    
    print(f"\n✓ Correct: {correct}")
    print(f"✗ Missed: {missed}")
    print(f"✗ Wrong: {wrong}")
    print("-" * 80)
    print()

## Analysis: Entity Distribution by Label

In [None]:
from collections import Counter

# Count entities by label in predictions
pred_label_counts = Counter()
for pmid, pred in predictions.items():
    for entity in pred['entities']:
        pred_label_counts[entity['label']] += 1

# Count entities by label in gold standard
gold_label_counts = Counter()
for pmid, article in dev_data.items():
    for entity in article['entities']:
        gold_label_counts[entity['label']] += 1

print("Entity Distribution by Label:")
print("="*60)
print(f"{'Label':<25} {'Gold':<10} {'Predicted':<10}")
print("-"*60)

all_labels = set(gold_label_counts.keys()) | set(pred_label_counts.keys())
for label in sorted(all_labels):
    print(f"{label:<25} {gold_label_counts[label]:<10} {pred_label_counts[label]:<10}")

print("-"*60)
print(f"{'TOTAL':<25} {sum(gold_label_counts.values()):<10} {sum(pred_label_counts.values()):<10}")