# Healthcare NER Fine-tuning Pipeline
## NCBI Disease (EN) + Quaero French Med (FR)

**Setup:** Runtime → Change runtime type → T4 GPU → Save

In [1]:
# Install dependencies
!pip install -q datasets transformers accelerate seqeval scikit-learn

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [2]:
# Imports
import torch
from datasets import load_dataset, concatenate_datasets, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
from seqeval.metrics import classification_report, f1_score, precision_score, recall_score
import numpy as np
from collections import Counter

print(f"GPU Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

GPU Available: True
Device: Tesla T4
Memory: 15.83 GB


## 1. Load Datasets (FIXED)

Using correct dataset paths that don't require loading scripts.

In [3]:
!wget -O quaero.zip https://quaerofrenchmed.limsi.fr/QUAERO_FrenchMed_brat.zip
!unzip -o quaero.zip -d quaero_data


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/12146146.ann  
 extracting: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/12146146.txt  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1215988.ann  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1215988.txt  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1217891.ann  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1217891.txt  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1219280.ann  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1219280.txt  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1223130.ann  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1223130.ann~  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1223130.txt  
  inflating: quaero_data/QUAERO_FrenchMed/corpus/dev/MEDLINE/1224062.ann  
  inflating: quaero_data/QUAERO_

In [4]:
!wget https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBI_corpus.zip -O ncbi_disease.zip
!unzip -o ncbi_disease.zip -d ncbi_disease


--2026-01-22 10:54:18--  https://www.ncbi.nlm.nih.gov/CBBresearch/Dogan/DISEASE/NCBI_corpus.zip
Resolving www.ncbi.nlm.nih.gov (www.ncbi.nlm.nih.gov)... 130.14.29.110, 2607:f220:41e:4290::110
Connecting to www.ncbi.nlm.nih.gov (www.ncbi.nlm.nih.gov)|130.14.29.110|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 368345 (360K) [application/zip]
Saving to: ‘ncbi_disease.zip’


2026-01-22 10:54:19 (1.27 MB/s) - ‘ncbi_disease.zip’ saved [368345/368345]

Archive:  ncbi_disease.zip
  inflating: ncbi_disease/NCBI_corpus_development.txt  
  inflating: ncbi_disease/NCBI_corpus_testing.txt  
  inflating: ncbi_disease/NCBI_corpus_training.txt  


In [9]:
import re

def parse_ncbi_with_tags(filepath):
    """Parse NCBI format: PMID\ttext with <category="type">entity</category> tags"""
    examples = []

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            # Split PMID and text
            parts = line.split('\t', 1)
            if len(parts) < 2:
                continue

            pmid = parts[0]
            text_with_tags = parts[1]

            # Find all entities with their positions
            # Pattern: <category="Type">entity text</category>
            pattern = r'<category="[^"]*">([^<]+)</category>'

            # First pass: extract plain text and entity spans
            entities = []
            for match in re.finditer(pattern, text_with_tags):
                entity_text = match.group(1)
                start_in_tagged = match.start()
                entities.append((start_in_tagged, entity_text))

            # Remove all tags to get clean text
            clean_text = re.sub(r'<category="[^"]*">', '', text_with_tags)
            clean_text = re.sub(r'</category>', '', clean_text)

            # Tokenize (simple whitespace split)
            tokens = clean_text.split()
            tags = ['O'] * len(tokens)

            # Map character positions to token indices
            char_to_token = {}
            char_pos = 0
            for token_idx, token in enumerate(tokens):
                for i in range(len(token)):
                    char_to_token[char_pos + i] = token_idx
                char_pos += len(token) + 1  # +1 for space

            # Mark entities in BIO format
            for entity_start_tagged, entity_text in entities:
                # Find position in clean text
                # Count characters before this entity (excluding tags)
                text_before = text_with_tags[:entity_start_tagged]
                clean_before = re.sub(r'<category="[^"]*">', '', text_before)
                clean_before = re.sub(r'</category>', '', clean_before)
                start_pos = len(clean_before)
                end_pos = start_pos + len(entity_text)

                # Find token indices
                start_token = char_to_token.get(start_pos)
                end_token = char_to_token.get(end_pos - 1)

                if start_token is not None:
                    tags[start_token] = 'B-Disease'
                    if end_token is not None and end_token > start_token:
                        for i in range(start_token + 1, min(end_token + 1, len(tags))):
                            tags[i] = 'I-Disease'

            if tokens:
                examples.append({'tokens': tokens, 'ner_tags': tags})

    return examples

# Parse NCBI files
print("Parsing NCBI Disease corpus...")
ncbi_train_data = parse_ncbi_with_tags('/content/ncbi_disease/NCBI_corpus_training.txt')
ncbi_test_data = parse_ncbi_with_tags('/content/ncbi_disease/NCBI_corpus_testing.txt')

print(f"✓ NCBI Train: {len(ncbi_train_data)} examples")
print(f"✓ NCBI Test: {len(ncbi_test_data)} examples")

if ncbi_train_data:
    print(f"\nFirst example:")
    print(f"  Tokens: {ncbi_train_data[0]['tokens'][:15]}")
    print(f"  Tags: {ncbi_train_data[0]['ner_tags'][:15]}")

    # Count entities
    train_entities = sum(1 for ex in ncbi_train_data for tag in ex['ner_tags'] if tag.startswith('B-'))
    print(f"\nTotal disease entities in train: {train_entities}")

Parsing NCBI Disease corpus...
✓ NCBI Train: 593 examples
✓ NCBI Test: 100 examples

First example:
  Tokens: ['Identification', 'of', 'APC2,', 'a', 'homologue', 'of', 'the', 'adenomatous', 'polyposis', 'coli', 'tumour', 'suppressor', '.', 'The', 'adenomatous']
  Tags: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Disease', 'I-Disease', 'I-Disease', 'I-Disease', 'O', 'O', 'O', 'B-Disease']

Total disease entities in train: 5130


In [16]:
import re
import os

def parse_brat_folder_correct(folder_path):
    """Parse BRAT - extract DISO (disorder) entities only"""
    examples = []

    txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]

    for txt_file in txt_files:
        txt_path = os.path.join(folder_path, txt_file)
        ann_path = os.path.join(folder_path, txt_file.replace('.txt', '.ann'))

        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Tokenize preserving character positions
        tokens = []
        token_spans = []
        for match in re.finditer(r"\S+", text):
            tokens.append(match.group())
            token_spans.append((match.start(), match.end()))

        tags = ['O'] * len(tokens)

        # Read annotations - ONLY DISO entities
        if os.path.exists(ann_path):
            with open(ann_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.startswith('T'):
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            ann_info = parts[1].split()
                            if len(ann_info) >= 3:
                                entity_type = ann_info[0]

                                # ONLY keep DISO (disorder/disease)
                                if entity_type == 'DISO':
                                    try:
                                        start = int(ann_info[1])
                                        end_str = ann_info[2].split(';')[0]  # Handle discontinuous
                                        end = int(end_str)

                                        # Tag tokens in this span
                                        first_token = True
                                        for token_idx, (tok_start, tok_end) in enumerate(token_spans):
                                            # Token overlaps entity
                                            if tok_start < end and tok_end > start:
                                                if first_token:
                                                    tags[token_idx] = 'B-Disease'
                                                    first_token = False
                                                else:
                                                    tags[token_idx] = 'I-Disease'
                                    except (ValueError, IndexError):
                                        continue

        if tokens:
            examples.append({'tokens': tokens, 'ner_tags': tags})

    return examples

# Parse Quaero folders
print("Re-parsing Quaero with DISO entities...")
quaero_emea = parse_brat_folder_correct('/content/quaero_data/QUAERO_FrenchMed/corpus/train/EMEA')
quaero_medline = parse_brat_folder_correct('/content/quaero_data/QUAERO_FrenchMed/corpus/train/MEDLINE')

quaero_all = quaero_emea + quaero_medline

print(f"✓ EMEA: {len(quaero_emea)} examples")
print(f"✓ MEDLINE: {len(quaero_medline)} examples")
print(f"✓ Total: {len(quaero_all)} examples")

# Count entities BEFORE conversion
entity_count = sum(1 for ex in quaero_all for tag in ex['ner_tags'] if tag == 'B-Disease')
print(f"\nDISO entities (B-Disease tags): {entity_count}")

# Show sample WITH entity
print("\nSample with entity:")
for ex in quaero_all:
    if 'B-Disease' in ex['ner_tags']:
        idx = ex['ner_tags'].index('B-Disease')
        print(f"  Tokens around entity: {ex['tokens'][max(0,idx-3):idx+5]}")
        print(f"  Tags: {ex['ner_tags'][max(0,idx-3):idx+5]}")
        break

# Split 80/20
from sklearn.model_selection import train_test_split
quaero_train_data, quaero_test_data = train_test_split(quaero_all, test_size=0.2, random_state=42)

print(f"\n✓ Quaero Train: {len(quaero_train_data)}")
print(f"✓ Quaero Test: {len(quaero_test_data)}")

# Convert tags to IDs
label_list = ['O', 'B-Disease', 'I-Disease']
label2id = {label: i for i, label in enumerate(label_list)}

def convert_tags_to_ids(examples):
    for example in examples:
        example['ner_tags'] = [label2id[tag] for tag in example['ner_tags']]
    return examples

quaero_train_data = convert_tags_to_ids(quaero_train_data)
quaero_test_data = convert_tags_to_ids(quaero_test_data)

# Create HF Dataset
from datasets import Dataset, DatasetDict

quaero = DatasetDict({
    'train': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in quaero_train_data],
        'ner_tags': [ex['ner_tags'] for ex in quaero_train_data]
    }),
    'test': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in quaero_test_data],
        'ner_tags': [ex['ner_tags'] for ex in quaero_test_data]
    })
})

