# 🔗 Shared-Head Multi-Task BioNER with Evaluation + Span Analysis

In [None]:
# !pip install datasets transformers seqeval

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer, AutoModelForTokenClassification,
    Trainer, TrainingArguments, DataCollatorForTokenClassification
)
from sklearn.metrics import classification_report, precision_recall_fscore_support


In [None]:
def read_ner_tsv(path):
    sentences, labels = [], []
    tokens, tags = [], []
    with open(path, 'r') as f:
        for line in f:
            line = line.strip()
            if line == '':
                if tokens:
                    sentences.append(tokens)
                    labels.append(tags)
                    tokens, tags = [], []
            else:
                token, tag = line.split()
                tokens.append(token)
                tags.append(tag)
        if tokens:
            sentences.append(tokens)
            labels.append(tags)
    return {"tokens": sentences, "tags": labels}


In [None]:
base_path = "/content"  # adjust this path
datasets = {}
for name in ["BC5CDR", "BC4CHEMD", "NCBI", "JNLPBA"]:
    datasets[name] = read_ner_tsv(f"{base_path}/{name}-IOBES/train.tsv")

# Combine
all_tokens, all_tags = [], []
for d in datasets.values():
    all_tokens.extend(d["tokens"])
    all_tags.extend(d["tags"])

dataset = DatasetDict({
    "train": Dataset.from_dict({"tokens": all_tokens, "tags": all_tags})
})

label_list = sorted(list({label for seq in all_tags for label in seq}))
label_to_id = {label: i for i, label in enumerate(label_list)}
id_to_label = {i: label for label, i in label_to_id.items()}


In [None]:
model_ckpt = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)