print("\n" + "="*70)
print("QUAERO RE-LOADED WITH DISO ENTITIES")
print("="*70)
print(f"Quaero: {quaero}")

# Final verification
quaero_entities = sum(1 for ex in quaero['train'] for tag in ex['ner_tags'] if tag == 1)
print(f"\nQuaero train B-Disease entities: {quaero_entities}")

# Show sample
for ex in quaero['train']:
    if 1 in ex['ner_tags']:
        print(f"\nSample: {ex['tokens'][:20]}")
        print(f"Tags: {ex['ner_tags'][:20]}")
        break

Re-parsing Quaero with DISO entities...
✓ EMEA: 11 examples
✓ MEDLINE: 833 examples
✓ Total: 844 examples

DISO entities (B-Disease tags): 1380

Sample with entity:
  Tokens around entity: ['?', 'Prialt', 'est', 'indiqué', 'pour', 'le', 'traitement', 'des']
  Tags: ['O', 'O', 'O', 'B-Disease', 'O', 'O', 'O', 'O']

✓ Quaero Train: 675
✓ Quaero Test: 169

QUAERO RE-LOADED WITH DISO ENTITIES
Quaero: DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 675
    })
    test: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 169
    })
})

Quaero train B-Disease entities: 1023

Sample: ['Hernie', 'de', 'Bochdalek', '.']
Tags: [1, 2, 2, 0]


In [15]:
# Verify entity counts
print("Entity verification:")
print("="*70)

# NCBI
ncbi_train_entities = sum(1 for ex in ncbi['train'] for tag in ex['ner_tags'] if tag == 1)
ncbi_test_entities = sum(1 for ex in ncbi['test'] for tag in ex['ner_tags'] if tag == 1)
print(f"NCBI train B-Disease entities: {ncbi_train_entities}")
print(f"NCBI test B-Disease entities: {ncbi_test_entities}")

# Quaero
quaero_train_entities = sum(1 for ex in quaero['train'] for tag in ex['ner_tags'] if tag == 1)
quaero_test_entities = sum(1 for ex in quaero['test'] for tag in ex['ner_tags'] if tag == 1)
print(f"\nQuaero train B-Disease entities: {quaero_train_entities}")
print(f"Quaero test B-Disease entities: {quaero_test_entities}")

# Show a Quaero example WITH entities
print("\n" + "="*70)
print("Quaero example WITH entities:")
for ex in quaero['train']:
    if 1 in ex['ner_tags']:  # Has B-Disease tag
        # Show tagged entities
        entities = []
        for i, (token, tag) in enumerate(zip(ex['tokens'], ex['ner_tags'])):
            if tag == 1:  # B-Disease
                entity = [token]
                j = i + 1
                while j < len(ex['ner_tags']) and ex['ner_tags'][j] == 2:  # I-Disease
                    entity.append(ex['tokens'][j])
                    j += 1
                entities.append(' '.join(entity))

        print(f"Tokens: {ex['tokens'][:20]}")
        print(f"Tags: {ex['ner_tags'][:20]}")
        print(f"Entities found: {entities[:5]}")
        break

Entity verification:
NCBI train B-Disease entities: 5130
NCBI test B-Disease entities: 955

Quaero train B-Disease entities: 0
Quaero test B-Disease entities: 0

Quaero example WITH entities:


In [18]:
# ============================================================================
# COMPLETE DATA LOADING & PREPROCESSING
# Run this ONCE before Cell 4 (Model Setup)
# ============================================================================

import re
import os
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

# ----------------------------------------------------------------------------
# 1. PARSE NCBI DISEASE (English) - XML tag format
# ----------------------------------------------------------------------------

def parse_ncbi_with_tags(filepath):
    """Parse NCBI format: PMID\ttext with <category="type">entity</category> tags"""
    examples = []

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            parts = line.split('\t', 1)
            if len(parts) < 2:
                continue

            pmid = parts[0]
            text_with_tags = parts[1]

            # Pattern: <category="Type">entity text</category>
            pattern = r'<category="[^"]*">([^<]+)</category>'

            # Extract entities with positions
            entities = []
            for match in re.finditer(pattern, text_with_tags):
                entity_text = match.group(1)
                start_in_tagged = match.start()
                entities.append((start_in_tagged, entity_text))

            # Remove all tags to get clean text
            clean_text = re.sub(r'<category="[^"]*">', '', text_with_tags)
            clean_text = re.sub(r'</category>', '', clean_text)

            # Tokenize
            tokens = clean_text.split()
            tags = ['O'] * len(tokens)

            # Map character positions to token indices
            char_to_token = {}
            char_pos = 0
            for token_idx, token in enumerate(tokens):
                for i in range(len(token)):
                    char_to_token[char_pos + i] = token_idx
                char_pos += len(token) + 1

            # Mark entities in BIO format
            for entity_start_tagged, entity_text in entities:
                text_before = text_with_tags[:entity_start_tagged]
                clean_before = re.sub(r'<category="[^"]*">', '', text_before)
                clean_before = re.sub(r'</category>', '', clean_before)
                start_pos = len(clean_before)
                end_pos = start_pos + len(entity_text)

                start_token = char_to_token.get(start_pos)
                end_token = char_to_token.get(end_pos - 1)

                if start_token is not None:
                    tags[start_token] = 'B-Disease'
                    if end_token is not None and end_token > start_token:
                        for i in range(start_token + 1, min(end_token + 1, len(tags))):
                            tags[i] = 'I-Disease'

            if tokens:
                examples.append({'tokens': tokens, 'ner_tags': tags})

    return examples

print("Parsing NCBI Disease corpus...")
ncbi_train_data = parse_ncbi_with_tags('/content/ncbi_disease/NCBI_corpus_training.txt')
ncbi_test_data = parse_ncbi_with_tags('/content/ncbi_disease/NCBI_corpus_testing.txt')
print(f"✓ NCBI Train: {len(ncbi_train_data)} examples")
print(f"✓ NCBI Test: {len(ncbi_test_data)} examples")

# ----------------------------------------------------------------------------
# 2. PARSE QUAERO (French) - BRAT format with DISO entities
# ----------------------------------------------------------------------------