def tokenize_and_align_labels(example):
    tokenized = tokenizer(example["tokens"], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(example["tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label_to_id[label[word_idx]])
            else:
                label_ids.append(label_to_id[label[word_idx]])
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized["labels"] = labels
    return tokenized

tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)


In [None]:
model = AutoModelForTokenClassification.from_pretrained(
    model_ckpt,
    num_labels=len(label_list),
    id2label=id_to_label,
    label2id=label_to_id
)

args = TrainingArguments(
    output_dir="./shared_head_mtl_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="no"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_dataset["train"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
)
trainer.train()


In [None]:
# Load test sets
def load_test_sets(base_path="/content"):
    return {name: read_ner_tsv(f"{base_path}/{name}-IOBES/test.tsv") for name in ["BC5CDR", "BC4CHEMD", "NCBI", "JNLPBA"]}

def tokenize_test_set(test_data):
    dataset = Dataset.from_dict(test_data)
    return dataset.map(tokenize_and_align_labels, batched=True)

test_sets = load_test_sets()
tokenized_test_sets = {k: tokenize_test_set(v) for k, v in test_sets.items()}

def evaluate_model_on_testset(task_name, raw_data, tokenized_data):
    predictions = trainer.predict(tokenized_data)
    pred_ids = np.argmax(predictions.predictions, axis=2)
    labels = predictions.label_ids
    tokens = raw_data["tokens"]

    pred_labels, true_labels = [], []
    for i in range(len(labels)):
        word_pointer = 0
        p_seq, l_seq = [], []
        for j, label_id in enumerate(labels[i]):
            if label_id == -100: continue
            if word_pointer < len(tokens[i]):
                true_label = id_to_label[label_id]
                pred_label = id_to_label[pred_ids[i][j]]
                l_seq.append(true_label)
                p_seq.append(pred_label)
                word_pointer += 1
        true_labels.append(l_seq)
        pred_labels.append(p_seq)

    y_true = [y for x in true_labels for y in x]
    y_pred = [y for x in pred_labels for y in x]
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="macro", zero_division=0)
    print(f"📊 {task_name} — P: {precision:.2f} R: {recall:.2f} F1: {f1:.2f}")
    return precision, recall, f1

results = {}
for task in test_sets:
    p, r, f1 = evaluate_model_on_testset(task, test_sets[task], tokenized_test_sets[task])
    results[task] = {"precision": p, "recall": r, "f1": f1}

# Plot results
tasks = list(results.keys())
f1s = [results[t]["f1"] for t in tasks]
prec = [results[t]["precision"] for t in tasks]
rec = [results[t]["recall"] for t in tasks]

x = range(len(tasks))
plt.figure(figsize=(10, 6))
plt.bar(x, prec, width=0.2, label='Precision', align='center')
plt.bar([p + 0.2 for p in x], rec, width=0.2, label='Recall', align='center')
plt.bar([p + 0.4 for p in x], f1s, width=0.2, label='F1 Score', align='center')
plt.xticks([p + 0.2 for p in x], tasks)
plt.xlabel("Task")
plt.ylabel("Score")
plt.title("📈 Evaluation on Test Sets (Shared-Head MTL)")
plt.legend()
plt.grid(True)
plt.show()


In [None]:
# 📏 Span-Length Sensitivity Analysis

def extract_spans(tokens, labels):
    spans = []
    current = []
    current_type = None
    for i, label in enumerate(labels):
        if label.startswith("B-"):
            if current:
                spans.append((" ".join(current), len(current), current_type))
            current = [tokens[i]]
            current_type = label[2:]
        elif label.startswith("I-") or label.startswith("E-"):
            current.append(tokens[i])
        elif label.startswith("S-"):
            spans.append((tokens[i], 1, label[2:]))
            current = []
        else:
            if current:
                spans.append((" ".join(current), len(current), current_type))
                current = []
    if current:
        spans.append((" ".join(current), len(current), current_type))
    return spans

def span_length_evaluation(task_name, raw_data, tokenized_data):
    predictions = trainer.predict(tokenized_data)
    pred_ids = np.argmax(predictions.predictions, axis=2)
    labels = predictions.label_ids
    tokens_list = raw_data["tokens"]

    true_labels_list, pred_labels_list = [], []

    for i in range(len(labels)):
        word_pointer = 0
        p_seq, l_seq = [], []
        for j, label_id in enumerate(labels[i]):
            if label_id == -100: continue
            if word_pointer < len(tokens_list[i]):
                l_seq.append(id_to_label[label_id])
                p_seq.append(id_to_label[pred_ids[i][j]])
                word_pointer += 1
        true_labels_list.append(l_seq)
        pred_labels_list.append(p_seq)

    length_bins = [1, 2, 3, '4+']
    true_by_len = defaultdict(set)
    pred_by_len = defaultdict(set)

    for idx in range(len(tokens_list)):
        true_spans = extract_spans(tokens_list[idx], true_labels_list[idx])
        pred_spans = extract_spans(tokens_list[idx], pred_labels_list[idx])

        for span in true_spans:
            l = span[1]
            bin_key = l if l <= 3 else '4+'
            true_by_len[bin_key].add((idx, span[0], span[2]))

        for span in pred_spans:
            l = span[1]
            bin_key = l if l <= 3 else '4+'
            pred_by_len[bin_key].add((idx, span[0], span[2]))

    # Metrics per span length
    print(f"\n📊 Span Length Sensitivity for Task: {task_name}")
    print("Length\tPrec\tRec\tF1\tSupport")
    precs, recs, f1s = [], [], []
    for bin_key in length_bins:
        gold = true_by_len[bin_key]
        pred = pred_by_len[bin_key]
        tp = len(gold & pred)
        fp = len(pred - gold)
        fn = len(gold - pred)

        precision = tp / (tp + fp + 1e-8)
        recall = tp / (tp + fn + 1e-8)
        f1 = 2 * precision * recall / (precision + recall + 1e-8)
        support = len(gold)
        precs.append(precision)
        recs.append(recall)
        f1s.append(f1)
        print(f"{bin_key}\t{precision:.2f}\t{recall:.2f}\t{f1:.2f}\t{support}")

    # Plot
    x = range(len(length_bins))
    plt.figure(figsize=(10, 6))
    plt.bar(x, precs, width=0.2, label='Precision', align='center')
    plt.bar([p + 0.2 for p in x], recs, width=0.2, label='Recall', align='center')
    plt.bar([p + 0.4 for p in x], f1s, width=0.2, label='F1 Score', align='center')
    plt.xticks([p + 0.2 for p in x], [str(l) for l in length_bins])
    plt.xlabel("Entity Span Length (tokens)")
    plt.ylabel("Score")
    plt.title(f"Span Length Sensitivity — {task_name}")
    plt.ylim(0, 1.05)
    plt.legend()
    plt.grid(axis='y')
    plt.tight_layout()
    plt.show()

# 🔁 Run span analysis per task
for task in test_sets:
    span_length_evaluation(task, test_sets[task], tokenized_test_sets[task])