def parse_brat_folder_correct(folder_path):
    """Parse BRAT - extract DISO (disorder) entities only"""
    examples = []

    txt_files = [f for f in os.listdir(folder_path) if f.endswith('.txt')]

    for txt_file in txt_files:
        txt_path = os.path.join(folder_path, txt_file)
        ann_path = os.path.join(folder_path, txt_file.replace('.txt', '.ann'))

        with open(txt_path, 'r', encoding='utf-8') as f:
            text = f.read()

        # Tokenize preserving character positions
        tokens = []
        token_spans = []
        for match in re.finditer(r"\S+", text):
            tokens.append(match.group())
            token_spans.append((match.start(), match.end()))

        tags = ['O'] * len(tokens)

        # Read annotations - ONLY DISO entities
        if os.path.exists(ann_path):
            with open(ann_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.startswith('T'):
                        parts = line.strip().split('\t')
                        if len(parts) >= 2:
                            ann_info = parts[1].split()
                            if len(ann_info) >= 3:
                                entity_type = ann_info[0]

                                if entity_type == 'DISO':
                                    try:
                                        start = int(ann_info[1])
                                        end_str = ann_info[2].split(';')[0]
                                        end = int(end_str)

                                        # Tag tokens in this span
                                        first_token = True
                                        for token_idx, (tok_start, tok_end) in enumerate(token_spans):
                                            if tok_start < end and tok_end > start:
                                                if first_token:
                                                    tags[token_idx] = 'B-Disease'
                                                    first_token = False
                                                else:
                                                    tags[token_idx] = 'I-Disease'
                                    except (ValueError, IndexError):
                                        continue

        if tokens:
            examples.append({'tokens': tokens, 'ner_tags': tags})

    return examples

print("\nParsing Quaero corpus...")
quaero_emea = parse_brat_folder_correct('/content/quaero_data/QUAERO_FrenchMed/corpus/train/EMEA')
quaero_medline = parse_brat_folder_correct('/content/quaero_data/QUAERO_FrenchMed/corpus/train/MEDLINE')
quaero_all = quaero_emea + quaero_medline
print(f"✓ Quaero Total: {len(quaero_all)} examples")

# Split 80/20
quaero_train_data, quaero_test_data = train_test_split(quaero_all, test_size=0.2, random_state=42)
print(f"✓ Quaero Train: {len(quaero_train_data)}, Test: {len(quaero_test_data)}")

# ----------------------------------------------------------------------------
# 3. CONVERT TO HUGGINGFACE DATASETS
# ----------------------------------------------------------------------------

label_list = ['O', 'B-Disease', 'I-Disease']
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}

def convert_tags_to_ids(examples):
    """Convert string BIO tags to numeric IDs"""
    for example in examples:
        example['ner_tags'] = [label2id[tag] for tag in example['ner_tags']]
    return examples

ncbi_train_data = convert_tags_to_ids(ncbi_train_data)
ncbi_test_data = convert_tags_to_ids(ncbi_test_data)
quaero_train_data = convert_tags_to_ids(quaero_train_data)
quaero_test_data = convert_tags_to_ids(quaero_test_data)

# Create HF Datasets
ncbi = DatasetDict({
    'train': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in ncbi_train_data],
        'ner_tags': [ex['ner_tags'] for ex in ncbi_train_data]
    }),
    'test': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in ncbi_test_data],
        'ner_tags': [ex['ner_tags'] for ex in ncbi_test_data]
    })
})

quaero = DatasetDict({
    'train': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in quaero_train_data],
        'ner_tags': [ex['ner_tags'] for ex in quaero_train_data]
    }),
    'test': Dataset.from_dict({
        'tokens': [ex['tokens'] for ex in quaero_test_data],
        'ner_tags': [ex['ner_tags'] for ex in quaero_test_data]
    })
})

# ----------------------------------------------------------------------------
# 4. VERIFICATION
# ----------------------------------------------------------------------------

print("\n" + "="*70)
print("DATASETS READY!")
print("="*70)
print(f"NCBI: {ncbi}")
print(f"Quaero: {quaero}")
print(f"\nLabel schema: {label_list}")
print(f"label2id: {label2id}")
print(f"id2label: {id2label}")

# Entity counts
ncbi_entities = sum(1 for ex in ncbi['train'] for tag in ex['ner_tags'] if tag == 1)
quaero_entities = sum(1 for ex in quaero['train'] for tag in ex['ner_tags'] if tag == 1)
print(f"\nNCBI train entities: {ncbi_entities}")
print(f"Quaero train entities: {quaero_entities}")

print("\n✅ Ready to proceed to Cell 4 (Model Setup)")

Parsing NCBI Disease corpus...
✓ NCBI Train: 593 examples
✓ NCBI Test: 100 examples

Parsing Quaero corpus...
✓ Quaero Total: 844 examples
✓ Quaero Train: 675, Test: 169

DATASETS READY!
NCBI: DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 593
    })
    test: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 100
    })
})
Quaero: DatasetDict({
    train: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 675
    })
    test: Dataset({
        features: ['tokens', 'ner_tags'],
        num_rows: 169
    })
})

Label schema: ['O', 'B-Disease', 'I-Disease']
label2id: {'O': 0, 'B-Disease': 1, 'I-Disease': 2}
id2label: {0: 'O', 1: 'B-Disease', 2: 'I-Disease'}

NCBI train entities: 5130
Quaero train entities: 1023

✅ Ready to proceed to Cell 4 (Model Setup)


In [None]:
# # FIXED: Use tner namespace for NCBI Disease (no loading script)
# print("Loading NCBI Disease dataset...")
# try:
#     ncbi = load_dataset("tner/ncbi_disease")
#     print("✓ NCBI loaded successfully from tner/ncbi_disease")
# except Exception as e:
#     print(f"Error loading NCBI: {e}")
#     print("Trying alternative source...")
#     # Fallback to direct HF hub
#     ncbi = load_dataset("ncbi/ncbi_disease", trust_remote_code=False)

# print("\nLoading Quaero French Med dataset...")
# try:
#     quaero = load_dataset("qanastek/QUAERO")
#     print("✓ Quaero loaded successfully from qanastek/QUAERO")
# except Exception as e:
#     print(f"Error loading Quaero: {e}")
#     print("Trying alternative...")
#     # Try the path you mentioned
#     quaero = load_dataset("mnaguib/QuaeroFrenchMed")

## 2. Inspect Dataset Structure

**CRITICAL:** Run this cell and examine the output before proceeding.

In [19]:
print("="*70)
print("NCBI DATASET STRUCTURE")
print("="*70)
print(f"Splits: {list(ncbi.keys())}")
print(f"Train size: {len(ncbi['train'])}")
print(f"Test size: {len(ncbi['test'])}")
print(f"\nFeatures: {ncbi['train'].features}")
print(f"\nFirst example:")
print(ncbi['train'][0])

# Extract label names
if 'ner_tags' in ncbi['train'].features:
    ncbi_label_feature = ncbi['train'].features['ner_tags']
    if hasattr(ncbi_label_feature, 'feature'):
        ncbi_labels = ncbi_label_feature.feature.names
        print(f"\nNCBI Label names: {ncbi_labels}")
    else:
        print(f"\nNCBI Label feature type: {type(ncbi_label_feature)}")
        # Try to extract from data
        sample_tags = [ncbi['train'][i]['ner_tags'] for i in range(min(10, len(ncbi['train'])))]
        unique_tags = sorted(set([tag for tags in sample_tags for tag in tags]))
        print(f"Unique tag IDs in first 10 examples: {unique_tags}")

NCBI DATASET STRUCTURE
Splits: ['train', 'test']
Train size: 593
Test size: 100

Features: {'tokens': List(Value('string')), 'ner_tags': List(Value('int64'))}

First example:
{'tokens': ['Identification', 'of', 'APC2,', 'a', 'homologue', 'of', 'the', 'adenomatous', 'polyposis', 'coli', 'tumour', 'suppressor', '.', 'The', 'adenomatous', 'polyposis', 'coli', '(', 'APC', ')', 'tumour-suppressor', 'protein', 'controls', 'the', 'Wnt', 'signalling', 'pathway', 'by', 'forming', 'a', 'complex', 'with', 'glycogen', 'synthase', 'kinase', '3beta', '(', 'GSK-3beta', ')', ',', 'axin', '/', 'conductin', 'and', 'betacatenin', '.', 'Complex', 'formation', 'induces', 'the', 'rapid', 'degradation', 'of', 'betacatenin', '.', 'In', 'colon', 'carcinoma', 'cells', ',', 'loss', 'of', 'APC', 'leads', 'to', 'the', 'accumulation', 'of', 'betacatenin', 'in', 'the', 'nucleus', ',', 'where', 'it', 'binds', 'to', 'and', 'activates', 'the', 'Tcf-4', 'transcription', 'factor', '(', 'reviewed', 'in', '[', '1', ']', '[

AttributeError: 'Value' object has no attribute 'names'

In [None]:
print("="*70)
print("QUAERO DATASET STRUCTURE")
print("="*70)
print(f"Splits: {list(quaero.keys())}")
print(f"Train size: {len(quaero['train'])}")
if 'test' in quaero:
    print(f"Test size: {len(quaero['test'])}")
elif 'validation' in quaero:
    print(f"Validation size: {len(quaero['validation'])}")

print(f"\nFeatures: {quaero['train'].features}")
print(f"\nFirst example:")
print(quaero['train'][0])

# Extract label names
if 'ner_tags' in quaero['train'].features:
    quaero_label_feature = quaero['train'].features['ner_tags']
    if hasattr(quaero_label_feature, 'feature'):
        quaero_labels = quaero_label_feature.feature.names
        print(f"\nQuaero Label names: {quaero_labels}")
    else:
        print(f"\nQuaero Label feature type: {type(quaero_label_feature)}")
        sample_tags = [quaero['train'][i]['ner_tags'] for i in range(min(10, len(quaero['train'])))]
        unique_tags = sorted(set([tag for tags in sample_tags for tag in tags]))
        print(f"Unique tag IDs in first 10 examples: {unique_tags}")

## 3. Define Label Schema

**IMPORTANT:** Modify this cell based on the output above.

Expected formats:
- **NCBI:** `[O, B-Disease, I-Disease]` (BIO tagging)
- **Quaero:** `[O, B-DISORDER, I-DISORDER, B-ANATOMY, I-ANATOMY, ...]`

Strategy: Map Quaero's DISORDER → Disease, ignore ANATOMY/PROCEDURE

In [None]:
# Define unified label schema
label_list = ["O", "B-Disease", "I-Disease"]
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for i, label in enumerate(label_list)}

print(f"Unified label schema: {label_list}")
print(f"Label to ID: {label2id}")

# TODO: Based on the output above, define mappings
# Example for NCBI (if labels are ['O', 'B-Disease', 'I-Disease']):
label_mapping_ncbi = {
    0: 0,  # O → O
    1: 1,  # B-Disease → B-Disease
    2: 2,  # I-Disease → I-Disease
}

# Example for Quaero (MODIFY THIS based on actual labels):
# If Quaero has: ['O', 'B-DISORDER', 'I-DISORDER', 'B-ANATOMY', 'I-ANATOMY']
label_mapping_quaero = {
    0: 0,  # O → O
    1: 1,  # B-DISORDER → B-Disease
    2: 2,  # I-DISORDER → I-Disease
    3: 0,  # B-ANATOMY → O (ignore)
    4: 0,  # I-ANATOMY → O (ignore)
    # Add more mappings as needed based on actual labels
}

print(f"\nNCBI mapping: {label_mapping_ncbi}")
print(f"Quaero mapping: {label_mapping_quaero}")
print("\n⚠️ VERIFY these mappings match your dataset output above!")

## 4. Model & Tokenizer Setup

In [20]:
# Using XLM-RoBERTa-base (better multilingual than mBERT)
model_name = "xlm-roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(f"✓ Loaded tokenizer: {model_name}")
print(f"Vocab size: {len(tokenizer)}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.10M [00:00<?, ?B/s]

✓ Loaded tokenizer: xlm-roberta-base
Vocab size: 250002


In [21]:
# Verify label schema (already defined in previous cell)
print(f"\nLabel schema: {label_list}")
print(f"label2id: {label2id}")
print(f"id2label: {id2label}")


Label schema: ['O', 'B-Disease', 'I-Disease']
label2id: {'O': 0, 'B-Disease': 1, 'I-Disease': 2}
id2label: {0: 'O', 1: 'B-Disease', 2: 'I-Disease'}


## 5. Tokenization & Label Alignment

In [22]:
def tokenize_and_align(examples, label_mapping):
    """
    Tokenize text and align labels to subword tokens.
    Args:
        examples: Batch of examples with 'tokens' and 'ner_tags'
        label_mapping: Dict mapping original labels to unified schema
    """
    tokenized = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        max_length=512,
        padding=False  # Data collator handles padding
    )

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None

        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens (CLS, SEP, PAD) → ignore in loss
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # First subword token of a word → use mapped label
                orig_label = label[word_idx]
                mapped_label = label_mapping.get(orig_label, 0)  # Default to O
                label_ids.append(mapped_label)
            else:
                # Continuation of subword tokens → ignore in loss
                label_ids.append(-100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized["labels"] = labels
    return tokenized

print("✓ Tokenization function defined")

✓ Tokenization function defined


## 6. Prepare Training Data

In [23]:
# # Sample 500 from each (with shuffling for diversity)
# ncbi_train_sample = ncbi['train'].shuffle(seed=42).select(range(min(500, len(ncbi['train']))))
# quaero_train_sample = quaero['train'].shuffle(seed=42).select(range(min(500, len(quaero['train']))))

# print(f"Sampled {len(ncbi_train_sample)} English (NCBI) examples")
# print(f"Sampled {len(quaero_train_sample)} French (Quaero) examples")

# # Tokenize with dataset-specific label mappings
# print("\nTokenizing NCBI dataset...")
# ncbi_tokenized = ncbi_train_sample.map(
#     lambda x: tokenize_and_align(x, label_mapping_ncbi),
#     batched=True,
#     remove_columns=ncbi_train_sample.column_names
# )

# print("Tokenizing Quaero dataset...")
# quaero_tokenized = quaero_train_sample.map(
#     lambda x: tokenize_and_align(x, label_mapping_quaero),
#     batched=True,
#     remove_columns=quaero_train_sample.column_names
# )

# # Combine and shuffle
# train_dataset = concatenate_datasets([ncbi_tokenized, quaero_tokenized]).shuffle(seed=42)
# print(f"\n✓ Combined training set: {len(train_dataset)} samples")

# # Prepare test sets (keep separate by language)
# print("\nPreparing test sets...")
# ncbi_test = ncbi['test'].map(
#     lambda x: tokenize_and_align(x, label_mapping_ncbi),
#     batched=True,
#     remove_columns=ncbi['test'].column_names
# )

# # Handle different split names for Quaero
# if 'test' in quaero:
#     quaero_test = quaero['test'].map(
#         lambda x: tokenize_and_align(x, label_mapping_quaero),
#         batched=True,
#         remove_columns=quaero['test'].column_names
#     )
# elif 'validation' in quaero:
#     quaero_test = quaero['validation'].map(
#         lambda x: tokenize_and_align(x, label_mapping_quaero),
#         batched=True,
#         remove_columns=quaero['validation'].column_names
#     )

# print(f"✓ Test sets ready: EN={len(ncbi_test)}, FR={len(quaero_test)}")

Sampled 500 English (NCBI) examples
Sampled 500 French (Quaero) examples

Tokenizing NCBI dataset...


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

NameError: name 'label_mapping_ncbi' is not defined

In [24]:
# ============================================================================
# TOKENIZATION FUNCTION
# ============================================================================

def tokenize_and_align(examples):
    """Tokenize and align labels - works with numeric IDs directly"""
    tokenized = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        max_length=512,
        padding=False  # Data collator handles padding
    )

    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        label_ids = []
        previous_word_idx = None

        for word_idx in word_ids:
            if word_idx is None:
                # Special tokens (CLS, SEP, PAD) → ignore in loss
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                # First subword token of a word → use label
                label_ids.append(label[word_idx])
            else:
                # Continuation of subword tokens → ignore in loss
                label_ids.append(-100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized["labels"] = labels
    return tokenized

print("✓ Tokenization function defined")

# ============================================================================
# PREPARE TRAINING DATA
# ============================================================================

# Sample 500 from each (with shuffling for diversity)
ncbi_train_sample = ncbi['train'].shuffle(seed=42).select(range(min(500, len(ncbi['train']))))
quaero_train_sample = quaero['train'].shuffle(seed=42).select(range(min(500, len(quaero['train']))))

print(f"\nSampled {len(ncbi_train_sample)} English (NCBI) examples")
print(f"Sampled {len(quaero_train_sample)} French (Quaero) examples")

# Tokenize
print("\nTokenizing NCBI dataset...")
ncbi_tokenized = ncbi_train_sample.map(
    tokenize_and_align,  # No mapping needed - already numeric IDs!
    batched=True,
    remove_columns=ncbi_train_sample.column_names
)

print("Tokenizing Quaero dataset...")
quaero_tokenized = quaero_train_sample.map(
    tokenize_and_align,  # No mapping needed!
    batched=True,
    remove_columns=quaero_train_sample.column_names
)

# Combine and shuffle
from datasets import concatenate_datasets
train_dataset = concatenate_datasets([ncbi_tokenized, quaero_tokenized]).shuffle(seed=42)
print(f"\n✓ Combined training set: {len(train_dataset)} samples")

# Prepare test sets (keep separate by language)
print("\nPreparing test sets...")
ncbi_test = ncbi['test'].map(
    tokenize_and_align,
    batched=True,
    remove_columns=ncbi['test'].column_names
)

quaero_test = quaero['test'].map(
    tokenize_and_align,
    batched=True,
    remove_columns=quaero['test'].column_names
)

print(f"✓ Test sets ready: EN={len(ncbi_test)}, FR={len(quaero_test)}")

✓ Tokenization function defined

Sampled 500 English (NCBI) examples
Sampled 500 French (Quaero) examples

Tokenizing NCBI dataset...


Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Tokenizing Quaero dataset...


Map:   0%|          | 0/500 [00:00<?, ? examples/s]


✓ Combined training set: 1000 samples

Preparing test sets...


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

Map:   0%|          | 0/169 [00:00<?, ? examples/s]

✓ Test sets ready: EN=100, FR=169


## 7. Initialize Model

In [25]:
model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True  # New classification head
)

print(f"✓ Model initialized: {model_name}")
print(f"   Labels: {len(label_list)}")
print(f"   Parameters: {model.num_parameters() / 1e6:.1f}M")

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✓ Model initialized: xlm-roberta-base
   Labels: 3
   Parameters: 277.5M


## 8. Define Evaluation Metrics

In [26]:
def compute_metrics(pred):
    """
    Compute token-level F1, Precision, Recall using seqeval.
    """
    predictions, labels = pred
    predictions = np.argmax(predictions, axis=2)

    # Convert IDs to labels, filtering out ignored tokens (-100)
    true_labels = [
        [id2label[l] for l in label if l != -100]
        for label in labels
    ]
    true_preds = [
        [id2label[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    return {
        "f1": f1_score(true_labels, true_preds),
        "precision": precision_score(true_labels, true_preds),
        "recall": recall_score(true_labels, true_preds)
    }

print("✓ Metrics function ready")

✓ Metrics function ready


## 9. Training Configuration

In [29]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",  # Changed from evaluation_strategy
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_steps=100,
    logging_dir='./logs',
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    fp16=torch.cuda.is_available(),  # Mixed precision if GPU available
    push_to_hub=False,
    report_to="none"  # Disable wandb/tensorboard
)

data_collator = DataCollatorForTokenClassification(tokenizer)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=ncbi_test,  # Evaluate on English during training
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("✓ Trainer initialized")
print(f"   Total training steps: {len(train_dataset) // training_args.per_device_train_batch_size * training_args.num_train_epochs}")

  trainer = Trainer(


✓ Trainer initialized
   Total training steps: 186


## 10. Train Model

In [30]:
print("Starting training...\n")
print("="*70)
train_result = trainer.train()
print("="*70)
print("\n✓ Training complete!")
print(f"Training time: {train_result.metrics['train_runtime']:.2f}s")

Starting training...



Epoch,Training Loss,Validation Loss,F1,Precision,Recall
1,0.7308,0.15397,0.086486,0.277457,0.051227
2,0.1553,0.069744,0.748387,0.699443,0.804696
3,0.0799,0.06436,0.780662,0.746109,0.81857



✓ Training complete!
Training time: 236.13s


## 11. Evaluate on Both Languages

In [31]:
print("\n" + "="*70)
print("ENGLISH (NCBI) TEST SET EVALUATION")
print("="*70)
ncbi_results = trainer.evaluate(ncbi_test)
print(f"\nF1 Score:  {ncbi_results['eval_f1']:.4f}")
print(f"Precision: {ncbi_results['eval_precision']:.4f}")
print(f"Recall:    {ncbi_results['eval_recall']:.4f}")


ENGLISH (NCBI) TEST SET EVALUATION



F1 Score:  0.7807
Precision: 0.7461
Recall:    0.8186


In [32]:
print("\n" + "="*70)
print("FRENCH (QUAERO) TEST SET EVALUATION")
print("="*70)
quaero_results = trainer.evaluate(quaero_test)
print(f"\nF1 Score:  {quaero_results['eval_f1']:.4f}")
print(f"Precision: {quaero_results['eval_precision']:.4f}")
print(f"Recall:    {quaero_results['eval_recall']:.4f}")


FRENCH (QUAERO) TEST SET EVALUATION



F1 Score:  0.4742
Precision: 0.4742
Recall:    0.4742


## 12. Detailed Per-Entity Type Reports

In [33]:
# English detailed report
print("\n" + "="*70)
print("ENGLISH - PER ENTITY TYPE CLASSIFICATION REPORT")
print("="*70)

predictions_ncbi = trainer.predict(ncbi_test)
preds_ncbi = np.argmax(predictions_ncbi.predictions, axis=2)
labels_ncbi = predictions_ncbi.label_ids

true_labels_ncbi = [
    [id2label[l] for l in label if l != -100]
    for label in labels_ncbi
]
true_preds_ncbi = [
    [id2label[p] for (p, l) in zip(pred, label) if l != -100]
    for pred, label in zip(preds_ncbi, labels_ncbi)
]

print(classification_report(true_labels_ncbi, true_preds_ncbi, digits=4))


ENGLISH - PER ENTITY TYPE CLASSIFICATION REPORT


              precision    recall  f1-score   support

     Disease     0.7461    0.8186    0.7807       937

   micro avg     0.7461    0.8186    0.7807       937
   macro avg     0.7461    0.8186    0.7807       937
weighted avg     0.7461    0.8186    0.7807       937



In [34]:
# French detailed report
print("\n" + "="*70)
print("FRENCH - PER ENTITY TYPE CLASSIFICATION REPORT")
print("="*70)

predictions_quaero = trainer.predict(quaero_test)
preds_quaero = np.argmax(predictions_quaero.predictions, axis=2)
labels_quaero = predictions_quaero.label_ids

true_labels_quaero = [
    [id2label[l] for l in label if l != -100]
    for label in labels_quaero
]
true_preds_quaero = [
    [id2label[p] for (p, l) in zip(pred, label) if l != -100]
    for pred, label in zip(preds_quaero, labels_quaero)
]

print(classification_report(true_labels_quaero, true_preds_quaero, digits=4))


FRENCH - PER ENTITY TYPE CLASSIFICATION REPORT


              precision    recall  f1-score   support

     Disease     0.4742    0.4742    0.4742       213

   micro avg     0.4742    0.4742    0.4742       213
   macro avg     0.4742    0.4742    0.4742       213
weighted avg     0.4742    0.4742    0.4742       213



## 13. Cross-Lingual Transfer Analysis (Zero-Shot)

Train on English only → Test on French (no French training data)

In [35]:
print("\n" + "="*70)
print("ZERO-SHOT CROSS-LINGUAL TRANSFER: EN → FR")
print("="*70)

# Train new model on English only
model_zeroshot = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)

trainer_zeroshot = Trainer(
    model=model_zeroshot,
    args=training_args,
    train_dataset=ncbi_tokenized,  # English only!
    eval_dataset=ncbi_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

print("Training on English only...")
trainer_zeroshot.train()

print("\nTesting on French (zero-shot)...")
zeroshot_results = trainer_zeroshot.evaluate(quaero_test)
print(f"\nZero-shot FR F1:  {zeroshot_results['eval_f1']:.4f}")
print(f"Zero-shot FR Precision: {zeroshot_results['eval_precision']:.4f}")
print(f"Zero-shot FR Recall:    {zeroshot_results['eval_recall']:.4f}")

# Compare with bilingual model
print("\n" + "="*70)
print("COMPARISON: Bilingual vs Zero-Shot on French")
print("="*70)
print(f"Bilingual model (EN+FR training): F1 = {quaero_results['eval_f1']:.4f}")
print(f"Zero-shot model (EN only):         F1 = {zeroshot_results['eval_f1']:.4f}")
print(f"Improvement from FR training:      {(quaero_results['eval_f1'] - zeroshot_results['eval_f1']):.4f}")


ZERO-SHOT CROSS-LINGUAL TRANSFER: EN → FR


Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer_zeroshot = Trainer(


Training on English only...


Epoch,Training Loss,Validation Loss,F1,Precision,Recall
1,No log,0.363896,0.0,0.0,0.0
2,0.589200,0.155738,0.0,0.0,0.0
3,0.589200,0.0788,0.6917,0.643974,0.747065



Testing on French (zero-shot)...



Zero-shot FR F1:  0.3902
Zero-shot FR Precision: 0.3697
Zero-shot FR Recall:    0.4131

COMPARISON: Bilingual vs Zero-Shot on French
Bilingual model (EN+FR training): F1 = 0.4742
Zero-shot model (EN only):         F1 = 0.3902
Improvement from FR training:      0.0839


## 14. Save Model & Results

In [36]:
# Save best model
output_dir = "/content/output_dir"
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"✓ Model saved to {output_dir}")

# Save results summary
import json

results_summary = {
    "model": model_name,
    "training_samples": len(train_dataset),
    "english_train": len(ncbi_tokenized),
    "french_train": len(quaero_tokenized),
    "label_schema": label_list,
    "results": {
        "english_test": {
            "f1": ncbi_results['eval_f1'],
            "precision": ncbi_results['eval_precision'],
            "recall": ncbi_results['eval_recall']
        },
        "french_test": {
            "f1": quaero_results['eval_f1'],
            "precision": quaero_results['eval_precision'],
            "recall": quaero_results['eval_recall']
        },
        "zero_shot_french": {
            "f1": zeroshot_results['eval_f1'],
            "precision": zeroshot_results['eval_precision'],
            "recall": zeroshot_results['eval_recall']
        }
    }
}

with open(f"{output_dir}/results_summary.json", "w") as f:
    json.dump(results_summary, f, indent=2)

print("✓ Results summary saved")
print("\n📁 To download: Files → right-click 'multilingual_disease_ner' → Download")

✓ Model saved to /content/output_dir
✓ Results summary saved

📁 To download: Files → right-click 'multilingual_disease_ner' → Download


## 15. Example Predictions

In [37]:
# Test on custom examples
def predict_entities(text, lang="en"):
    """Predict entities in a text string"""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model(**inputs)

    predictions = torch.argmax(outputs.logits, dim=2)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Extract entities
    entities = []
    current_entity = []
    current_label = None

    for token, pred in zip(tokens, predictions[0]):
        label = id2label[pred.item()]

        if token in ["<s>", "</s>", "<pad>"]:
            continue

        if label.startswith("B-"):
            if current_entity:
                entities.append((" ".join(current_entity), current_label))
            current_entity = [token]
            current_label = label[2:]
        elif label.startswith("I-") and current_label:
            current_entity.append(token)
        else:
            if current_entity:
                entities.append((" ".join(current_entity), current_label))
            current_entity = []
            current_label = None

    if current_entity:
        entities.append((" ".join(current_entity), current_label))

    return entities

# Test examples
print("\n" + "="*70)
print("EXAMPLE PREDICTIONS")
print("="*70)

en_example = "The patient was diagnosed with hypertension and diabetes mellitus."
fr_example = "Le patient souffre d'hypertension et de diabète de type 2."

print(f"\nEnglish: {en_example}")
en_entities = predict_entities(en_example, "en")
print(f"Entities: {en_entities}")

print(f"\nFrench: {fr_example}")
fr_entities = predict_entities(fr_example, "fr")
print(f"Entities: {fr_entities}")


EXAMPLE PREDICTIONS

English: The patient was diagnosed with hypertension and diabetes mellitus.
Entities: [('▁hyper tension', 'Disease'), ('▁diabetes ▁mell itus', 'Disease')]

French: Le patient souffre d'hypertension et de diabète de type 2.
Entities: [('hy', 'Disease'), ('per tension', 'Disease'), ('▁di', 'Disease'), ('ab ète ▁de ▁type ▁2.', 'Disease')]
