In [1]:
import pandas as pd
import sys
import gc
import os
import spacy
import scispacy
from dateutil import parser
from spacy.matcher import Matcher
import re
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import json
import os, gc, json, re, joblib, warnings
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd

from rapidfuzz import fuzz

# HF / seqeval / sklearn
import datasets as ds
from transformers import (
    AutoTokenizer, AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    TrainingArguments, Trainer, EarlyStoppingCallback, pipeline
)
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report as skl_report

import torch

warnings.filterwarnings("ignore", category=UserWarning)


In [2]:
print(spacy.__version__)
print(scispacy.__version__)
import en_core_web_sm
print(en_core_web_sm.__version__)
import sys
print(sys.executable)

3.4.4
0.5.1
3.4.0
/opt/anaconda3/envs/nlp_clinical/bin/python


In [3]:
def get_data_path(filename):
    """
    Returns the path to a file in the data/clean directory.
    
    Args:
        filename (str): Name of the file (including extension)
    
    Returns:
        str: Full path to the file
    """
    cwd = os.getcwd()
    parent_dir = os.path.dirname(cwd)
    file_path = os.path.join(parent_dir, 'data', 'clean', filename)
    return file_path

In [4]:
df_medical_expanded = get_data_path('df_medical_expanded.csv')
df_medical = pd.read_csv(df_medical_expanded, index_col=0)

In [5]:
df_diagnosis_expanded = get_data_path('df_diagnosis_expanded.csv')
df_diagnosis = pd.read_csv(df_diagnosis_expanded, index_col=0)

In [6]:
df_surgery_expanded = get_data_path('df_surgery_expanded.csv')
df_surgery = pd.read_csv(df_surgery_expanded, index_col=0)

In [7]:
df_symptoms_expanded = get_data_path('df_symptoms_expanded.csv')
df_symptoms = pd.read_csv(df_symptoms_expanded, index_col=0)

In [8]:
df_treatments_expanded = get_data_path('df_treatments_expanded.csv')
df_treatments = pd.read_csv(df_treatments_expanded, index_col=0)

In [9]:
df_info_expanded = get_data_path('df_info_expanded.csv')
df_info = pd.read_csv(df_info_expanded, index_col=0)

### Medical Entity Extractor

In [10]:
class MedicalEntityExtractor:
    def __init__(self, model_name="en_core_sci_sm", batch_size=500):
        """
        Initialize the medical entity extractor.

        Args:
            model_name: ScispaCy model to use. Options:
                       - "en_core_sci_sm" (smaller, faster, generic entities)
                       - "en_core_sci_md" (medium)
                       - "en_ner_bc5cdr_md" (disease/chemical focused - BEST FOR MEDICAL)
                       - "en_ner_bionlp13cg_md" (multiple bio entity types)
            batch_size: Number of texts to process at once
        """
        self.batch_size = batch_size
        self.model_name = model_name
        self.nlp = None

    def load_model(self):
        """# Load a spaCy language model and assign it to self.nlp"""
        print(f"Loading {self.model_name}...")
        try:
            self.nlp = spacy.load(self.model_name)
        except:
            print(f"Model {self.model_name} not found. Installing...")
            import os
            # Fallback to install the appropriate model if fail to load
            if self.model_name == "en_ner_bc5cdr_md":
                os.system("pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_ner_bc5cdr_md-0.5.1.tar.gz")
            elif self.model_name == "en_ner_bionlp13cg_md":
                os.system("pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_ner_bionlp13cg_md-0.5.1.tar.gz")
            else:
                os.system(f"pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/{self.model_name}-0.5.1.tar.gz")
            self.nlp = spacy.load(self.model_name)

        # Disable unnecessary pipeline components to save memory
        pipes_to_disable = ["tagger", "parser", "attribute_ruler", "lemmatizer"]
        for pipe in pipes_to_disable:
            if pipe in self.nlp.pipe_names:
                self.nlp.disable_pipe(pipe)

        print(f"Model loaded. Active pipes: {self.nlp.pipe_names}")

    def process_batch(self, texts, ids):  
        batch_results = []
        docs = list(self.nlp.pipe(texts, batch_size=50)) # Stream & batch process with nlp.pipe is faster than looped nlp(text)
        for doc, original_text, rid in zip(docs, texts, ids):
            for ent in doc.ents: # doc.ents yields entity objects
                batch_results.append({
                    'text': ent.text,
                    'label': ent.label_,
                    'start': ent.start_char,
                    'end': ent.end_char,
                    'original_text': original_text,
                    'row_idx': rid,  
                })
        # Drop large list, nudge Python’s garbage collector 
        del docs 
        gc.collect()
        return batch_results

    def extract_entities_from_df(
        self, df, text_column, *, # * forces keyword-only args after it
        save_checkpoints=True, checkpoint_dir='./ner_checkpoints',
        id_column='idx'                    
    ):
        if self.nlp is None:
            self.load_model()

        texts = df[text_column].fillna('').astype(str).tolist()

        # --- Choose row IDs from the given column; fallback to df.index ---
        if id_column is not None and id_column in df.columns:
            ids_series = df[id_column]
            # if there are NaNs in idx, fall back to positional index for those rows
            ids_series = ids_series.where(pd.notna(ids_series), df.index)
            ids = ids_series.tolist()
            print(f"Stamping row identifier from column: '{id_column}'")
        else:
            ids = df.index.tolist()
            print("Stamping row identifier from DataFrame index")

        total_batches = (len(texts) + self.batch_size - 1) // self.batch_size # Ceiling division: compute total batch count without importing math
        all_results = []

        if save_checkpoints:
            import os
            os.makedirs(checkpoint_dir, exist_ok=True)

        print(f"Processing {len(texts)} texts in {total_batches} batches...")
        print(f"Using model: {self.model_name} for column: {text_column}")

        # Status bar with tqdm
        for batch_start in tqdm(range(0, len(texts), self.batch_size), desc="Processing batches"):
            batch_end   = batch_start + self.batch_size
            batch_texts = texts[batch_start:batch_end]
            batch_ids   = ids[batch_start:batch_end]

            try:
                batch_results = self.process_batch(batch_texts, batch_ids)
                all_results.extend(batch_results)
                # # Periodic checkpoints: every 10th batch (0-based), write all results so far to CSV
                if save_checkpoints and ((batch_start // self.batch_size) % 10 == 0): 
                    checkpoint_df = pd.DataFrame(all_results)
                    safe_column_name = text_column.replace(' ', '_').replace('/', '_')
                    checkpoint_df.to_csv(
                        f"{checkpoint_dir}/checkpoint_{safe_column_name}_{batch_start}.csv",
                        index=False
                    )
                    print(f"\nCheckpoint saved at batch {batch_start}")
            except Exception as e:
                print(f"\nError processing batch {batch_start}: {e}")
                continue

            if (batch_start // self.batch_size) % 5 == 0:
                gc.collect() # Occasional GC hint (every 5 batches) to keep memory tidy

        results_df = pd.DataFrame(all_results)
        results_df['source_column'] = text_column
        return results_df

    def get_entity_summary(self, entities_df):
        """Generate summary statistics of extracted entities."""
        summary = {
            'total_entities': len(entities_df),
            'unique_entities': entities_df['text'].nunique(),
            'entity_types': entities_df['label'].value_counts().to_dict(),
            'top_entities': entities_df['text'].value_counts().head(20).to_dict()
        }
        return summary



In [11]:
# Helper functions for analysis
def analyze_medical_entities(entities_df, min_frequency=5):
    """Analyze extracted entities and create labeling functions."""
    
    # Find frequent medical conditions
    frequent_entities = entities_df['text'].value_counts()
    frequent_entities = frequent_entities[frequent_entities >= min_frequency]
    
    print(f"Found {len(frequent_entities)} entities appearing >= {min_frequency} times")
    
    # Group by entity type
    entity_groups = {}
    for label in entities_df['label'].unique():
        label_entities = entities_df[entities_df['label'] == label]['text'].value_counts()
        entity_groups[label] = label_entities.head(10).to_dict()
    
    return frequent_entities, entity_groups

def create_entity_based_rules(entities_df, target_labels=['DISEASE', 'CHEMICAL'], source_column=None):
    """Create labeling functions based on extracted entities."""
    
    # Get the source column from the entities_df if not provided
    if source_column is None:
        if 'source_column' in entities_df.columns:
            source_column = entities_df['source_column'].iloc[0]
        else:
            raise ValueError("source_column must be provided or available in entities_df")
    
    # Get top entities for each label type
    rules = {}
    
    for label in target_labels:
        if label in entities_df['label'].unique():
            top_entities = entities_df[entities_df['label'] == label]['text'].value_counts().head(20)
            
            # Create a labeling function for this entity type
            def make_labeling_function(entity_list, label_name, column_name):
                def lf(row):
                    # Handle column names with spaces
                    if column_name in row:
                        text = str(row[column_name]).lower()
                    else:
                        # # Try alternative column names
                        # alt_column = column_name.replace('_', ' ')
                        # if alt_column in row:
                        #     text = str(row[alt_column]).lower()
                        # else:
                            return 'ABSTAIN'
                    
                    for entity in entity_list:
                        if entity.lower() in text:
                            return label_name
                    return 'ABSTAIN'
                
                # Create safe function name
                safe_column_name = column_name.replace(' ', '_').replace('/', '_')
                lf.__name__ = f"lf_{label_name.lower()}_{safe_column_name}"
                return lf
            
            rules[f'lf_{label.lower()}_{source_column.replace(" ", "_")}'] = make_labeling_function(
                top_entities.index.tolist(),
                label,
                source_column
            )
    
    return rules

# Main execution function for single column
def run_medical_ner_extraction(df, text_column,
                               model_name="en_ner_bc5cdr_md", batch_size=500,id_column='idx'):
    """
    Complete pipeline to extract medical entities from dataframe.
    
    Args:
        df: dataframe with medical text
        text_column: Column containing the text (NOW REQUIRED, NO DEFAULT)
        model_name: Which ScispaCy model to use
        batch_size: Batch size for processing
    
    Returns:
        entities_df: DataFrame with all extracted entities
        summary: Summary statistics
        rules: Generated labeling functions
    """

    if text_column not in df.columns:
        raise ValueError(f"Column '{text_column}' not found in dataframe.")
    # Wire stages together: instantiate, extract, summarize, build rules, return
    extractor = MedicalEntityExtractor(model_name=model_name, batch_size=batch_size)
    entities_df = extractor.extract_entities_from_df(
        df, text_column, id_column=id_column
    )
    summary = extractor.get_entity_summary(entities_df)
    frequent_entities, entity_groups = analyze_medical_entities(entities_df)
    rules = create_entity_based_rules(entities_df, source_column=text_column)
    return entities_df, summary, rules


#### Extracting Entities From Patient Medical Record

In [12]:
df_medical.head(30)

Unnamed: 0,idx,has_medical_history,physiological context,psychological context,vaccination history,allergies,exercise frequency,nutrition,sexual history,alcohol consumption,drug usage,smoking status
0,155216,True,,Diagnosed with bipolar affective disorder at t...,,,,,,,,
2,133948,True,,Intensifying feelings of helplessness,,,,,,,,
3,80176,True,History of left elbow arthrodesis performed fo...,,,,,,,,,
4,72232,True,,,,,,,,,,
5,31864,True,"Inability to walk since babyhood, did not walk...",,,,,,Got married at the age of 15 and became pregna...,,,
6,26809,True,"Normal Apgar score, no resuscitation required ...",,,,,,,,,
7,149866,True,"Coxa vara deformity of bilateral hips, bilater...",,,,,,,,,
8,87064,True,,Patient could not realize that his symptoms mi...,,,,,,,,
9,123006,True,,,,,,,,,,
10,119317,True,Born at full term by spontaneous vaginal deliv...,Mentally healthy,,,,,,,,


In [13]:
df_medical['psychological context'].value_counts()

psychological context
Depression                                                               85
Bipolar disorder                                                         30
Schizophrenia                                                            29
Anxiety                                                                  27
No significant psychosocial history                                      14
                                                                         ..
Paranoid schizophrenia and bipolar disorder                               1
Mentally retarded with aggressive behavior                                1
Mental retardation grade 2, intelligence quotient of 23                   1
mild phantom limb pain and very frequent nonpainful phantom sensation     1
Exhibiting a tortured expression                                          1
Name: count, Length: 2231, dtype: int64

##### Physiological Contexts

In [14]:
if __name__ == "__main__":
    # Run extraction for the 'physiological context' column
    df_physiological_entities, disease_summary, rules = run_medical_ner_extraction(
        df_medical,
        text_column='physiological context',
        model_name="en_ner_bc5cdr_md",
        batch_size=500,
        id_column='idx')
    
    # Test the generated rules
    print("\nTesting generated labeling functions on df_physiological...")
    
    # Select a sample row from the original df_medical to test the rules
    # Ensure df_medical is not empty before accessing iloc[0]
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]
        
        for rule_name, rule_func in rules.items():
            # Apply the rule function to a row from the original df_medical
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column '{e}' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")
    
    print("\n=== DEBUG: Check extracted entities dataframe ===")
    print(f"Shape of df_physiological_entities: {df_physiological_entities.shape}")
    print(f"\nColumn names: {df_physiological_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_physiological_entities.head())
    
    # Check what labels ScispaCy actually found in the entities dataframe
    print("\n=== Entity Labels Found in df_physiological_entities ===")
    if not df_physiological_entities.empty:
        print(df_physiological_entities['label'].value_counts())
    else:
        print("df_physiological_entities is empty.")
    
    # ============= CUSTOM ANATOMY EXTRACTION =============
    print("\n" + "="*60)
    print("CUSTOM ANATOMY EXTRACTION FOR PHYSIOLOGICAL CONTEXT")
    print("="*60)
    # Comprehensive anatomy patterns
    anatomy_patterns = {
            'organs': [
                'heart', 'lung', 'lungs', 'liver', 'kidney', 'kidneys', 'brain', 
                'pancreas', 'spleen', 'stomach', 'intestine', 'colon', 'bladder',
                'gallbladder', 'thyroid', 'prostate', 'uterus', 'ovary', 'ovaries','esophagus',
                'trachea','larynx','pharynx','adrenal glands','pituitary','hypothalamus','pineal gland',
                'testes','appendix','rectum','ear','eye'
            ],
            'bones': [
                'bone', 'femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna',
                'skull', 'spine', 'vertebra', 'vertebrae', 'rib', 'ribs', 'pelvis',
                'clavicle', 'scapula', 'sternum', 'patella','metacarpal','metatarsal','phalanges','mandible','maxilla',
                'sacrum','coccyx','carpals','tarsals'
            ],
            'joints': [
                'joint', 'hip', 'knee', 'shoulder', 'elbow', 'wrist', 'ankle',
                'knuckle', 'finger', 'toe', 'neck', 'back'
            ],
            'cardiovascular': [
                'artery', 'arteries', 'vein', 'veins', 'vessel', 'vessels',
                'aorta', 'carotid', 'coronary', 'pulmonary', 'cardiac',
                'ventricle', 'atrium', 'valve','capillary','myocardium','pericardium','endocardium','septum',
                'vena cava', 'jugular','femoral','subclavian'
            ],
            'neurological': [
                'nerve', 'nerves', 'neural', 'spinal cord', 'brainstem',
                'cerebral', 'cerebellum', 'cortex', 'lobe', 'ganglion','neuron','axon','dendrite','hypothalamus','thalamus','hippocampus',
                'plexus','peripheral nerve','mengines','dura', 'pia mater'
            ],
            'muscles': [
                'muscle', 'muscles', 'tendon', 'ligament', 'fascia',
                'biceps', 'triceps', 'quadriceps', 'hamstring','deltoid','pectoral','gluteus','abdominal','calf','gastrocnemius','diaphragm','flexor','extensor'
            ],
            'regions': [
                'chest', 'abdomen', 'pelvis', 'thorax', 'cranium',
                'extremity', 'limb', 'upper extremity', 'lower extremity','groin','axilla','mediastinum','peritoneum','retroperitoneum'
            ],
            'tissues': [
                'skin', 'tissue', 'membrane', 'mucosa', 'epithelium',
                'cartilage', 'marrow', 'lymph node', 'gland','blood',
                'plasma','connective tissue','adipose','fascia'
            ]
        }
    # Laterality terms
    laterality_terms = ['left', 'right', 'bilateral', 'unilateral']
    # Direction terms
    direction_terms = ['anterior', 'posterior', 'superior', 'inferior', 'medial', 'lateral', 'proximal', 'distal','ventral', 'dorsal', 
        'cranial', 'caudal', 'superficial', 'deep', 'central', 'peripheral']    
    def extract_anatomy_from_physiological_context(df_medical, text_column, id_column='idx'):
        """Hard-coded weak labelers: Extract anatomy entities from physiological context that BC5CDR might miss"""
         # build the id series once
        if id_column is not None and id_column in df_medical.columns:
            ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
        else:
            ids_series = df_medical.index

        texts = df_medical[text_column].fillna('').astype(str)
        all_anatomy = []
        for category, terms in anatomy_patterns.items():
            # add all terms to anatomy list & sort by length
            all_anatomy.extend(terms)
        all_anatomy.sort(key=len, reverse=True)
        # Build pattern components
        lat_pattern = '|'.join(laterality_terms) # Escape special regex characters and join with |
        dir_pattern = '|'.join(direction_terms)
    
        # Handle multi-word anatomy terms
        anatomy_pattern = '|'.join(re.escape(term).replace(r'\ ', r'\s+') for term in all_anatomy)
        # Build comprehensive regex patterns
        combined_pattern_regex = re.compile(rf'\b(({lat_pattern})\s+)?(({dir_pattern})\s+)?({anatomy_pattern})\b', re.IGNORECASE)        

        # Process each row
        anatomy_entities = []
        for (df_index, text_original), idx in zip(texts.items(), ids_series): # Parallel iteration: texts.items() yields (index, value) pairs. Paired with the IDs via zip
            if text_original:
                text = text_original.lower()

                # Extract anatomy terms
                for match in combined_pattern_regex.finditer(text): 
                    full_term = match.group(0)
                    laterality = match.group(2)  # Laterality
                    direction = match.group(4)  # Direction

                # Normalize multi-space anatomy terms
                anatomy = match.group(5)  
                anatomy_normalized = ' '.join(anatomy.split())

                # Determine category
                category = None
                for cat, terms in anatomy_patterns.items():
                    if any(anatomy_normalized == term.lower() for term in terms):
                        category = cat
                        break
                
                # Determine label based on modifiers present
                if laterality and direction:
                    label = 'ANATOMY_WITH_LATERALITY_AND_DIRECTION'
                elif laterality:
                    label = 'ANATOMY_WITH_LATERALITY'
                elif direction:
                    label = 'ANATOMY_WITH_DIRECTION'
                else:
                    label = 'ANATOMY'
                
                # Build entity record
                entity = {
                    'text': full_term,
                    'label': label,
                    'category': category,
                    'anatomy': anatomy_normalized,
                    'start': match.start(),
                    'end': match.end(),
                    'original_text': text_original,  # Use original case
                    'source': 'custom_anatomy_extraction',
                    'row_idx': idx
                }
                
                # Add modifiers if present
                if laterality:
                    entity['laterality'] = laterality
                if direction:
                    entity['direction'] = direction
                
                anatomy_entities.append(entity)
    
        return pd.DataFrame(anatomy_entities)

                            
    
    # Extract custom anatomy entities
    print("\nExtracting anatomy entities...")
    df_anatomy_custom = extract_anatomy_from_physiological_context(df_medical, text_column='physiological context', id_column='idx')
    
    print(f"\nCustom anatomy extraction found {len(df_anatomy_custom)} anatomy mentions")
    
    if not df_anatomy_custom.empty:
        print("\n=== Anatomy Entity Distribution ===")
        print(df_anatomy_custom['label'].value_counts())
        
        print("\n=== Top Anatomy Terms ===")
        print(df_anatomy_custom['text'].value_counts().head(30))
        
        print("\n=== Anatomy by Category ===")
        print(df_anatomy_custom['category'].value_counts())
        
        # Show examples of laterality
        laterality_examples = df_anatomy_custom[df_anatomy_custom['label'] == 'ANATOMY_WITH_LATERALITY']
        if not laterality_examples.empty:
            print("\n=== Examples of Anatomy with Laterality ===")
            print(laterality_examples[['text', 'laterality', 'anatomy']].head(10))
    
    # Combine BC5CDR entities with custom anatomy entities
    print("\n=== COMBINING BC5CDR AND CUSTOM ENTITIES ===")
    df_all_physiological_entities = pd.concat([
        df_physiological_entities,
        df_anatomy_custom[['text', 'label', 'start', 'end', 'original_text']]
    ], ignore_index=True)
    
    print(f"\nTotal combined entities: {len(df_all_physiological_entities)}")
    print("\nCombined entity distribution:")
    print(df_all_physiological_entities['label'].value_counts())
    
    # Create anatomy-specific labeling functions
    print("\n=== Creating Anatomy-Specific Labeling Functions ===")
    
    def create_anatomy_labeling_functions():
        """Create labeling functions for anatomy patterns"""
        cardiac_terms = anatomy_patterns['cardiovascular']
        msk_terms = anatomy_patterns['muscles'] + anatomy_patterns['bones'] + anatomy_patterns['joints']
        neuro_terms = anatomy_patterns['neurological']  
        laterality = laterality_terms
        def lf_cardiac_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            if any(term in text for term in cardiac_terms):
                return 'CARDIAC_ANATOMY'
            return 'ABSTAIN'
        
        def lf_musculoskeletal(row):
            text = str(row.get('physiological context', '')).lower()
            msk_terms = ['bone', 'joint', 'muscle', 'tendon', 'ligament', 'cartilage', 'diaphysis', 'femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna', 'skull', 
            'spine', 'vertebra', 'vertebrae', 'rib', 'ribs', 'pelvis', 'flexor', 'extensor', 'meniscus']
            if any(term in text for term in msk_terms):
                return 'MUSCULOSKELETAL'
            return 'ABSTAIN'
        
        def lf_bilateral_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            if 'bilateral' in text:
                return 'BILATERAL_CONDITION'
            return 'ABSTAIN'
        
        def lf_neurological_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            neuro_terms = ['brain', 'nerve', 'neural', 'spinal', 'cerebral', 'cerebrum', 'brainstem', 'neuron', 'lobe', 'cerebellum', 'cortex', 'ganglion']
            if any(term in text for term in neuro_terms):
                return 'NEUROLOGICAL_ANATOMY'
            return 'ABSTAIN'
        
        return [lf_cardiac_anatomy, lf_musculoskeletal, lf_bilateral_anatomy, lf_neurological_anatomy]
    
    # Test anatomy labeling functions
    anatomy_lfs = create_anatomy_labeling_functions()
    
    print("\nTesting anatomy labeling functions:")
if not df_medical.empty:
    print("\nApplying labeling functions to all rows...")
    
    for lf in anatomy_lfs:
        # Create cleaner column name
        label_name = lf.__name__.replace('lf_', '')
        
        # Apply to all rows
        df_medical[label_name] = df_medical.apply(lf, axis=1)
        
        # Show statistics
        value_counts = df_medical[label_name].value_counts()
        non_abstain = value_counts.drop('ABSTAIN', errors='ignore')
        
        if not non_abstain.empty:
            print(f"\n{label_name}:")
            print(f"  - Total labeled: {non_abstain.sum()}")
            print(f"  - Distribution: {non_abstain.to_dict()}")
    keep_cols = ['text','label','start','end','original_text','row_idx','category']
    df_all_physiological_entities = pd.concat(
    [df_physiological_entities, df_anatomy_custom[keep_cols]],
    ignore_index=True
)
    
    # Save combined results
    df_all_physiological_entities.to_csv('physiological_entities_comprehensive.csv', index=False)
    df_anatomy_custom.to_csv('anatomy_entities_detailed.csv', index=False)

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 60 batches...
Using model: en_ner_bc5cdr_md for column: physiological context


Processing batches:   2%|▏         | 1/60 [00:01<01:15,  1.28s/it]


Checkpoint saved at batch 0


Processing batches:  18%|█▊        | 11/60 [00:11<00:49,  1.00s/it]


Checkpoint saved at batch 5000


Processing batches:  35%|███▌      | 21/60 [00:20<00:38,  1.00it/s]


Checkpoint saved at batch 10000


Processing batches:  52%|█████▏    | 31/60 [00:30<00:28,  1.01it/s]


Checkpoint saved at batch 15000


Processing batches:  68%|██████▊   | 41/60 [00:42<00:27,  1.47s/it]


Checkpoint saved at batch 20000


Processing batches:  85%|████████▌ | 51/60 [00:53<00:09,  1.04s/it]


Checkpoint saved at batch 25000


Processing batches: 100%|██████████| 60/60 [01:00<00:00,  1.01s/it]


Found 1249 entities appearing >= 5 times

Testing generated labeling functions on df_physiological...
lf_disease_physiological_context applied to row 0: ABSTAIN
lf_chemical_physiological_context applied to row 0: ABSTAIN

=== DEBUG: Check extracted entities dataframe ===
Shape of df_physiological_entities: (36272, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
                      text    label  start  end  \
0  posttraumatic arthritis  DISEASE     48   71   
1                     pain  DISEASE    116  120   
2                 fracture  DISEASE    151  159   
3      Coxa vara deformity  DISEASE      0   19   
4                 fracture  DISEASE     75   83   

                                       original_text  row_idx  \
0  History of left elbow arthrodesis performed fo...    80176   
1  Inability to walk since babyhood, did not walk...    31864   
2  Inability to walk since babyhood, did not walk...    31864   
3  

In [15]:
def _first_hit(text, terms, case_sensitive=False, word_boundaries=True):
    """Helper function to find first matching term in text"""
    if not case_sensitive:
        text = text.lower()
        terms = [t.lower() for t in terms]
    
    for t in terms:
        if word_boundaries:
            import re
            pattern = r'\b' + re.escape(t) + r'\b'
            if re.search(pattern, text):
                return t
        else:
            if t in text:
                return t
    return None

In [16]:

def extract_anatomy_from_physiological_context(df_medical, text_column, id_column='idx'):
    """Extract anatomy entities from physiological context"""
    
    # Build the id series once
    if id_column is not None and id_column in df_medical.columns:
        ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
    else:
        ids_series = df_medical.index

    texts = df_medical[text_column].fillna('').astype(str)
    
    # Comprehensive anatomy patterns
    anatomy_patterns = {
        'organs': [
            'heart', 'lung', 'lungs', 'liver', 'kidney', 'kidneys', 'brain', 
            'pancreas', 'spleen', 'stomach', 'intestine', 'colon', 'bladder',
            'gallbladder', 'thyroid', 'prostate', 'uterus', 'ovary', 'ovaries','esophagus',
            'trachea','larynx','pharynx','adrenal glands','pituitary','hypothalamus','pineal gland',
            'testes','appendix','rectum','ear','eye', 'spinal cord', 'lymph node'
        ],
        'bones': [
            'bone', 'femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna',
            'skull', 'spine', 'vertebra', 'vertebrae', 'rib', 'ribs', 'pelvis',
            'clavicle', 'scapula', 'sternum', 'patella','metacarpal','metatarsal','phalanges',
            'mandible','maxilla', 'sacrum','coccyx','carpals','tarsals'
        ],
        'joints': [
            'joint', 'hip', 'knee', 'shoulder', 'elbow', 'wrist', 'ankle',
            'knuckle', 'finger', 'toe', 'neck', 'back'
        ],
        'cardiovascular': [
            'artery', 'arteries', 'vein', 'veins', 'vessel', 'vessels',
            'aorta', 'carotid', 'coronary', 'pulmonary', 'cardiac',
            'ventricle', 'atrium', 'valve','capillary','myocardium','pericardium',
            'endocardium','septum', 'vena cava', 'jugular','femoral','subclavian'
        ],
        'neurological': [
            'nerve', 'nerves', 'neural', 'brainstem', 'cerebral', 'cerebellum', 
            'cortex', 'lobe', 'ganglion','neuron','axon','dendrite','hypothalamus',
            'thalamus','hippocampus', 'plexus','peripheral nerve','meninges','dura', 'pia mater'
        ],
        'muscles': [
            'muscle', 'muscles', 'tendon', 'ligament', 'fascia',
            'biceps', 'triceps', 'quadriceps', 'hamstring','deltoid','pectoral',
            'gluteus','abdominal','calf','gastrocnemius','diaphragm','flexor','extensor'
        ],
        'regions': [
            'chest', 'abdomen', 'pelvis', 'thorax', 'cranium', 'chest wall',
            'extremity', 'limb', 'upper extremity', 'lower extremity','groin',
            'axilla','mediastinum','peritoneum','retroperitoneum'
        ],
        'tissues': [
            'skin', 'tissue', 'membrane', 'mucosa', 'epithelium',
            'cartilage', 'marrow', 'gland','blood',
            'plasma','connective tissue','adipose'
        ]
    }
    
    # Modifier terms
    laterality_terms = ['left', 'right', 'bilateral', 'unilateral']
    direction_terms = ['anterior', 'posterior', 'superior', 'inferior', 'medial', 
                    'lateral', 'proximal', 'distal','ventral', 'dorsal', 
                    'cranial', 'caudal', 'superficial', 'deep', 'central', 'peripheral']
    
    # Build comprehensive patterns 
    all_anatomy = []
    for category, terms in anatomy_patterns.items():
        all_anatomy.extend(terms)
    all_anatomy.sort(key=len, reverse=True)
    
    # Build pattern components
    lat_pattern = '|'.join(laterality_terms)
    dir_pattern = '|'.join(direction_terms)
    
    # Handle multi-word anatomy terms
    anatomy_pattern = '|'.join(re.escape(term).replace(r'\ ', r'\s+') for term in all_anatomy)
    
    # Build comprehensive regex pattern with named groups
    combined_pattern_regex = re.compile(
        rf'\b(?:(?P<lat>{lat_pattern})\s+)?(?:(?P<dir>{dir_pattern})\s+)?(?P<anatomy>{anatomy_pattern})\b', 
        re.IGNORECASE
    )
    
    # Process each row
    anatomy_entities = []
    
    for (df_index, text_original), idx in zip(texts.items(), ids_series):
        if text_original:
            text = text_original.lower()
            
            # Extract anatomy terms
            for match in combined_pattern_regex.finditer(text):
                full_term = match.group(0)
                laterality = match.group('lat')
                direction = match.group('dir')
                anatomy = match.group('anatomy')
                
                # Normalize multi-space anatomy terms
                anatomy_normalized = ' '.join(anatomy.split())
                
                # Determine category
                category = None
                for cat, terms in anatomy_patterns.items():
                    if any(anatomy_normalized == term.lower() for term in terms):
                        category = cat
                        break
                
                # Determine label based on modifiers present
                if laterality and direction:
                    label = 'ANATOMY_WITH_LATERALITY_AND_DIRECTION'
                elif laterality:
                    label = 'ANATOMY_WITH_LATERALITY'
                elif direction:
                    label = 'ANATOMY_WITH_DIRECTION'
                else:
                    label = 'ANATOMY'
                
                # Build entity record
                entity = {
                    'text': full_term,
                    'label': label,
                    'category': category,
                    'anatomy': anatomy_normalized,
                    'start': match.start(),
                    'end': match.end(),
                    'original_text': text_original,
                    'source': 'custom_anatomy_extraction',
                    'row_idx': idx
                }
                
                # Add modifiers if present
                if laterality:
                    entity['laterality'] = laterality
                if direction:
                    entity['direction'] = direction
                
                anatomy_entities.append(entity)
    
    return pd.DataFrame(anatomy_entities)


def create_anatomy_labeling_functions():
    """Create labeling functions that return evidence for span creation"""
    
    def lf_cardiac_anatomy(row):
        text = str(row.get('physiological context', '')).lower()
        cardiac_terms = ['heart', 'cardiac', 'coronary', 'ventricle', 'atrium', 'valve', 'aorta', 
                        'artery', 'arteries', 'vein', 'veins', 'vessel', 'vessels', 
                        'myocardium', 'pericardium', 'endocardium']
        
        hit = _first_hit(text, cardiac_terms)
        return {
            'label': 'CARDIAC_ANATOMY',
            'column': 'physiological context',
            'match': hit,
            'category': 'cardiovascular'
        } if hit else {'label': 'ABSTAIN'}
    
    
    def lf_musculoskeletal(row):
        text = str(row.get('physiological context', '')).lower()
        msk_terms = ['bone', 'joint', 'muscle', 'tendon', 'ligament', 'cartilage', 'diaphysis', 
                    'femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna', 'skull', 
                    'spine', 'vertebra', 'vertebrae', 'rib', 'ribs', 'pelvis', 
                    'flexor', 'extensor', 'meniscus']
        
        hit = _first_hit(text,  msk_terms)
        return {
            'label': 'MUSCULOSKELETAL',
            'column': 'physiological context',
            'match': hit,
            'category': 'musculoskeletal'
        } if hit else {'label': 'ABSTAIN'}
    
    def lf_bilateral_anatomy(row):
        text = str(row.get('physiological context', '')).lower()
        if 'bilateral' in text:
            # Find what follows bilateral
            bilateral_pattern = r'bilateral\s+(\w+)'
            match = re.search(bilateral_pattern, text)
            if match:
                return {
                    'label': 'BILATERAL_CONDITION',
                    'column': 'physiological context',
                    'match': f'bilateral {match.group(1)}',
                    'category': 'laterality'
                }
            else:
                return {
                    'label': 'BILATERAL_CONDITION',
                    'column': 'physiological context',
                    'match': 'bilateral',
                    'category': 'laterality'
                }
        return {'label': 'ABSTAIN'}
        
    def lf_neurological_anatomy(row):
        text = str(row.get('physiological context', '')).lower()
        neuro_terms = ['brain', 'nerve', 'neural', 'spinal', 'cerebral', 'cerebrum', 'brainstem', 
                    'neuron', 'lobe', 'cerebellum', 'cortex', 'ganglion']
        
        hit = _first_hit(text,  neuro_terms)
        return {
            'label': 'NEUROLOGICAL_ANATOMY',
            'column': 'physiological context',
            'match': hit,
            'category': 'neurological'
        } if hit else {'label': 'ABSTAIN'}
    
    return [lf_cardiac_anatomy, lf_musculoskeletal, lf_bilateral_anatomy, lf_neurological_anatomy]


def materialize_lf_spans(df, labeling_functions, id_column='idx'):
    """Convert labeling function results into span entities"""
    lf_entities = []
    
    # Build the id series once
    if id_column is not None and id_column in df.columns:
        ids_series = df[id_column].where(pd.notna(df[id_column]), df.index)
    else:
        ids_series = df.index
    
    for (df_index, row), idx in zip(df.iterrows(), ids_series):
        for lf in labeling_functions:
            result = lf(row)
            
            if result.get('label') != 'ABSTAIN':
                # Get the text from the specified column
                column_name = result.get('column', 'physiological context')
                full_text = str(row.get(column_name, ''))
                match_term = result['match']
                
                # Find all occurrences of the match term
                pattern = r'\b' + re.escape(match_term) + r'\b'
                
                for match in re.finditer(pattern, full_text.lower()):
                    # Extract the actual text (preserving original case)
                    actual_text = full_text[match.start():match.end()]
                    
                    lf_entities.append({
                        'text': actual_text,
                        'label': result['label'],
                        'category': result.get('category', 'unknown'),
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': full_text,
                        'source': f'lf:{lf.__name__}',
                        'row_idx': idx
                    })
    
    return pd.DataFrame(lf_entities)


# Main execution
if __name__ == "__main__":
    # Extract custom anatomy entities using regex patterns
    print("\nExtracting anatomy entities...")
    df_anatomy_custom = extract_anatomy_from_physiological_context(
        df_medical, 
        text_column='physiological context', 
        id_column='idx'
    )
    
    print(f"\nCustom anatomy extraction found {len(df_anatomy_custom)} anatomy mentions")
    
    # Create and apply labeling functions to generate additional spans
    print("\n=== Creating and Applying Labeling Functions ===")
    anatomy_lfs = create_anatomy_labeling_functions()
    
    # Materialize spans from labeling functions
    df_physiological_lf_spans = materialize_lf_spans(df_medical, anatomy_lfs, id_column='idx')
    print(f"\nLabeling functions generated {len(df_physiological_lf_spans)} span entities")
    
    if not df_physiological_lf_spans.empty:
        print("\n=== LF-Generated Spans Distribution ===")
        print(df_physiological_lf_spans['label'].value_counts())
        print("\n=== Sample LF Spans ===")
        print(df_physiological_lf_spans[['text', 'label', 'source']].head(10))
    
    # Also apply labeling functions for row-level analysis
    print("\n=== Row-Level Analysis from Labeling Functions ===")
    coverage_results = {}
    for lf in anatomy_lfs:
        labeled_count = 0
        for _, row in df_medical.iterrows():
            result = lf(row)
            if result.get('label') != 'ABSTAIN':
                labeled_count += 1
        
        coverage = (labeled_count / len(df_medical) * 100) if len(df_medical) > 0 else 0
        coverage_results[lf.__name__] = {
            'labeled': labeled_count,
            'coverage': coverage
        }
    
    print("\nLabeling function coverage:")
    for lf_name, stats in sorted(coverage_results.items(), key=lambda x: x[1]['coverage'], reverse=True):
        print(f"  {lf_name}: {stats['labeled']} rows ({stats['coverage']:.1f}% coverage)")
    
    # Combine all entity sources
    print("\n=== COMBINING ALL ENTITY SOURCES ===")
    
    # Prepare common columns for concatenation
    common_columns = ['text', 'label', 'start', 'end', 'original_text', 'source']
    if 'row_idx' in df_anatomy_custom.columns:
        common_columns.append('row_idx')
    if 'category' in df_anatomy_custom.columns:
        common_columns.append('category')
    
    # Ensure all dataframes have required columns
    for col in common_columns:
        if col not in df_physiological_entities.columns:
            df_physiological_entities[col] = 'bc5cdr' if col == 'source' else None
        if col not in df_anatomy_custom.columns:
            df_anatomy_custom[col] = 'custom_extraction' if col == 'source' else None
        if col not in df_physiological_lf_spans.columns:
            df_physiological_lf_spans[col] = None
    
    df_all_physiological_entities = pd.concat([
        df_physiological_entities[common_columns],     # BC5CDR entities
        df_anatomy_custom[common_columns],             # Custom regex extraction
        df_physiological_lf_spans[common_columns]                    # LF-generated spans
    ], ignore_index=True)
    
    print(f"\nTotal combined entities: {len(df_all_physiological_entities)}")
    print("\nCombined entity distribution by label:")
    print(df_all_physiological_entities['label'].value_counts())
    print("\nCombined entity distribution by source:")
    print(df_all_physiological_entities['source'].value_counts())
    
    # Analysis of different entity types
    if not df_all_physiological_entities.empty:
        print("\n=== Analysis by Entity Type ===")
        
        # Analyze CARDIAC_ANATOMY entities
        cardiac = df_all_physiological_entities[
            df_all_physiological_entities['label'] == 'CARDIAC_ANATOMY'
        ]
        if not cardiac.empty:
            print(f"\n=== CARDIAC_ANATOMY Analysis ===")
            print(f"Total mentions: {len(cardiac)}")
            print(f"Unique terms: {cardiac['text'].nunique()}")
            print("\nTop cardiac terms:")
            print(cardiac['text'].value_counts().head(10))
        
        # Analyze MUSCULOSKELETAL entities
        msk = df_all_physiological_entities[
            df_all_physiological_entities['label'] == 'MUSCULOSKELETAL'
        ]
        if not msk.empty:
            print(f"\n=== MUSCULOSKELETAL Analysis ===")
            print(f"Total mentions: {len(msk)}")
            print(f"Unique terms: {msk['text'].nunique()}")
            print("\nTop musculoskeletal terms:")
            print(msk['text'].value_counts().head(10))
        
        # Compare sources
        print("\n=== Entity Source Comparison ===")
        source_label_dist = pd.crosstab(
            df_all_physiological_entities['source'],
            df_all_physiological_entities['label']
        )
        print(source_label_dist)
    
    # Optional: Add row-level phenotypes to original dataframe
    print("\n=== Adding Row-Level Phenotypes ===")
    for lf in anatomy_lfs:
        phenotype_name = f"has_{lf.__name__.replace('lf_', '')}"
        df_medical[phenotype_name] = df_medical.apply(
            lambda row: lf(row).get('label') != 'ABSTAIN',
            axis=1
        )
    
    # Show phenotype distribution
    phenotype_cols = [col for col in df_medical.columns if col.startswith('has_')]
    if phenotype_cols:
        print("\nRow-level phenotypes added:")
        for col in phenotype_cols:
            true_count = df_medical[col].sum()
            print(f"  {col}: {true_count} rows")
    
    # Save results
    df_all_physiological_entities.to_csv('physiological_entities_comprehensive.csv', index=False)
    df_anatomy_custom.to_csv('anatomy_entities_detailed.csv', index=False)
    df_physiological_lf_spans.to_csv('physiological_lf_generated_spans.csv', index=False)
    
    # Save dataframe with phenotypes
    df_medical.to_csv('medical_data_with_phenotypes.csv', index=False)
    
    print("\n\nAnatomy entity extraction complete!")
    print(f"Total entities extracted: {len(df_all_physiological_entities)}")
    print(f"  - BC5CDR: {len(df_physiological_entities)}")
    print(f"  - Custom extraction: {len(df_anatomy_custom)}")
    print(f"  - LF-generated: {len(df_physiological_lf_spans)}")
    print(f"\nRow-level phenotypes added: {len(phenotype_cols)}")
    print("\nAll results saved to CSV files.")


Extracting anatomy entities...

Custom anatomy extraction found 11614 anatomy mentions

=== Creating and Applying Labeling Functions ===

Labeling functions generated 3873 span entities

=== LF-Generated Spans Distribution ===
label
CARDIAC_ANATOMY         1971
MUSCULOSKELETAL          802
NEUROLOGICAL_ANATOMY     567
BILATERAL_CONDITION      533
Name: count, dtype: int64

=== Sample LF Spans ===
             text                label                   source
0           femur      MUSCULOSKELETAL    lf:lf_musculoskeletal
1  bilateral hips  BILATERAL_CONDITION  lf:lf_bilateral_anatomy
2          muscle      MUSCULOSKELETAL    lf:lf_musculoskeletal
3           heart      CARDIAC_ANATOMY    lf:lf_cardiac_anatomy
4           heart      CARDIAC_ANATOMY    lf:lf_cardiac_anatomy
5           spine      MUSCULOSKELETAL    lf:lf_musculoskeletal
6        Coronary      CARDIAC_ANATOMY    lf:lf_cardiac_anatomy
7        coronary      CARDIAC_ANATOMY    lf:lf_cardiac_anatomy
8        coronary      

In [17]:
df_physiological_lf_spans

Unnamed: 0,text,label,category,start,end,original_text,source,row_idx
0,femur,MUSCULOSKELETAL,musculoskeletal,87,92,"Coxa vara deformity of bilateral hips, bilater...",lf:lf_musculoskeletal,149866
1,bilateral hips,BILATERAL_CONDITION,laterality,23,37,"Coxa vara deformity of bilateral hips, bilater...",lf:lf_bilateral_anatomy,149866
2,muscle,MUSCULOSKELETAL,musculoskeletal,75,81,Diagnosed with a type of muscular dystrophy at...,lf:lf_musculoskeletal,80892
3,heart,CARDIAC_ANATOMY,cardiovascular,9,14,"Complete heart block, had dual chamber pacemak...",lf:lf_cardiac_anatomy,80727
4,heart,CARDIAC_ANATOMY,cardiovascular,88,93,"Complete heart block, had dual chamber pacemak...",lf:lf_cardiac_anatomy,80727
...,...,...,...,...,...,...,...,...
3868,cerebral,NEUROLOGICAL_ANATOMY,neurological,88,96,"Chronic rheumatic heart disease, atrial fibril...",lf:lf_neurological_anatomy,132307
3869,vein,CARDIAC_ANATOMY,cardiovascular,52,56,"Gastroesophageal reflux disease, hypertension,...",lf:lf_cardiac_anatomy,114112
3870,Coronary,CARDIAC_ANATOMY,cardiovascular,0,8,"Coronary arteriosclerosis, spinal canal stenos...",lf:lf_cardiac_anatomy,86992
3871,spinal,NEUROLOGICAL_ANATOMY,neurological,27,33,"Coronary arteriosclerosis, spinal canal stenos...",lf:lf_neurological_anatomy,86992


In [18]:
df_anatomy_custom

Unnamed: 0,text,label,category,anatomy,start,end,original_text,source,row_idx,laterality,direction
0,left elbow,ANATOMY_WITH_LATERALITY,joints,elbow,11,21,History of left elbow arthrodesis performed fo...,custom_anatomy_extraction,80176,left,
1,femur,ANATOMY,bones,femur,87,92,"Coxa vara deformity of bilateral hips, bilater...",custom_anatomy_extraction,149866,,
2,neck,ANATOMY,joints,neck,93,97,"Coxa vara deformity of bilateral hips, bilater...",custom_anatomy_extraction,149866,,
3,rectum,ANATOMY,organs,rectum,32,38,Colorectal carcinoma treated by rectum and ile...,custom_anatomy_extraction,92167,,
4,pancreas,ANATOMY,organs,pancreas,144,152,"Past medical history significant for RCC, init...",custom_anatomy_extraction,43555,,
...,...,...,...,...,...,...,...,...,...,...,...
11609,colon,ANATOMY,organs,colon,61,66,"Coronary arteriosclerosis, spinal canal stenos...",custom_anatomy_extraction,86992,,
11610,kidney,ANATOMY,organs,kidney,0,6,"Kidney stone lithotripsy, hypertension treated...",custom_anatomy_extraction,157822,,
11611,pulmonary,ANATOMY,cardiovascular,pulmonary,20,29,"Chronic obstructive pulmonary disease, high bl...",custom_anatomy_extraction,77450,,
11612,blood,ANATOMY,tissues,blood,44,49,"Chronic obstructive pulmonary disease, high bl...",custom_anatomy_extraction,77450,,


In [19]:
df_all_physiological_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,posttraumatic arthritis,DISEASE,48,71,History of left elbow arthrodesis performed fo...,bc5cdr,80176,
1,pain,DISEASE,116,120,"Inability to walk since babyhood, did not walk...",bc5cdr,31864,
2,fracture,DISEASE,151,159,"Inability to walk since babyhood, did not walk...",bc5cdr,31864,
3,Coxa vara deformity,DISEASE,0,19,"Coxa vara deformity of bilateral hips, bilater...",bc5cdr,149866,
4,fracture,DISEASE,75,83,"Coxa vara deformity of bilateral hips, bilater...",bc5cdr,149866,
...,...,...,...,...,...,...,...,...
51754,cerebral,NEUROLOGICAL_ANATOMY,88,96,"Chronic rheumatic heart disease, atrial fibril...",lf:lf_neurological_anatomy,132307,neurological
51755,vein,CARDIAC_ANATOMY,52,56,"Gastroesophageal reflux disease, hypertension,...",lf:lf_cardiac_anatomy,114112,cardiovascular
51756,Coronary,CARDIAC_ANATOMY,0,8,"Coronary arteriosclerosis, spinal canal stenos...",lf:lf_cardiac_anatomy,86992,cardiovascular
51757,spinal,NEUROLOGICAL_ANATOMY,27,33,"Coronary arteriosclerosis, spinal canal stenos...",lf:lf_neurological_anatomy,86992,neurological


##### Psychological Contexts

In [20]:

def extract_psychological_entities_custom(df_medical, text_column='psychological context', id_column='idx'):
    """Extract psychological-specific entities beyond what BC5CDR captures"""
    
    # Build the id series once
    if id_column is not None and id_column in df_medical.columns:
        ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
    else:
        ids_series = df_medical.index
    
    texts = df_medical[text_column].fillna('').astype(str)
    
    # Psychological patterns
    psychological_patterns = {
        'mood': ['depression', 'anxiety', 'euphoria', 'dysphoria', 'irritability', 'apathy'],
        'cognitive': ['confusion', 'disorientation', 'amnesia', 'dementia', 'delirium'],
        'psychotic': ['hallucination', 'delusion', 'paranoia', 'psychosis'],
        'behavioral': ['aggression', 'agitation', 'withdrawal', 'impulsivity'],
        'sleep': ['insomnia', 'hypersomnia', 'nightmares', 'sleep disturbance'],
        'trauma': ['ptsd', 'trauma', 'flashback', 'hypervigilance']
    }
    
    psychological_entities = []
    
    for (df_index, text_original), idx in zip(texts.items(), ids_series):
        if text_original:
            text = text_original.lower()
            
            # Extract psychological terms
            for category, terms in psychological_patterns.items():
                for term in terms:
                    pattern = r'\b' + re.escape(term) + r'\b'
                    for match in re.finditer(pattern, text):
                        psychological_entities.append({
                            'text': text_original[match.start():match.end()],
                            'label': 'PSYCHOLOGICAL_CONDITION',
                            'category': category,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': text_original,
                            'source': 'custom_psychological_extraction',
                            'row_idx': idx
                        })
    
    return pd.DataFrame(psychological_entities)


def create_psychological_labeling_functions():
    """Create labeling functions that return evidence for psychological span creation"""
    
    def lf_mood_disorder(row):
            text = str(row.get('psychological context', ''))
            mood_terms = ['depression', 'depressed', 'anxiety', 'anxious', 'manic', 'bipolar', 
                        'mood disorder', 'dysthymia', 'euphoria']
            
            hit = _first_hit(text, mood_terms)
            return {
                'label': 'MOOD_DISORDER',
                'column': 'psychological context',
                'match': hit,
                'category': 'mood'
            } if hit else {'label': 'ABSTAIN'}
    
    def lf_cognitive_impairment(row):
        text = str(row.get('psychological context', ''))
        cognitive_terms = ['memory loss', 'confusion', 'disoriented', 'cognitive decline', 
                          'dementia', 'alzheimer', 'amnesia', 'forgetful']
        
        hit = _first_hit(text, cognitive_terms)
        return {
            'label': 'COGNITIVE_IMPAIRMENT',
            'column': 'psychological context',
            'match': hit,
            'category': 'cognitive'
        } if hit else {'label': 'ABSTAIN'}
    
    def lf_psychotic_symptoms(row):
        text = str(row.get('psychological context', '')).lower()
        psychotic_terms = ['hallucination', 'delusion', 'paranoid', 'psychosis', 'psychotic',
                        'hearing voices', 'seeing things']
        
        hit = _first_hit(text, psychotic_terms)
        return {
            'label': 'PSYCHOTIC_SYMPTOMS',
            'column': 'psychological context',
            'match': hit,
            'category': 'psychotic'
        } if hit else {'label': 'ABSTAIN'}
    
    def lf_trauma_related(row):
        text = str(row.get('psychological context', '')).lower()
        trauma_terms = ['ptsd', 'post-traumatic', 'trauma', 'flashback', 'nightmare',
                    'hypervigilance', 'traumatic stress']
        
        hit = _first_hit(text, trauma_terms)
        return {
            'label': 'TRAUMA_RELATED',
            'column': 'psychological context',
            'match': hit,
            'category': 'trauma'
        } if hit else {'label': 'ABSTAIN'}
    
    def lf_sleep_disorder(row):
        text = str(row.get('psychological context', '')).lower()
        sleep_terms = ['insomnia', 'sleep disorder', 'cant sleep', 'sleep disturbance',
                    'nightmares', 'night terrors', 'sleep apnea']
        
        hit = _first_hit(text, sleep_terms)
        return {
            'label': 'SLEEP_DISORDER',
            'column': 'psychological context',
            'match': hit,
            'category': 'sleep'
        } if hit else {'label': 'ABSTAIN'}
    
    
    return [lf_mood_disorder, lf_cognitive_impairment, lf_psychotic_symptoms, 
            lf_trauma_related, lf_sleep_disorder]


def materialize_psychological_lf_spans(df, labeling_functions, id_column='idx'):
    """Convert psychological labeling function results into span entities"""
    lf_entities = []
    
    # Build the id series once
    if id_column is not None and id_column in df.columns:
        ids_series = df[id_column].where(pd.notna(df[id_column]), df.index)
    else:
        ids_series = df.index
    
    for (df_index, row), idx in zip(df.iterrows(), ids_series):
        for lf in labeling_functions:
            result = lf(row)
            
            if result.get('label') != 'ABSTAIN':
                # Get the text from the specified column
                column_name = result.get('column', 'psychological context')
                full_text = str(row.get(column_name, ''))
                match_term = result['match']
                
                # Find all occurrences of the match term
                pattern = r'\b' + re.escape(match_term) + r'\b'
                
                for match in re.finditer(pattern, full_text.lower()):
                    # Extract the actual text (preserving original case)
                    actual_text = full_text[match.start():match.end()]
                    
                    lf_entities.append({
                        'text': actual_text,
                        'label': result['label'],
                        'category': result.get('category', 'unknown'),
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': full_text,
                        'source': f'lf:{lf.__name__}',
                        'row_idx': idx
                    })
    
    return pd.DataFrame(lf_entities)


# Main execution
if __name__ == "__main__":
    # Run extraction for the 'psychological context' column with BC5CDR
    df_psychological_entities, disease_summary, rules = run_medical_ner_extraction(
        df_medical,
        text_column='psychological context',
        model_name="en_ner_bc5cdr_md",
        batch_size=500,
        id_column='idx'
    )
    
    # Test the generated rules from BC5CDR
    print("\nTesting BC5CDR-generated labeling functions on df_medical...")
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]
        
        for rule_name, rule_func in rules.items():
            try:
                test_result = rule_func(sample_row)
                if test_result != 'ABSTAIN':
                    print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column '{e}' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")
    
    print("\n=== BC5CDR Extraction Results ===")
    print(f"Shape of df_psychological_entities: {df_psychological_entities.shape}")
    print(f"\nColumn names: {df_psychological_entities.columns.tolist()}")
    
    if not df_psychological_entities.empty:
        print("\n=== BC5CDR Entity Labels Found ===")
        print(df_psychological_entities['label'].value_counts())
    else:
        print("df_psychological_entities is empty.")
    
    # Extract custom psychological entities
    print("\n=== CUSTOM PSYCHOLOGICAL ENTITY EXTRACTION ===")
    df_custom_psychological = extract_psychological_entities_custom(
        df_medical,
        text_column='psychological context',
        id_column='idx'
    )
    
    print(f"\nCustom extraction found {len(df_custom_psychological)} psychological entities")
    if not df_custom_psychological.empty:
        print("\n=== Custom Entity Distribution ===")
        print(df_custom_psychological['label'].value_counts())
        print("\n=== Custom Entity Categories ===")
        print(df_custom_psychological['category'].value_counts())
    
    # Create and apply psychological labeling functions
    print("\n=== Creating and Applying Psychological Labeling Functions ===")
    psychological_lfs = create_psychological_labeling_functions()
    
    # Materialize spans from labeling functions
    df_psychological_lf_spans = materialize_psychological_lf_spans(
        df_medical, 
        psychological_lfs, 
        id_column='idx'
    )
    
    print(f"\nLabeling functions generated {len(df_psychological_lf_spans)} span entities")
    
    if not df_psychological_lf_spans.empty:
        print("\n=== LF-Generated Spans Distribution ===")
        print(df_psychological_lf_spans['label'].value_counts())
        print("\n=== Sample LF Spans ===")
        print(df_psychological_lf_spans[['text', 'label', 'source']].head(10))
    
    # Row-level coverage analysis
    print("\n=== Row-Level Coverage Analysis ===")
    coverage_results = {}
    for lf in psychological_lfs:
        labeled_count = 0
        for _, row in df_medical.iterrows():
            result = lf(row)
            if result.get('label') != 'ABSTAIN':
                labeled_count += 1
        
        coverage = (labeled_count / len(df_medical) * 100) if len(df_medical) > 0 else 0
        coverage_results[lf.__name__] = {
            'labeled': labeled_count,
            'coverage': coverage
        }
    
    print("\nLabeling function coverage:")
    for lf_name, stats in sorted(coverage_results.items(), key=lambda x: x[1]['coverage'], reverse=True):
        print(f"  {lf_name}: {stats['labeled']} rows ({stats['coverage']:.1f}% coverage)")
    
    # Combine all sources
    print("\n=== COMBINING ALL PSYCHOLOGICAL ENTITY SOURCES ===")
    
    # Prepare common columns
    common_columns = ['text', 'label', 'start', 'end', 'original_text', 'source']
    if 'row_idx' in df_psychological_entities.columns:
        common_columns.append('row_idx')
    if 'category' in df_custom_psychological.columns:
        common_columns.append('category')
    
    # Ensure all dataframes have required columns
    for col in common_columns:
        if col not in df_psychological_entities.columns:
            df_psychological_entities[col] = 'bc5cdr' if col == 'source' else None
        if col not in df_custom_psychological.columns:
            df_custom_psychological[col] = 'custom_extraction' if col == 'source' else None
        if col not in df_psychological_lf_spans.columns:
            df_psychological_lf_spans[col] = None
    
    df_all_psychological_entities = pd.concat([
        df_psychological_entities[common_columns],
        df_custom_psychological[common_columns],
        df_psychological_lf_spans[common_columns]
    ], ignore_index=True)
    
    print(f"\nTotal combined entities: {len(df_all_psychological_entities)}")
    print("\nCombined entity distribution by label:")
    print(df_all_psychological_entities['label'].value_counts())
    print("\nCombined entity distribution by source:")
    print(df_all_psychological_entities['source'].value_counts())
    
    # Analysis by psychological categories
    if not df_all_psychological_entities.empty:
        print("\n=== Analysis by Psychological Category ===")
        
        # Mood disorders
        mood_entities = df_all_psychological_entities[
            df_all_psychological_entities['label'].isin(['MOOD_DISORDER', 'PSYCHOLOGICAL_CONDITION'])
        ]
        if not mood_entities.empty and 'category' in mood_entities.columns:
            mood_specific = mood_entities[mood_entities['category'] == 'mood']
            if not mood_specific.empty:
                print(f"\n=== Mood-Related Terms ===")
                print(f"Total mentions: {len(mood_specific)}")
                print("Top terms:")
                print(mood_specific['text'].value_counts().head(10))
        
        # Cognitive issues
        cognitive_entities = df_all_psychological_entities[
            df_all_psychological_entities['label'] == 'COGNITIVE_IMPAIRMENT'
        ]
        if not cognitive_entities.empty:
            print(f"\n=== Cognitive Impairment Terms ===")
            print(f"Total mentions: {len(cognitive_entities)}")
            print("Top terms:")
            print(cognitive_entities['text'].value_counts().head(10))
    
    # Add row-level psychological phenotypes
    print("\n=== Adding Row-Level Psychological Phenotypes ===")
    for lf in psychological_lfs:
        phenotype_name = f"has_{lf.__name__.replace('lf_', '')}"
        df_medical[phenotype_name] = df_medical.apply(
            lambda row: lf(row).get('label') != 'ABSTAIN',
            axis=1
        )
    
    # Show phenotype distribution
    phenotype_cols = [col for col in df_medical.columns if col.startswith('has_') and 'mood' in col or 'cognitive' in col or 'psychotic' in col or 'trauma' in col or 'sleep' in col]
    if phenotype_cols:
        print("\nPsychological phenotypes added:")
        for col in phenotype_cols:
            true_count = df_medical[col].sum()
            print(f"  {col}: {true_count} rows")
    
    # Save all results
    df_all_psychological_entities.to_csv('psychological_entities_comprehensive.csv', index=False)
    df_custom_psychological.to_csv('psychological_entities_custom.csv', index=False)
    df_psychological_lf_spans.to_csv('psychological_lf_generated_spans.csv', index=False)
    
    print("\n\nPsychological entity extraction complete!")
    print(f"Total entities extracted: {len(df_all_psychological_entities)}")
    print(f"  - BC5CDR: {len(df_psychological_entities)}")
    print(f"  - Custom extraction: {len(df_custom_psychological)}")
    print(f"  - LF-generated: {len(df_psychological_lf_spans)}")
    print("\nAll results saved to CSV files.")

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 60 batches...
Using model: en_ner_bc5cdr_md for column: psychological context


Processing batches:   2%|▏         | 1/60 [00:01<00:59,  1.01s/it]


Checkpoint saved at batch 0


Processing batches:  18%|█▊        | 11/60 [00:03<00:13,  3.65it/s]


Checkpoint saved at batch 5000


Processing batches:  35%|███▌      | 21/60 [00:05<00:10,  3.84it/s]


Checkpoint saved at batch 10000


Processing batches:  52%|█████▏    | 31/60 [00:08<00:07,  3.66it/s]


Checkpoint saved at batch 15000


Processing batches:  68%|██████▊   | 41/60 [00:11<00:05,  3.54it/s]


Checkpoint saved at batch 20000


Processing batches:  85%|████████▌ | 51/60 [00:13<00:02,  3.69it/s]


Checkpoint saved at batch 25000


Processing batches: 100%|██████████| 60/60 [00:15<00:00,  3.84it/s]


Found 136 entities appearing >= 5 times

Testing BC5CDR-generated labeling functions on df_medical...

=== BC5CDR Extraction Results ===
Shape of df_psychological_entities: (3402, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

=== BC5CDR Entity Labels Found ===
label
DISEASE     3240
CHEMICAL     162
Name: count, dtype: int64

=== CUSTOM PSYCHOLOGICAL ENTITY EXTRACTION ===

Custom extraction found 1107 psychological entities

=== Custom Entity Distribution ===
label
PSYCHOLOGICAL_CONDITION    1107
Name: count, dtype: int64

=== Custom Entity Categories ===
category
mood          848
behavioral     71
psychotic      59
cognitive      53
sleep          46
trauma         30
Name: count, dtype: int64

=== Creating and Applying Psychological Labeling Functions ===

Labeling functions generated 1174 span entities

=== LF-Generated Spans Distribution ===
label
MOOD_DISORDER           931
PSYCHOTIC_SYMPTOMS       99
COGNITIVE_IMPAIRMENT     57

In [21]:
df_all_psychological_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,bipolar affective disorder,DISEASE,15,41,Diagnosed with bipolar affective disorder at t...,bc5cdr,155216,
1,mania,DISEASE,90,95,Diagnosed with bipolar affective disorder at t...,bc5cdr,155216,
2,Parental distress,DISEASE,0,17,Parental distress,bc5cdr,90928,
3,depression,DISEASE,68,78,Known to local mental health services for 20 y...,bc5cdr,45433,
4,anxiety,DISEASE,83,90,Known to local mental health services for 20 y...,bc5cdr,45433,
...,...,...,...,...,...,...,...,...
5678,depression,MOOD_DISORDER,11,21,History of depression diagnosed about 2 years ...,lf:lf_mood_disorder,136465,mood
5679,anxiety,MOOD_DISORDER,22,29,Developed generalized anxiety disorder,lf:lf_mood_disorder,76671,mood
5680,Bipolar,MOOD_DISORDER,0,7,Bipolar disorder,lf:lf_mood_disorder,87937,mood
5681,Bipolar,MOOD_DISORDER,0,7,Bipolar disorder,lf:lf_mood_disorder,113022,mood


##### Vaccination History

In [22]:
def extract_vaccination_entities_custom(df_medical, text_column='vaccination history', id_column='idx'):
    """Extract vaccination-specific entities beyond what BC5CDR captures"""
    
    # Build the id series once
    if id_column is not None and id_column in df_medical.columns:
        ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
    else:
        ids_series = df_medical.index
    
    texts = df_medical[text_column].fillna('').astype(str)
    
    # Vaccination patterns
    vaccination_patterns = {
        'covid_vaccines': [
            'covid', 'covid-19', 'coronavirus', 'sars-cov-2', 'pfizer', 'moderna',
            'astrazeneca', 'johnson & johnson', 'j&j', 'mrna-1273', 'bnt162b2',
            'covaxin', 'sputnik', 'sinovac', 'sinopharm'
        ],
        'routine_vaccines': [
            'mmr', 'measles', 'mumps', 'rubella', 'varicella', 'chickenpox',
            'polio', 'dtap', 'diphtheria', 'tetanus', 'pertussis', 'whooping cough',
            'hib', 'hepatitis a', 'hepatitis b', 'hep a', 'hep b', 'rotavirus',
            'pcv', 'ipv', 'opv', 'bcg', 'tuberculosis'
        ],
        'adult_vaccines': [
            'influenza', 'flu vaccine', 'flu shot', 'pneumococcal', 'pneumonia vaccine',
            'shingles', 'zoster', 'hpv', 'human papillomavirus', 'tdap', 'td'
        ],
        'travel_vaccines': [
            'yellow fever', 'typhoid', 'japanese encephalitis', 'rabies',
            'meningococcal', 'cholera', 'malaria prophylaxis'
        ],
        'vaccine_brands': [
            'havrix', 'engerix', 'prevnar', 'pneumovax', 'fluzone', 'flumist',
            'gardasil', 'cervarix', 'boostrix', 'adacel', 'shingrix', 'zostavax'
        ]
    }
    
    # Timing patterns
    timing_patterns = [
        r'\d+\s*(?:year|month|week|day)s?\s*ago',
        r'(?:last|this)\s*(?:year|month|week)',
        r'(?:january|february|march|april|may|june|july|august|september|october|november|december)\s*\d{4}',
        r'\d{1,2}[/-]\d{1,2}[/-]\d{2,4}',
        r'recently', r'childhood', r'infancy', r'adolescence'
    ]
    
    # Dose patterns
    dose_patterns = [
        r'(?:first|second|third|1st|2nd|3rd)\s*(?:dose|shot)',
        r'booster\s*(?:dose|shot)?',
        r'single\s*dose',
        r'(?:completed|complete)\s*series',
        r'fully\s*vaccinated',
        r'partially\s*vaccinated'
    ]
    
    # Status patterns
    status_patterns = {
        'up_to_date': ['up to date', 'up-to-date', 'current', 'complete'],
        'not_vaccinated': ['not vaccinated', 'unvaccinated', 'declined', 'refused', 'no history'],
        'overdue': ['overdue', 'due for', 'needs', 'recommended'],
        'contraindicated': ['contraindicated', 'allergic', 'cannot receive']
    }
    
    vaccination_entities = []
    
    for (df_index, text_original), idx in zip(texts.items(), ids_series):
        if text_original:
            text = text_original.lower()
            
            # Extract vaccine names
            for category, vaccines in vaccination_patterns.items():
                for vaccine in vaccines:
                    pattern = r'\b' + re.escape(vaccine) + r'\b'
                    for match in re.finditer(pattern, text):
                        vaccination_entities.append({
                            'text': text_original[match.start():match.end()],
                            'label': 'VACCINE',
                            'category': category,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': text_original,
                            'source': 'custom_vaccination_extraction',
                            'row_idx': idx
                        })
            
            # Extract timing information
            for pattern in timing_patterns:
                for match in re.finditer(pattern, text, re.IGNORECASE):
                    vaccination_entities.append({
                        'text': text_original[match.start():match.end()],
                        'label': 'VACCINATION_TIMING',
                        'category': 'temporal',
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': text_original,
                        'source': 'custom_vaccination_extraction',
                        'row_idx': idx
                    })
            
            # Extract dose information
            for pattern in dose_patterns:
                for match in re.finditer(pattern, text, re.IGNORECASE):
                    vaccination_entities.append({
                        'text': text_original[match.start():match.end()],
                        'label': 'VACCINE_DOSE',
                        'category': 'dosage',
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': text_original,
                        'source': 'custom_vaccination_extraction',
                        'row_idx': idx
                    })
            
            # Extract vaccination status
            for status_type, terms in status_patterns.items():
                for term in terms:
                    pattern = r'\b' + re.escape(term) + r'\b'
                    for match in re.finditer(pattern, text):
                        vaccination_entities.append({
                            'text': text_original[match.start():match.end()],
                            'label': 'VACCINATION_STATUS',
                            'category': status_type,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': text_original,
                            'source': 'custom_vaccination_extraction',
                            'row_idx': idx
                        })
    
    return pd.DataFrame(vaccination_entities)


def materialize_lf_spans(df, labeling_functions, id_column='idx'):
    """Convert labeling function results into span entities"""
    lf_entities = []
    
    # Build the id series once
    if id_column is not None and id_column in df.columns:
        ids_series = df[id_column].where(pd.notna(df[id_column]), df.index)
    else:
        ids_series = df.index
    
    for (df_index, row), idx in zip(df.iterrows(), ids_series):
        for lf in labeling_functions:
            result = lf(row)
            
            if result.get('label') != 'ABSTAIN':
                # Get the text from the specified column
                column_name = result.get('column', 'vaccination history')
                full_text = str(row.get(column_name, ''))
                match_term = result.get('match')
                
                if match_term:
                    # Find all occurrences of the match term
                    pattern = r'\b' + re.escape(match_term) + r'\b'
                    
                    for match in re.finditer(pattern, full_text.lower()):
                        # Extract the actual text (preserving original case)
                        actual_text = full_text[match.start():match.end()]
                        
                        lf_entities.append({
                            'text': actual_text,
                            'label': result['label'],
                            'category': result.get('category', 'unknown'),
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': full_text,
                            'source': f'lf:{lf.__name__}',
                            'row_idx': idx
                        })
    
    return pd.DataFrame(lf_entities)


# Main execution
if __name__ == "__main__":
    # BC5CDR model will identify vaccines (CHEMICAL) and diseases (DISEASE)
    df_vaccination_entities, vaccination_summary, vaccination_rules = run_medical_ner_extraction(
        df_medical, 
        text_column='vaccination history',
        model_name="en_ner_bc5cdr_md",  # This model recognizes DISEASE and CHEMICAL entities
        batch_size=300,
        id_column='idx'  # Use 'idx' column for row identifiers
    )

    # Test the generated rules
    print("\nTesting generated labeling functions for vaccination history...")

    # Select a sample row from the original df_medical to test the rules
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in vaccination_rules.items():
            try:
                test_result = rule_func(sample_row)
                if test_result != 'ABSTAIN':
                    print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'vaccination_history' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Debug: Check extracted entities
    print("\n=== BC5CDR Extraction Results ===")
    print(f"Shape of df_vaccination_entities: {df_vaccination_entities.shape}")
    print(f"\nColumn names: {df_vaccination_entities.columns.tolist()}")

    # Check entity label distribution
    print("\n=== Entity Labels Found in Vaccination History ===")
    if not df_vaccination_entities.empty:
        print(df_vaccination_entities['label'].value_counts())

        # Look at CHEMICAL entities (likely vaccines)
        if 'CHEMICAL' in df_vaccination_entities['label'].values:
            print("\n=== Top CHEMICAL/Vaccine entities ===")
            vaccine_entities = df_vaccination_entities[df_vaccination_entities['label'] == 'CHEMICAL']['text'].value_counts().head(20)
            print(vaccine_entities)

        # Look at DISEASE entities (conditions vaccines prevent)
        if 'DISEASE' in df_vaccination_entities['label'].values:
            print("\n=== Top DISEASE entities (conditions) ===")
            disease_entities = df_vaccination_entities[df_vaccination_entities['label'] == 'DISEASE']['text'].value_counts().head(20)
            print(disease_entities)
    else:
        print("df_vaccination_entities is empty.")
    
    # Extract custom vaccination entities
    print("\n=== CUSTOM VACCINATION ENTITY EXTRACTION ===")
    df_custom_vaccination = extract_vaccination_entities_custom(
        df_medical,
        text_column='vaccination history',
        id_column='idx'
    )
    
    print(f"\nCustom extraction found {len(df_custom_vaccination)} vaccination entities")
    if not df_custom_vaccination.empty:
        print("\n=== Custom Entity Distribution by Label ===")
        print(df_custom_vaccination['label'].value_counts())
        print("\n=== Custom Entity Distribution by Category ===")
        print(df_custom_vaccination['category'].value_counts())
        print("\n=== Top Vaccine Names ===")
        vaccine_names = df_custom_vaccination[df_custom_vaccination['label'] == 'VACCINE']
        if not vaccine_names.empty:
            print(vaccine_names['text'].value_counts().head(20))

    # Create labeling functions
    def create_vaccination_labeling_functions(entities_df: pd.DataFrame):
        """Spanified vaccination LFs: return {'label','column','match','category?'}."""
        COL = 'vaccination history'

        _ = entities_df['text'].value_counts().head(50) if not entities_df.empty else pd.Series(dtype=int)

        def _first_hit(text, terms):
            text_l = text.lower()
            for t in terms:
                if t in text_l:
                    return t
            return None

        def lf_covid_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['covid', 'coronavirus', 'sars-cov-2', 'pfizer', 'moderna',
                    'astrazeneca', 'johnson', 'mrna-1273', 'bnt162b2']
            hit = _first_hit(text, terms)
            return {'label': 'COVID_VACCINE', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_childhood_vaccines(row):
            text = str(row.get(COL, ''))
            terms = ['mmr', 'measles', 'mumps', 'rubella', 'varicella', 'chickenpox',
                    'polio', 'dtap', 'diphtheria', 'tetanus', 'pertussis', 'whooping',
                    'hib', 'hepatitis b', 'rotavirus', 'pcv', 'ipv']
            hit = _first_hit(text, terms)
            return {'label': 'CHILDHOOD_VACCINES', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_influenza_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['influenza', 'flu vaccine', 'flu shot', 'seasonal flu', 'h1n1']
            hit = _first_hit(text, terms)
            return {'label': 'FLU_VACCINE', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_hepatitis_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['hepatitis a', 'hepatitis b', 'hep a', 'hep b', 'havrix', 'engerix']
            hit = _first_hit(text, terms)
            return {'label': 'HEPATITIS_VACCINE', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_tetanus_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['tetanus', 'tdap', 'td ', 'boostrix', 'adacel']
            hit = _first_hit(text, terms)
            return {'label': 'TETANUS_VACCINE', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_pneumococcal_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['pneumococcal', 'pneumonia vaccine', 'prevnar', 'pneumovax']
            hit = _first_hit(text, terms)
            return {'label': 'PNEUMO_VACCINE', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_travel_vaccines(row):
            text = str(row.get(COL, ''))
            terms = ['yellow fever', 'typhoid', 'japanese encephalitis', 'rabies', 'meningococcal', 'cholera']
            hit = _first_hit(text, terms)
            return {'label': 'TRAVEL_VACCINES', 'column': COL, 'match': hit, 'category': 'vaccine'} if hit else {'label': 'ABSTAIN'}

        def lf_vaccination_timing(row):
            text = str(row.get(COL, ''))
            terms = ['booster', 'dose', 'series', 'schedule', 'up to date',
                    'fully vaccinated', 'partially vaccinated']
            hit = _first_hit(text, terms)
            return {'label': 'VACCINATION_TIMING', 'column': COL, 'match': hit, 'category': 'timing'} if hit else {'label': 'ABSTAIN'}

        def lf_no_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['no vaccination', 'not vaccinated', 'unvaccinated',
                    'declined', 'refused', 'no history of vaccination']
            hit = _first_hit(text, terms)
            return {'label': 'UNVACCINATED', 'column': COL, 'match': hit, 'category': 'status'} if hit else {'label': 'ABSTAIN'}

        def lf_vaccine_reaction(row):
            text = str(row.get(COL, ''))
            terms = ['reaction', 'allergy', 'side effect', 'adverse', 'anaphylaxis']
            hit = _first_hit(text, terms)
            return {'label': 'VACCINE_REACTION', 'column': COL, 'match': hit, 'category': 'reaction'} if hit else {'label': 'ABSTAIN'}

        def lf_recent_vaccination(row):
            text = str(row.get(COL, ''))
            terms = ['weeks ago', 'months ago', 'recently', 'last month', 'last week', 'this year']
            hit = _first_hit(text, terms)
            return {'label': 'RECENT_VACCINATION', 'column': COL, 'match': hit, 'category': 'recency'} if hit else {'label': 'ABSTAIN'}

        def lf_historical_vaccination(row):
            text = str(row.get(COL, ''))
            if 'history' in text.lower() and any(k in text.lower() for k in ['vaccin', 'immuniz']):
                return {'label': 'VACCINATION_HISTORY', 'column': COL, 'match': 'history', 'category': 'history'}
            return {'label': 'ABSTAIN'}

        return [
            lf_covid_vaccination, lf_childhood_vaccines, lf_influenza_vaccination,
            lf_hepatitis_vaccination, lf_tetanus_vaccination, lf_pneumococcal_vaccination,
            lf_travel_vaccines, lf_vaccination_timing, lf_no_vaccination,
            lf_vaccine_reaction, lf_recent_vaccination, lf_historical_vaccination
        ]

    # Apply labeling functions
    vaccination_lfs = create_vaccination_labeling_functions(df_vaccination_entities)
    
    df_vaccination_lf_spans = materialize_lf_spans(df_medical, vaccination_lfs, id_column='idx')
    print(f"\n=== LF-Generated Vaccination Spans ===")
    print(f"Total LF-generated spans: {len(df_vaccination_lf_spans)}")
    if not df_vaccination_lf_spans.empty:
        print("\nLF span distribution:")
        print(df_vaccination_lf_spans['label'].value_counts())
    
    # Row-level coverage analysis
    print("\n=== Row-Level Coverage Analysis ===")
    coverage_results = {}
    for lf in vaccination_lfs:
        labeled_count = 0
        for _, row in df_medical.iterrows():
            result = lf(row)
            if result.get('label') != 'ABSTAIN':
                labeled_count += 1
        
        coverage = (labeled_count / len(df_medical) * 100) if len(df_medical) > 0 else 0
        coverage_results[lf.__name__] = {
            'labeled': labeled_count,
            'coverage': coverage
        }
    
    print("\nLabeling function coverage:")
    for lf_name, stats in sorted(coverage_results.items(), key=lambda x: x[1]['coverage'], reverse=True):
        print(f"  {lf_name}: {stats['labeled']} rows ({stats['coverage']:.1f}% coverage)")

    # Combine all sources
    print("\n=== COMBINING ALL VACCINATION ENTITY SOURCES ===")
    
    # Ensure common columns exist for concat
    _v_cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _v_cols:
        if col not in df_vaccination_entities.columns:
            df_vaccination_entities[col] = 'bc5cdr' if col=='source' else None
        if col not in df_custom_vaccination.columns:
            df_custom_vaccination[col] = 'custom_extraction' if col=='source' else None
        if col not in df_vaccination_lf_spans.columns:
            df_vaccination_lf_spans[col] = None

    df_all_vaccination_entities = pd.concat(
        [df_vaccination_entities[_v_cols], 
         df_custom_vaccination[_v_cols],
         df_vaccination_lf_spans[_v_cols]],
        ignore_index=True
    )
    
    print(f"\nTotal vaccination entities: {len(df_all_vaccination_entities)}")
    print(f"  - BC5CDR: {len(df_vaccination_entities)}")
    print(f"  - Custom extraction: {len(df_custom_vaccination)}")
    print(f"  - LF-generated: {len(df_vaccination_lf_spans)}")
    
    print("\nCombined entity distribution by label:")
    print(df_all_vaccination_entities['label'].value_counts())
    print("\nCombined entity distribution by source:")
    print(df_all_vaccination_entities['source'].value_counts())
    
    # Analysis by vaccine type
    if not df_all_vaccination_entities.empty:
        print("\n=== Vaccine Type Analysis ===")
        
        # COVID vaccines
        covid_vaccines = df_all_vaccination_entities[
            df_all_vaccination_entities['label'].isin(['COVID_VACCINE', 'VACCINE'])
        ]
        if not covid_vaccines.empty and 'category' in covid_vaccines.columns:
            covid_specific = covid_vaccines[covid_vaccines['category'] == 'covid_vaccines']
            if not covid_specific.empty:
                print(f"\n=== COVID Vaccine Mentions ===")
                print(f"Total: {len(covid_specific)}")
                print("Top terms:")
                print(covid_specific['text'].value_counts().head(10))
    
    # Add row-level vaccination phenotypes
    print("\n=== Adding Row-Level Vaccination Phenotypes ===")
    for lf in vaccination_lfs:
        phenotype_name = f"has_{lf.__name__.replace('lf_', '')}"
        df_medical[phenotype_name] = df_medical.apply(
            lambda row: lf(row).get('label') != 'ABSTAIN',
            axis=1
        )
    
    # Show phenotype distribution
    phenotype_cols = [col for col in df_medical.columns if col.startswith('has_') and ('vaccine' in col.lower() or 'vaccination' in col.lower())]
    if phenotype_cols:
        print("\nVaccination phenotypes added:")
        for col in phenotype_cols:
            true_count = df_medical[col].sum()
            print(f"  {col}: {true_count} rows")
    
    # Save results
    df_vaccination_lf_spans.to_csv('vaccination_lf_generated_spans.csv', index=False)
    df_all_vaccination_entities.to_csv('vaccination_entities_comprehensive.csv', index=False)
    df_custom_vaccination.to_csv('vaccination_entities_custom.csv', index=False)
    
    print("\n\nVaccination entity extraction complete!")
    print("All results saved to CSV files.")

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: vaccination history


Processing batches:   2%|▏         | 2/100 [00:00<00:36,  2.68it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:15,  5.74it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:03<00:12,  6.43it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:05<00:11,  5.94it/s]


Checkpoint saved at batch 9000


Processing batches:  42%|████▏     | 42/100 [00:07<00:09,  6.13it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:08<00:07,  6.30it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:10<00:06,  6.24it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:11<00:04,  5.89it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:13<00:02,  6.33it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:14<00:01,  6.23it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:16<00:00,  6.22it/s]


Found 6 entities appearing >= 5 times

Testing generated labeling functions for vaccination history...

=== BC5CDR Extraction Results ===
Shape of df_vaccination_entities: (129, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

=== Entity Labels Found in Vaccination History ===
label
DISEASE     106
CHEMICAL     23
Name: count, dtype: int64

=== Top CHEMICAL/Vaccine entities ===
text
vitamin K                6
Calmette                 3
Calmette-Guérin          2
tetanus                  2
Vitamin K                2
Guérin                   2
benzathine penicillin    1
penicillin               1
DTaP                     1
DPT                      1
Vaccine                  1
diphtheria pertussis     1
Name: count, dtype: int64

=== Top DISEASE entities (conditions) ===
text
tetanus                     28
Tetanus                     13
hepatitis B                  8
Anti-D                       5
left buttock                 5
varicella   

In [23]:
df_vaccination_lf_spans

Unnamed: 0,text,label,category,start,end,original_text,source,row_idx
0,history,VACCINATION_HISTORY,history,3,10,No history of recent vaccination,lf:lf_historical_vaccination,87135
1,rabies,TRAVEL_VACCINES,vaccine,11,17,Given anti rabies vaccine and immunoglobulin a...,lf:lf_travel_vaccines,132574
2,Tetanus,CHILDHOOD_VACCINES,vaccine,0,7,Tetanus vaccination with tetanus immunoglobuli...,lf:lf_childhood_vaccines,119386
3,tetanus,CHILDHOOD_VACCINES,vaccine,25,32,Tetanus vaccination with tetanus immunoglobuli...,lf:lf_childhood_vaccines,119386
4,Tetanus,TETANUS_VACCINE,vaccine,0,7,Tetanus vaccination with tetanus immunoglobuli...,lf:lf_tetanus_vaccination,119386
...,...,...,...,...,...,...,...,...
318,Tetanus,CHILDHOOD_VACCINES,vaccine,0,7,Tetanus vaccination administered during curren...,lf:lf_childhood_vaccines,68442
319,Tetanus,TETANUS_VACCINE,vaccine,0,7,Tetanus vaccination administered during curren...,lf:lf_tetanus_vaccination,68442
320,history,VACCINATION_HISTORY,history,3,10,No history of any immunizations in the recent ...,lf:lf_historical_vaccination,156722
321,influenza,FLU_VACCINE,vaccine,17,26,"Lacked a current influenza vaccination, all ot...",lf:lf_influenza_vaccination,74172


In [24]:
df_all_vaccination_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,Tetanus,DISEASE,0,7,Tetanus vaccination with tetanus immunoglobuli...,bc5cdr,119386,
1,tetanus,DISEASE,25,32,Tetanus vaccination with tetanus immunoglobuli...,bc5cdr,119386,
2,hyposplenism,DISEASE,39,51,Vaccinated post-treatment for presumed hypospl...,bc5cdr,13774,
3,tetanus,DISEASE,14,21,No history of tetanus vaccination or tetanus i...,bc5cdr,157338,
4,tetanus infection,DISEASE,37,54,No history of tetanus vaccination or tetanus i...,bc5cdr,157338,
...,...,...,...,...,...,...,...,...
810,Tetanus,CHILDHOOD_VACCINES,0,7,Tetanus vaccination administered during curren...,lf:lf_childhood_vaccines,68442,vaccine
811,Tetanus,TETANUS_VACCINE,0,7,Tetanus vaccination administered during curren...,lf:lf_tetanus_vaccination,68442,vaccine
812,history,VACCINATION_HISTORY,3,10,No history of any immunizations in the recent ...,lf:lf_historical_vaccination,156722,history
813,influenza,FLU_VACCINE,17,26,"Lacked a current influenza vaccination, all ot...",lf:lf_influenza_vaccination,74172,vaccine


##### Allergies

In [25]:
# Process allergies using BC5CDR model (recognizes DISEASE and CHEMICAL)
if __name__ == "__main__":
    # Run extraction for the 'allergies' column
    # BC5CDR model will identify allergens (often CHEMICAL entities)
    df_allergies_entities, allergies_summary, allergies_rules = run_medical_ner_extraction(
        df_medical, 
        text_column='allergies',
        model_name="en_ner_bc5cdr_md",  # This model recognizes DISEASE and CHEMICAL entities
        batch_size=300,
        id_column='idx' 
    )

    # Test the generated rules
    print("\nTesting generated labeling functions for allergies...")

    # Select a sample row from the original df_medical to test the rules
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in allergies_rules.items():
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'allergies' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Debug: Check extracted entities
    print("\n=== DEBUG: Check extracted allergy entities ===")
    print(f"Shape of df_allergies_entities: {df_allergies_entities.shape}")
    print(f"\nColumn names: {df_allergies_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_allergies_entities.head())

    # Check entity label distribution
    print("\n=== Entity Labels Found in Allergies ===")
    if not df_allergies_entities.empty:
        print(df_allergies_entities['label'].value_counts())

        # Let's look at the top entities
        print("\n=== Top Allergy-related Entities ===")
        allergy_entities = df_allergies_entities['text'].value_counts().head(30)
        print(allergy_entities)

    # Create specific labeling functions for allergy data
    def create_allergy_labeling_functions(entities_df):
        """Create specific labeling functions for allergies"""

        # Analyze entities to understand patterns
        top_entities = entities_df['text'].value_counts().head(100)
        print(f"\nAnalyzing allergy entities for patterns...")

        # Common allergen categories
        def lf_drug_allergy(row):
            """Detect drug/medication allergies"""
            text = str(row['allergies']).lower()
            drug_patterns = [
                'penicillin', 'amoxicillin', 'ampicillin', 'antibiotic',
                'sulfa', 'aspirin', 'nsaid', 'ibuprofen', 'morphine',
                'codeine', 'contrast', 'iodine', 'latex', 'adhesive'
            ]
            if any(drug in text for drug in drug_patterns):
                return 'DRUG_ALLERGY'
            return 'ABSTAIN'

        def lf_food_allergy(row):
            """Detect food allergies"""
            text = str(row['allergies']).lower()
            food_patterns = [
                'peanut', 'nut', 'shellfish', 'fish', 'milk', 'dairy',
                'egg', 'wheat', 'gluten', 'soy', 'sesame', 'food'
            ]
            if any(food in text for food in food_patterns):
                return 'FOOD_ALLERGY'
            return 'ABSTAIN'

        def lf_environmental_allergy(row):
            """Detect environmental allergies"""
            text = str(row['allergies']).lower()
            env_patterns = [
                'pollen', 'dust', 'mold', 'grass', 'tree', 'ragweed',
                'cat', 'dog', 'animal', 'dander', 'environmental'
            ]
            if any(env in text for env in env_patterns):
                return 'ENVIRONMENTAL_ALLERGY'
            return 'ABSTAIN'

        def lf_no_allergies(row):
            """Detect absence of allergies"""
            text = str(row['allergies']).lower()
            no_allergy_patterns = [
                'no known allergies', 'no allergies', 'nka', 'nkda',
                'no known drug allergies', 'denies allergies', 'none'
            ]
            if any(pattern in text for pattern in no_allergy_patterns):
                return 'NO_ALLERGIES'
            return 'ABSTAIN'

        def lf_allergy_severity(row):
            """Detect severe allergic reactions"""
            text = str(row['allergies']).lower()
            severity_patterns = [
                'anaphylaxis', 'anaphylactic', 'severe', 'life-threatening',
                'epipen', 'epinephrine', 'emergency'
            ]
            if any(pattern in text for pattern in severity_patterns):
                return 'SEVERE_ALLERGY'
            return 'ABSTAIN'

        def lf_allergy_reaction_type(row):
            """Detect specific reaction types"""
            text = str(row['allergies']).lower()
            reaction_patterns = [
                'rash', 'hives', 'swelling', 'itching', 'breathing',
                'wheezing', 'nausea', 'vomiting', 'throat'
            ]
            if any(reaction in text for reaction in reaction_patterns):
                return 'ALLERGIC_REACTION_DESCRIBED'
            return 'ABSTAIN'

        def lf_seasonal_allergy(row):
            """Detect seasonal allergies"""
            text = str(row['allergies']).lower()
            if any(season in text for season in ['seasonal', 'spring', 'fall', 'hay fever']):
                return 'SEASONAL_ALLERGY'
            return 'ABSTAIN'

        def lf_chemical_sensitivity(row):
            """Detect chemical sensitivities"""
            text = str(row['allergies']).lower()
            chemical_patterns = [
                'chemical', 'perfume', 'fragrance', 'smoke', 'detergent',
                'cleaning', 'formaldehyde'
            ]
            if any(chem in text for chem in chemical_patterns):
                return 'CHEMICAL_SENSITIVITY'
            return 'ABSTAIN'

        def lf_multiple_allergies(row):
            """Detect multiple allergies"""
            text = str(row['allergies']).lower()
            # Count commas or "and" as indicators of multiple allergies
            if (text.count(',') >= 2 or text.count(' and ') >= 2) and 'no known' not in text:
                return 'MULTIPLE_ALLERGIES'
            return 'ABSTAIN'

        def lf_allergy_testing(row):
            """Detect allergy testing mentions"""
            text = str(row['allergies']).lower()
            if any(test in text for test in ['tested', 'skin test', 'patch test', 'ige']):
                return 'ALLERGY_TESTED'
            return 'ABSTAIN'

        return [
            lf_drug_allergy,
            lf_food_allergy,
            lf_environmental_allergy,
            lf_no_allergies,
            lf_allergy_severity,
            lf_allergy_reaction_type,
            lf_seasonal_allergy,
            lf_chemical_sensitivity,
            lf_multiple_allergies,
            lf_allergy_testing
        ]

    # Apply the allergy-specific labeling functions
    allergy_lfs = create_allergy_labeling_functions(df_allergies_entities)

    print("\n=== Testing Allergy Labeling Functions ===")
    if not df_medical.empty and 'allergies' in df_medical.columns:
        # Test on multiple rows to see coverage
        test_rows = min(10, len(df_medical))

        for i in range(test_rows):
            row = df_medical.iloc[i]
            if pd.notna(row['allergies']):
                print(f"\nRow {i} allergies: {row['allergies'][:100]}...")
                for lf in allergy_lfs:
                    result = lf(row)
                    if result != 'ABSTAIN':
                        print(f"  {lf.__name__}: {result}")

    # Analyze coverage
    print("\n=== Allergy Labeling Function Coverage ===")
    total_non_na = df_medical['allergies'].notna().sum()
    for lf in allergy_lfs:
        labeled_count = sum(lf(row) != 'ABSTAIN' for _, row in df_medical.iterrows()
                           if pd.notna(row.get('allergies', '')))
        coverage = (labeled_count / total_non_na * 100) if total_non_na > 0 else 0
        print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

    # Analyze specific patterns in allergies
    print("\n=== Common Allergy Patterns ===")
    if 'allergies' in df_medical.columns:
        # Count specific allergen mentions
        penicillin_mentions = df_medical['allergies'].str.contains('penicillin|Penicillin', na=False).sum()
        no_allergy_mentions = df_medical['allergies'].str.contains('no known|NKA|NKDA', na=False, regex=True).sum()
        food_allergy_mentions = df_medical['allergies'].str.contains('peanut|shellfish|milk|egg', na=False, regex=True).sum()

        print(f"Penicillin allergy mentions: {penicillin_mentions}")
        print(f"No known allergies mentions: {no_allergy_mentions}")
        print(f"Food allergy mentions: {food_allergy_mentions}")
        print(f"Total allergy records: {df_medical['allergies'].notna().sum()}")

    # Create a summary of allergen types found
    print("\n=== Allergen Type Summary ===")
    allergen_summary = {
        'drugs': [],
        'foods': [],
        'environmental': [],
        'other': []
    }

    # Categorize top entities
    for entity, count in df_allergies_entities['text'].value_counts().head(50).items():
        entity_lower = entity.lower()
        if any(drug in entity_lower for drug in ['cillin', 'mycin', 'zole', 'statin']):
            allergen_summary['drugs'].append(entity)
        elif any(food in entity_lower for food in ['nut', 'milk', 'egg', 'fish']):
            allergen_summary['foods'].append(entity)
        elif any(env in entity_lower for env in ['pollen', 'dust', 'grass']):
            allergen_summary['environmental'].append(entity)
        else:
            allergen_summary['other'].append(entity)

    for category, items in allergen_summary.items():
        if items:
            print(f"\n{category.upper()}: {items[:10]}")  # Show top 10 in each category
# Save result
df_allergies_entities.to_csv('allergies_entities_comprehensive.csv', index=False)

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: allergies


Processing batches:   2%|▏         | 2/100 [00:00<00:38,  2.54it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:14,  5.97it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:12,  6.14it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:05<00:12,  5.65it/s]


Checkpoint saved at batch 9000


Processing batches:  41%|████      | 41/100 [00:07<00:10,  5.40it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:09<00:08,  5.69it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:11<00:07,  5.43it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:12<00:05,  5.52it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:14<00:03,  5.80it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:16<00:01,  5.77it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:17<00:00,  5.74it/s]


Found 32 entities appearing >= 5 times

Testing generated labeling functions for allergies...
lf_disease_allergies applied to row 0: ABSTAIN
lf_chemical_allergies applied to row 0: ABSTAIN

=== DEBUG: Check extracted allergy entities ===
Shape of df_allergies_entities: (934, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
             text     label  start  end  \
0       allergies   DISEASE      9   18   
1  drug allergies   DISEASE      9   23   
2         allergy   DISEASE     18   25   
3         Allergy   DISEASE      0    7   
4     amoxicillin  CHEMICAL     11   22   

                           original_text  row_idx source_column  
0                     No known allergies    32488     allergies  
1                No known drug allergies    77061     allergies  
2  No sensitivity or allergy to any drug   149806     allergies  
3                 Allergy to amoxicillin    83662     allergies  
4                 All

In [26]:
def extract_allergy_entities_custom(df_medical, text_column='allergies', id_column='idx'):
    """Extract allergy-specific entities beyond what BC5CDR captures"""
    
    # Build the id series once
    if id_column is not None and id_column in df_medical.columns:
        ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
    else:
        ids_series = df_medical.index
    
    texts = df_medical[text_column].fillna('').astype(str)
    
    # Allergy patterns
    allergy_patterns = {
        'drug_allergens': [
            'penicillin', 'amoxicillin', 'ampicillin', 'antibiotic', 'cephalosporin',
            'sulfa', 'sulfamethoxazole', 'bactrim', 'aspirin', 'nsaid', 'ibuprofen',
            'morphine', 'codeine', 'opioid', 'contrast', 'iodine', 'latex',
            'adhesive', 'lidocaine', 'vancomycin', 'tetracycline'
        ],
        'food_allergens': [
            'peanut', 'tree nut', 'almond', 'cashew', 'walnut', 'shellfish',
            'shrimp', 'lobster', 'fish', 'milk', 'dairy', 'lactose',
            'egg', 'wheat', 'gluten', 'soy', 'sesame', 'corn'
        ],
        'environmental_allergens': [
            'pollen', 'dust', 'dust mite', 'mold', 'grass', 'tree pollen',
            'ragweed', 'cat', 'dog', 'animal dander', 'pet dander',
            'bee sting', 'wasp', 'insect'
        ],
        'other_allergens': [
            'perfume', 'fragrance', 'smoke', 'detergent', 'nickel',
            'formaldehyde', 'rubber', 'wool', 'chemical'
        ]
    }
    
    # Reaction patterns
    reaction_patterns = [
        'rash', 'hives', 'urticaria', 'swelling', 'angioedema', 'itching',
        'pruritus', 'breathing difficulty', 'wheezing', 'anaphylaxis',
        'nausea', 'vomiting', 'throat swelling', 'flushing'
    ]
    
    # Status patterns
    status_patterns = {
        'no_allergies': ['no known allergies', 'no allergies', 'nka', 'nkda', 
                        'no known drug allergies', 'denies allergies', 'none'],
        'severe': ['anaphylaxis', 'anaphylactic', 'severe', 'life-threatening',
                  'epipen', 'epinephrine'],
        'tested': ['tested', 'skin test', 'patch test', 'ige', 'allergy testing']
    }
    
    allergy_entities = []
    
    for (df_index, text_original), idx in zip(texts.items(), ids_series):
        if text_original:
            text = text_original.lower()
            
            # Extract allergens
            for category, allergens in allergy_patterns.items():
                for allergen in allergens:
                    pattern = r'\b' + re.escape(allergen) + r'\b'
                    for match in re.finditer(pattern, text):
                        allergy_entities.append({
                            'text': text_original[match.start():match.end()],
                            'label': 'ALLERGEN',
                            'category': category,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': text_original,
                            'source': 'custom_allergy_extraction',
                            'row_idx': idx
                        })
            
            # Extract reactions
            for reaction in reaction_patterns:
                pattern = r'\b' + re.escape(reaction) + r'\b'
                for match in re.finditer(pattern, text):
                    allergy_entities.append({
                        'text': text_original[match.start():match.end()],
                        'label': 'ALLERGIC_REACTION',
                        'category': 'reaction',
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': text_original,
                        'source': 'custom_allergy_extraction',
                        'row_idx': idx
                    })
            
            # Extract allergy status
            for status_type, terms in status_patterns.items():
                for term in terms:
                    if term in text:
                        start_idx = text.find(term)
                        allergy_entities.append({
                            'text': text_original[start_idx:start_idx + len(term)],
                            'label': 'ALLERGY_STATUS',
                            'category': status_type,
                            'start': start_idx,
                            'end': start_idx + len(term),
                            'original_text': text_original,
                            'source': 'custom_allergy_extraction',
                            'row_idx': idx
                        })
    
    return pd.DataFrame(allergy_entities)


# Process allergies using BC5CDR model (recognizes DISEASE and CHEMICAL)
if __name__ == "__main__":
    # Run extraction for the 'allergies' column
    df_allergies_entities, allergies_summary, allergies_rules = run_medical_ner_extraction(
        df_medical, 
        text_column='allergies',
        model_name="en_ner_bc5cdr_md",
        batch_size=300,
        id_column='idx' 
    )

    # Test the generated rules
    print("\nTesting BC5CDR-generated labeling functions for allergies...")
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]
        for rule_name, rule_func in allergies_rules.items():
            try:
                test_result = rule_func(sample_row)
                if test_result != 'ABSTAIN':
                    print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'allergies' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Check BC5CDR results
    print("\n=== BC5CDR Extraction Results ===")
    print(f"Shape of df_allergies_entities: {df_allergies_entities.shape}")
    if not df_allergies_entities.empty:
        print("\nEntity Labels Found:")
        print(df_allergies_entities['label'].value_counts())
        print("\nTop Entities:")
        print(df_allergies_entities['text'].value_counts().head(30))

    # Extract custom allergy entities
    print("\n=== CUSTOM ALLERGY ENTITY EXTRACTION ===")
    df_custom_allergies = extract_allergy_entities_custom(
        df_medical,
        text_column='allergies',
        id_column='idx'
    )
    
    print(f"\nCustom extraction found {len(df_custom_allergies)} allergy entities")
    if not df_custom_allergies.empty:
        print("\nCustom Entity Distribution:")
        print(df_custom_allergies['label'].value_counts())
        print("\nCustom Entity Categories:")
        print(df_custom_allergies['category'].value_counts())

    # Create specific labeling functions for allergy data (spanified)
    def create_allergy_labeling_functions(entities_df):
        """Create spanified labeling functions for allergies"""
        COL = 'allergies'

        def lf_drug_allergy(row):
            """Detect drug/medication allergies"""
            text = str(row.get(COL, ''))
            drug_patterns = [
                'penicillin', 'amoxicillin', 'ampicillin', 'antibiotic',
                'sulfa', 'aspirin', 'nsaid', 'ibuprofen', 'morphine',
                'codeine', 'contrast', 'iodine', 'latex', 'adhesive'
            ]
            hit = _first_hit(text, drug_patterns)
            return {'label': 'DRUG_ALLERGY', 'column': COL, 'match': hit, 'category': 'drug'} if hit else {'label': 'ABSTAIN'}

        def lf_food_allergy(row):
            """Detect food allergies"""
            text = str(row.get(COL, ''))
            food_patterns = [
                'peanut', 'nut', 'shellfish', 'fish', 'milk', 'dairy',
                'egg', 'wheat', 'gluten', 'soy', 'sesame', 'food'
            ]
            hit = _first_hit(text, food_patterns)
            return {'label': 'FOOD_ALLERGY', 'column': COL, 'match': hit, 'category': 'food'} if hit else {'label': 'ABSTAIN'}

        def lf_environmental_allergy(row):
            """Detect environmental allergies"""
            text = str(row.get(COL, ''))
            env_patterns = [
                'pollen', 'dust', 'mold', 'grass', 'tree', 'ragweed',
                'cat', 'dog', 'animal', 'dander', 'environmental'
            ]
            hit = _first_hit(text, env_patterns)
            return {'label': 'ENVIRONMENTAL_ALLERGY', 'column': COL, 'match': hit, 'category': 'environmental'} if hit else {'label': 'ABSTAIN'}

        def lf_no_allergies(row):
            """Detect absence of allergies"""
            text = str(row.get(COL, ''))
            no_allergy_patterns = [
                'no known allergies', 'no allergies', 'nka', 'nkda',
                'no known drug allergies', 'denies allergies', 'none'
            ]
            hit = _first_hit(text, no_allergy_patterns)
            return {'label': 'NO_ALLERGIES', 'column': COL, 'match': hit, 'category': 'status'} if hit else {'label': 'ABSTAIN'}

        def lf_allergy_severity(row):
            """Detect severe allergic reactions"""
            text = str(row.get(COL, ''))
            severity_patterns = [
                'anaphylaxis', 'anaphylactic', 'severe', 'life-threatening',
                'epipen', 'epinephrine', 'emergency'
            ]
            hit = _first_hit(text, severity_patterns)
            return {'label': 'SEVERE_ALLERGY', 'column': COL, 'match': hit, 'category': 'severity'} if hit else {'label': 'ABSTAIN'}

        def lf_allergy_reaction_type(row):
            """Detect specific reaction types"""
            text = str(row.get(COL, ''))
            reaction_patterns = [
                'rash', 'hives', 'swelling', 'itching', 'breathing',
                'wheezing', 'nausea', 'vomiting', 'throat'
            ]
            hit = _first_hit(text, reaction_patterns)
            return {'label': 'ALLERGIC_REACTION_DESCRIBED', 'column': COL, 'match': hit, 'category': 'reaction'} if hit else {'label': 'ABSTAIN'}

        def lf_seasonal_allergy(row):
            """Detect seasonal allergies"""
            text = str(row.get(COL, ''))
            seasonal_terms = ['seasonal', 'spring', 'fall', 'hay fever']
            hit = _first_hit(text, seasonal_terms)
            return {'label': 'SEASONAL_ALLERGY', 'column': COL, 'match': hit, 'category': 'seasonal'} if hit else {'label': 'ABSTAIN'}

        def lf_chemical_sensitivity(row):
            """Detect chemical sensitivities"""
            text = str(row.get(COL, ''))
            chemical_patterns = [
                'chemical', 'perfume', 'fragrance', 'smoke', 'detergent',
                'cleaning', 'formaldehyde'
            ]
            hit = _first_hit(text, chemical_patterns)
            return {'label': 'CHEMICAL_SENSITIVITY', 'column': COL, 'match': hit, 'category': 'chemical'} if hit else {'label': 'ABSTAIN'}

        def lf_multiple_allergies(row):
            """Detect multiple allergies"""
            text = str(row.get(COL, ''))
            if (text.count(',') >= 2 or text.count(' and ') >= 2) and 'no known' not in text:
                # For multiple allergies, just return the first allergen found
                allergens = ['penicillin', 'peanut', 'shellfish', 'dust', 'pollen']
                hit = _first_hit(text, allergens)
                return {'label': 'MULTIPLE_ALLERGIES', 'column': COL, 'match': hit or 'multiple', 'category': 'multiple'} if hit or text else {'label': 'ABSTAIN'}
            return {'label': 'ABSTAIN'}

        def lf_allergy_testing(row):
            """Detect allergy testing mentions"""
            text = str(row.get(COL, ''))
            test_terms = ['tested', 'skin test', 'patch test', 'ige']
            hit = _first_hit(text, test_terms)
            return {'label': 'ALLERGY_TESTED', 'column': COL, 'match': hit, 'category': 'testing'} if hit else {'label': 'ABSTAIN'}

        return [
            lf_drug_allergy, lf_food_allergy, lf_environmental_allergy,
            lf_no_allergies, lf_allergy_severity, lf_allergy_reaction_type,
            lf_seasonal_allergy, lf_chemical_sensitivity, lf_multiple_allergies,
            lf_allergy_testing
        ]

    # Apply the allergy-specific labeling functions
    allergy_lfs = create_allergy_labeling_functions(df_allergies_entities)
    
    # Materialize spans from labeling functions
    df_allergy_lf_spans = materialize_lf_spans(df_medical, allergy_lfs, id_column='idx')
    print(f"\n=== LF-Generated Allergy Spans ===")
    print(f"Total LF-generated spans: {len(df_allergy_lf_spans)}")
    if not df_allergy_lf_spans.empty:
        print("\nLF span distribution:")
        print(df_allergy_lf_spans['label'].value_counts())

    # Row-level coverage analysis
    print("\n=== Row-Level Coverage Analysis ===")
    total_non_na = df_medical['allergies'].notna().sum()
    for lf in allergy_lfs:
        labeled_count = sum(lf(row).get('label') != 'ABSTAIN' for _, row in df_medical.iterrows()
                           if pd.notna(row.get('allergies', '')))
        coverage = (labeled_count / total_non_na * 100) if total_non_na > 0 else 0
        print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

    # Combine all sources
    print("\n=== COMBINING ALL ALLERGY ENTITY SOURCES ===")
    
    # Ensure common columns
    common_cols = ['text', 'label', 'start', 'end', 'original_text', 'source', 'row_idx', 'category']
    for col in common_cols:
        if col not in df_allergies_entities.columns:
            df_allergies_entities[col] = 'bc5cdr' if col == 'source' else None
        if col not in df_custom_allergies.columns:
            df_custom_allergies[col] = 'custom_extraction' if col == 'source' else None
        if col not in df_allergy_lf_spans.columns:
            df_allergy_lf_spans[col] = None
    
    df_all_allergy_entities = pd.concat([
        df_allergies_entities[common_cols],
        df_custom_allergies[common_cols],
        df_allergy_lf_spans[common_cols]
    ], ignore_index=True)
    
    print(f"\nTotal allergy entities: {len(df_all_allergy_entities)}")
    print(f"  - BC5CDR: {len(df_allergies_entities)}")
    print(f"  - Custom extraction: {len(df_custom_allergies)}")
    print(f"  - LF-generated: {len(df_allergy_lf_spans)}")
    
    # Analysis
    if not df_all_allergy_entities.empty:
        print("\nCombined entity distribution by label:")
        print(df_all_allergy_entities['label'].value_counts())
        print("\nCombined entity distribution by source:")
        print(df_all_allergy_entities['source'].value_counts())
        
        # Allergen type analysis
        allergen_entities = df_all_allergy_entities[
            df_all_allergy_entities['label'].isin(['ALLERGEN', 'DRUG_ALLERGY', 'FOOD_ALLERGY', 'ENVIRONMENTAL_ALLERGY'])
        ]
        if not allergen_entities.empty:
            print("\n=== Top Allergens by Type ===")
            for cat in ['drug_allergens', 'food_allergens', 'environmental_allergens']:
                cat_entities = allergen_entities[allergen_entities['category'] == cat]
                if not cat_entities.empty:
                    print(f"\n{cat.upper()}:")
                    print(cat_entities['text'].value_counts().head(10))

    # Add row-level allergy phenotypes
    print("\n=== Adding Row-Level Allergy Phenotypes ===")
    for lf in allergy_lfs:
        phenotype_name = f"has_{lf.__name__.replace('lf_', '')}"
        df_medical[phenotype_name] = df_medical.apply(
            lambda row: lf(row).get('label') != 'ABSTAIN',
            axis=1
        )

    # Save results
    df_allergies_entities.to_csv('allergies_entities_comprehensive.csv', index=False)
    df_custom_allergies.to_csv('allergies_entities_custom.csv', index=False)
    df_allergy_lf_spans.to_csv('allergies_lf_generated_spans.csv', index=False)
    
    print("\n\nAllergy entity extraction complete!")
    print("All results saved to CSV files.")

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: allergies


Processing batches:   2%|▏         | 2/100 [00:00<00:38,  2.52it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:14,  6.08it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:12,  6.28it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:05<00:11,  6.08it/s]


Checkpoint saved at batch 9000


Processing batches:  42%|████▏     | 42/100 [00:07<00:09,  5.85it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:08<00:08,  5.70it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:10<00:07,  5.20it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:12<00:05,  5.45it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:14<00:03,  5.75it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:15<00:01,  5.91it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:17<00:00,  5.87it/s]


Found 32 entities appearing >= 5 times

Testing BC5CDR-generated labeling functions for allergies...

=== BC5CDR Extraction Results ===
Shape of df_allergies_entities: (934, 7)

Entity Labels Found:
label
DISEASE     739
CHEMICAL    195
Name: count, dtype: int64

Top Entities:
text
allergies                        291
drug allergies                    92
allergy                           73
Allergy                           24
penicillin                        24
allergic reaction                 18
Allergic                          16
Penicillin                        16
allergic                          14
Allergic rhinitis                 12
hypersensitivity                  12
Seasonal allergies                11
Allergic reaction                 10
Penicillin allergy                 9
allergic rhinitis                  9
rash                               9
anaphylaxis                        7
vancomycin                         7
penicillin allergy                 7
paracetamol   

In [27]:
df_allergy_lf_spans

Unnamed: 0,text,label,category,start,end,original_text,source,row_idx
0,No known allergies,NO_ALLERGIES,status,0,18,No known allergies,lf:lf_no_allergies,32488
1,No known drug allergies,NO_ALLERGIES,status,0,23,No known drug allergies,lf:lf_no_allergies,77061
2,amoxicillin,DRUG_ALLERGY,drug,11,22,Allergy to amoxicillin,lf:lf_drug_allergy,83662
3,contrast,DRUG_ALLERGY,drug,44,52,"numerous drugs and diagnostic agents (e.g., co...",lf:lf_drug_allergy,67018
4,aspirin,DRUG_ALLERGY,drug,23,30,Unknown if allergic to aspirin,lf:lf_drug_allergy,91449
...,...,...,...,...,...,...,...,...
547,dust,ENVIRONMENTAL_ALLERGY,environmental,5,9,Mild dust allergy,lf:lf_environmental_allergy,157940
548,No known drug allergies,NO_ALLERGIES,status,0,23,No known drug allergies,lf:lf_no_allergies,135610
549,Severe,SEVERE_ALLERGY,severity,0,6,Severe allergic reaction to docetaxel chemothe...,lf:lf_allergy_severity,40349
550,No known drug allergies,NO_ALLERGIES,status,0,23,No known drug allergies,lf:lf_no_allergies,135761


In [28]:
df_allergies_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,source,category
0,allergies,DISEASE,9,18,No known allergies,32488,allergies,bc5cdr,
1,drug allergies,DISEASE,9,23,No known drug allergies,77061,allergies,bc5cdr,
2,allergy,DISEASE,18,25,No sensitivity or allergy to any drug,149806,allergies,bc5cdr,
3,Allergy,DISEASE,0,7,Allergy to amoxicillin,83662,allergies,bc5cdr,
4,amoxicillin,CHEMICAL,11,22,Allergy to amoxicillin,83662,allergies,bc5cdr,
...,...,...,...,...,...,...,...,...,...
929,docetaxel,CHEMICAL,28,37,Severe allergic reaction to docetaxel chemothe...,40349,allergies,bc5cdr,
930,drug allergies,DISEASE,9,23,No known drug allergies,135761,allergies,bc5cdr,
931,Allergic,DISEASE,0,8,Allergic to penicillin,138116,allergies,bc5cdr,
932,penicillin,CHEMICAL,12,22,Allergic to penicillin,138116,allergies,bc5cdr,


##### Drug Usage

In [29]:
def extract_drug_usage_entities_custom(df_medical, text_column='drug usage', id_column='idx'):
    """Extract drug usage-specific entities beyond what BC5CDR captures"""
    
    # Build the id series once
    if id_column is not None and id_column in df_medical.columns:
        ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
    else:
        ids_series = df_medical.index
    
    texts = df_medical[text_column].fillna('').astype(str)
    
    # Drug/substance patterns
    substance_patterns = {
        'alcohol': [
            'alcohol', 'drinking', 'beer', 'wine', 'liquor', 'spirits', 
            'ethanol', 'etoh', 'alcoholism', 'alcoholic'
        ],
        'tobacco': [
            'tobacco', 'smoking', 'cigarette', 'cigarettes', 'nicotine', 'pack',
            'cigar', 'chewing tobacco', 'vaping', 'e-cigarette', 'vape'
        ],
        'cannabis': [
            'cannabis', 'marijuana', 'thc', 'weed', 'pot', 'hemp',
            'mary jane', 'ganja', 'hash', 'cannabinoid', 'cbd'
        ],
        'opioids': [
            'opioid', 'opiate', 'heroin', 'morphine', 'oxycodone', 'hydrocodone',
            'fentanyl', 'codeine', 'tramadol', 'methadone', 'percocet', 
            'vicodin', 'oxycontin', 'dilaudid', 'demerol'
        ],
        'stimulants': [
            'cocaine', 'crack', 'amphetamine', 'methamphetamine', 'meth',
            'speed', 'crystal', 'adderall', 'ritalin', 'mdma', 'ecstasy',
            'molly', 'crank'
        ],
        'benzodiazepines': [
            'benzodiazepine', 'xanax', 'alprazolam', 'valium', 'diazepam',
            'ativan', 'lorazepam', 'klonopin', 'clonazepam'
        ],
        'hallucinogens': [
            'lsd', 'acid', 'mushrooms', 'psilocybin', 'pcp', 'ketamine',
            'dmt', 'mescaline', 'peyote'
        ]
    }
    
    # Usage patterns
    usage_patterns = [
        r'\b(\d+)\s*(?:pack|packs)\s*(?:per|/)\s*(?:day|week|year)',
        r'\b(\d+)\s*(?:drink|drinks|beer|beers)\s*(?:per|/)\s*(?:day|week)',
        r'\b(\d+)\s*years?\s*(?:of\s*)?(?:use|usage|history)',
        r'(?:daily|weekly|occasional|social|heavy|moderate|light)\s*(?:use|usage|user)',
        r'(?:former|past|current|active|recovering)\s*(?:user|addict|alcoholic)'
    ]
    
    # Status patterns
    status_patterns = {
        'denial': ['no drug', 'denies', 'denied', 'no history', 'no illicit',
                   'no substance', 'no recreational', 'never used', 'none'],
        'recovery': ['rehab', 'recovery', 'aa', 'na', 'sober', 'clean',
                     'abstinent', 'in recovery', 'treatment'],
        'active': ['current', 'active', 'ongoing', 'continues', 'daily',
                   'regular', 'frequent', 'occasional']
    }
    
    drug_usage_entities = []
    
    for (df_index, text_original), idx in zip(texts.items(), ids_series):
        if text_original:
            text = text_original.lower()
            
            # Extract substances
            for category, substances in substance_patterns.items():
                for substance in substances:
                    pattern = r'\b' + re.escape(substance) + r'\b'
                    for match in re.finditer(pattern, text):
                        drug_usage_entities.append({
                            'text': text_original[match.start():match.end()],
                            'label': 'SUBSTANCE',
                            'category': category,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': text_original,
                            'source': 'custom_drug_usage_extraction',
                            'row_idx': idx
                        })
            
            # Extract usage patterns
            for pattern in usage_patterns:
                for match in re.finditer(pattern, text, re.IGNORECASE):
                    drug_usage_entities.append({
                        'text': text_original[match.start():match.end()],
                        'label': 'USAGE_PATTERN',
                        'category': 'pattern',
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': text_original,
                        'source': 'custom_drug_usage_extraction',
                        'row_idx': idx
                    })
            
            # Extract status
            for status_type, terms in status_patterns.items():
                for term in terms:
                    if term in text:
                        start_idx = text.find(term)
                        drug_usage_entities.append({
                            'text': text_original[start_idx:start_idx + len(term)],
                            'label': 'USAGE_STATUS',
                            'category': status_type,
                            'start': start_idx,
                            'end': start_idx + len(term),
                            'original_text': text_original,
                            'source': 'custom_drug_usage_extraction',
                            'row_idx': idx
                        })
    
    return pd.DataFrame(drug_usage_entities)


# Process drug usage using BC5CDR model (recognizes DISEASE and CHEMICAL)
if __name__ == "__main__":
    # Run extraction for the 'drug usage' column
    df_drug_usage_entities, drug_usage_summary, drug_usage_rules = run_medical_ner_extraction(
        df_medical,  
        text_column='drug usage',
        model_name="en_ner_bc5cdr_md",
        batch_size=300, 
        id_column='idx'
    )

    # Test the generated rules
    print("\nTesting BC5CDR-generated labeling functions for drug usage...")
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]
        for rule_name, rule_func in drug_usage_rules.items():
            try:
                test_result = rule_func(sample_row)
                if test_result != 'ABSTAIN':
                    print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'drug usage' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Check BC5CDR results
    print("\n=== BC5CDR Extraction Results ===")
    print(f"Shape of df_drug_usage_entities: {df_drug_usage_entities.shape}")
    if not df_drug_usage_entities.empty:
        print("\nEntity Labels Found:")
        print(df_drug_usage_entities['label'].value_counts())
        print("\nTop Entities:")
        print(df_drug_usage_entities['text'].value_counts().head(30))

    # Extract custom drug usage entities
    print("\n=== CUSTOM DRUG USAGE ENTITY EXTRACTION ===")
    df_custom_drug_usage = extract_drug_usage_entities_custom(
        df_medical,
        text_column='drug usage',
        id_column='idx'
    )
    
    print(f"\nCustom extraction found {len(df_custom_drug_usage)} drug usage entities")
    if not df_custom_drug_usage.empty:
        print("\nCustom Entity Distribution:")
        print(df_custom_drug_usage['label'].value_counts())
        print("\nCustom Entity Categories:")
        print(df_custom_drug_usage['category'].value_counts())

    # Create specific labeling functions for drug usage data
    def create_drug_usage_labeling_functions(entities_df: pd.DataFrame):
        """Spanified drug usage LFs."""
        COL = 'drug usage'

        # Note: _first_hit is now defined globally

        def lf_no_drug_use(row):
            text = str(row.get(COL, ''))
            terms = ['no drug', 'denies', 'denied', 'no history', 'no illicit',
                    'no substance', 'no recreational', 'never', 'none',
                    'no personal history', 'negative']
            hit = _first_hit(text, terms)
            return {'label': 'NO_DRUG_USE', 'column': COL, 'match': hit, 'category': 'status'} if hit else {'label': 'ABSTAIN'}

        def lf_alcohol_use(row):
            text = str(row.get(COL, ''))
            terms = ['alcohol', 'drinking', 'beer', 'wine', 'liquor', 'spirits', 'ethanol', 'etoh', 'drinks per', 'social drinking']
            hit = _first_hit(text, terms)
            return {'label': 'ALCOHOL_USE', 'column': COL, 'match': hit, 'category': 'alcohol'} if hit else {'label': 'ABSTAIN'}

        def lf_tobacco_use(row):
            text = str(row.get(COL, ''))
            terms = ['tobacco', 'smoking', 'cigarette', 'nicotine', 'pack', 'cigar', 'chewing tobacco', 'vaping', 'e-cigarette']
            hit = _first_hit(text, terms)
            return {'label': 'TOBACCO_USE', 'column': COL, 'match': hit, 'category': 'tobacco'} if hit else {'label': 'ABSTAIN'}

        def lf_cannabis_use(row):
            text = str(row.get(COL, ''))
            terms = ['cannabis', 'marijuana', 'thc', 'weed', 'pot', 'hemp', 'mary jane', 'ganja', 'hash', 'cannabinoid']
            hit = _first_hit(text, terms)
            return {'label': 'CANNABIS_USE', 'column': COL, 'match': hit, 'category': 'cannabis'} if hit else {'label': 'ABSTAIN'}

        def lf_opioid_use(row):
            text = str(row.get(COL, ''))
            terms = ['opioid', 'opiate', 'heroin', 'morphine', 'oxycodone', 'hydrocodone', 'fentanyl', 'codeine', 'tramadol', 'methadone',
                    'percocet', 'vicodin', 'oxycontin']
            hit = _first_hit(text, terms)
            return {'label': 'OPIOID_USE', 'column': COL, 'match': hit, 'category': 'opioid'} if hit else {'label': 'ABSTAIN'}

        def lf_stimulant_use(row):
            text = str(row.get(COL, ''))
            terms = ['cocaine', 'crack', 'amphetamine', 'methamphetamine', 'meth', 'speed', 'crystal', 'adderall', 'ritalin', 'mdma', 'ecstasy']
            hit = _first_hit(text, terms)
            return {'label': 'STIMULANT_USE', 'column': COL, 'match': hit, 'category': 'stimulant'} if hit else {'label': 'ABSTAIN'}

        def lf_iv_drug_use(row):
            text = str(row.get(COL, ''))
            terms = ['iv drug', 'intravenous', 'injection', 'needle', 'inject', 'ivdu', 'shooting up']
            hit = _first_hit(text, terms)
            return {'label': 'IV_DRUG_USE', 'column': COL, 'match': hit, 'category': 'route'} if hit else {'label': 'ABSTAIN'}

        def lf_prescription_abuse(row):
            text = str(row.get(COL, ''))
            if ('prescription' in text.lower() or 'prescribed' in text.lower()) and \
               any(k in text.lower() for k in ['abuse', 'misuse', 'dependency', 'addiction']):
                return {'label': 'PRESCRIPTION_ABUSE', 'column': COL, 'match': 'prescription', 'category': 'abuse'}
            return {'label': 'ABSTAIN'}

        def lf_polysubstance_use(row):
            text = str(row.get(COL, '')).lower()
            subs = ['alcohol', 'tobacco', 'cannabis', 'cocaine', 'heroin', 'meth']
            found_subs = [s for s in subs if s in text]
            if len(found_subs) >= 2:
                return {'label': 'POLYSUBSTANCE_USE', 'column': COL, 'match': found_subs[0], 'category': 'multiple'}
            return {'label': 'ABSTAIN'}

        def lf_past_drug_use(row):
            text = str(row.get(COL, ''))
            past_terms = ['former', 'past', 'history of', 'previously', 'quit', 'stopped', 'used to', 'years ago', 'in recovery', 'sober']
            if any(t in text.lower() for t in past_terms) and not any(t in text.lower() for t in ['current', 'active', 'ongoing']):
                hit = _first_hit(text, past_terms)
                return {'label': 'PAST_DRUG_USE', 'column': COL, 'match': hit, 'category': 'temporal'}
            return {'label': 'ABSTAIN'}

        def lf_current_drug_use(row):
            text = str(row.get(COL, ''))
            terms = ['current', 'active', 'ongoing', 'continues', 'daily', 'regular', 'frequent', 'occasional']
            hit = _first_hit(text, terms)
            return {'label': 'CURRENT_DRUG_USE', 'column': COL, 'match': hit, 'category': 'temporal'} if hit else {'label': 'ABSTAIN'}

        def lf_drug_treatment(row):
            text = str(row.get(COL, ''))
            terms = ['rehab', 'treatment', 'recovery', 'aa', 'na', 'methadone clinic', 'suboxone', 'detox', 'counseling', 'therapy']
            hit = _first_hit(text, terms)
            return {'label': 'DRUG_TREATMENT', 'column': COL, 'match': hit, 'category': 'treatment'} if hit else {'label': 'ABSTAIN'}

        def lf_drug_screen_result(row):
            text = str(row.get(COL, ''))
            terms = ['drug screen', 'urine test', 'tested positive', 'tested negative']
            hit = _first_hit(text, terms)
            return {'label': 'DRUG_SCREEN_MENTIONED', 'column': COL, 'match': hit, 'category': 'testing'} if hit else {'label': 'ABSTAIN'}

        return [
            lf_no_drug_use, lf_alcohol_use, lf_tobacco_use, lf_cannabis_use, lf_opioid_use,
            lf_stimulant_use, lf_iv_drug_use, lf_prescription_abuse, lf_polysubstance_use,
            lf_past_drug_use, lf_current_drug_use, lf_drug_treatment, lf_drug_screen_result
        ]

    # Apply the drug usage-specific labeling functions
    drug_usage_lfs = create_drug_usage_labeling_functions(df_drug_usage_entities)
    
    # Materialize spans from labeling functions
    df_drug_usage_lf_spans = materialize_lf_spans(df_medical, drug_usage_lfs, id_column='idx')
    print(f"\n=== LF-Generated Drug Usage Spans ===")
    print(f"Total LF-generated spans: {len(df_drug_usage_lf_spans)}")
    if not df_drug_usage_lf_spans.empty:
        print("\nLF span distribution:")
        print(df_drug_usage_lf_spans['label'].value_counts())

    # Row-level coverage analysis
    print("\n=== Row-Level Coverage Analysis ===")
    total_non_na = df_medical['drug usage'].notna().sum() if 'drug usage' in df_medical.columns else 0
    if total_non_na > 0:
        for lf in drug_usage_lfs:
            labeled_count = sum(lf(row).get('label') != 'ABSTAIN' for _, row in df_medical.iterrows()
                               if pd.notna(row.get('drug usage', '')))
            coverage = (labeled_count / total_non_na * 100)
            print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

    # Combine all sources
    print("\n=== COMBINING ALL DRUG USAGE ENTITY SOURCES ===")
    
    _d_cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _d_cols:
        if col not in df_drug_usage_entities.columns:
            df_drug_usage_entities[col] = 'bc5cdr' if col=='source' else None
        if col not in df_custom_drug_usage.columns:
            df_custom_drug_usage[col] = 'custom_extraction' if col=='source' else None
        if col not in df_drug_usage_lf_spans.columns:
            df_drug_usage_lf_spans[col] = None

    df_all_drug_usage_entities = pd.concat([
        df_drug_usage_entities[_d_cols],
        df_custom_drug_usage[_d_cols],
        df_drug_usage_lf_spans[_d_cols]
    ], ignore_index=True)
    
    print(f"\nTotal drug usage entities: {len(df_all_drug_usage_entities)}")
    print(f"  - BC5CDR: {len(df_drug_usage_entities)}")
    print(f"  - Custom extraction: {len(df_custom_drug_usage)}")
    print(f"  - LF-generated: {len(df_drug_usage_lf_spans)}")
    
    # Analysis
    if not df_all_drug_usage_entities.empty:
        print("\nCombined entity distribution by label:")
        print(df_all_drug_usage_entities['label'].value_counts())
        print("\nCombined entity distribution by source:")
        print(df_all_drug_usage_entities['source'].value_counts())
        
        # Substance analysis
        substance_entities = df_all_drug_usage_entities[
            df_all_drug_usage_entities['label'].isin(['SUBSTANCE', 'ALCOHOL_USE', 'TOBACCO_USE', 
                                                     'CANNABIS_USE', 'OPIOID_USE', 'STIMULANT_USE'])
        ]
        if not substance_entities.empty:
            print("\n=== Top Substances Mentioned ===")
            print(substance_entities['text'].value_counts().head(20))

    # Add row-level drug usage phenotypes
    print("\n=== Adding Row-Level Drug Usage Phenotypes ===")
    for lf in drug_usage_lfs:
        phenotype_name = f"has_{lf.__name__.replace('lf_', '')}"
        df_medical[phenotype_name] = df_medical.apply(
            lambda row: lf(row).get('label') != 'ABSTAIN',
            axis=1
        )

    # Save results
    df_drug_usage_lf_spans.to_csv('drug_usage_lf_generated_spans.csv', index=False)
    df_all_drug_usage_entities.to_csv('drug_usage_entities_comprehensive.csv', index=False)
    df_custom_drug_usage.to_csv('drug_usage_entities_custom.csv', index=False)
    
    print("\n\nDrug usage entity extraction complete!")
    print("All results saved to CSV files.")

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: drug usage


Processing batches:   2%|▏         | 2/100 [00:00<00:36,  2.66it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:16,  5.32it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:13,  5.65it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:06<00:11,  5.75it/s]


Checkpoint saved at batch 9000


Processing batches:  42%|████▏     | 42/100 [00:07<00:10,  5.72it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:09<00:08,  5.69it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:11<00:06,  5.49it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:12<00:05,  5.31it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:14<00:03,  5.33it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:16<00:01,  5.55it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


Found 27 entities appearing >= 5 times

Testing BC5CDR-generated labeling functions for drug usage...

=== BC5CDR Extraction Results ===
Shape of df_drug_usage_entities: (715, 7)

Entity Labels Found:
label
CHEMICAL    447
DISEASE     268
Name: count, dtype: int64

Top Entities:
text
drug abuse                             113
cocaine                                 92
substance abuse                         47
cannabis                                39
methamphetamine                         31
Cocaine                                 27
steroid                                 17
alcohol                                 15
Heroin                                  14
Cannabis                                10
steroids                                10
ecstasy                                  9
methadone                                8
benzodiazepines                          7
amphetamines                             7
caffeine                                 6
LSD                        

In [30]:
df_drug_usage_lf_spans

Unnamed: 0,text,label,category,start,end,original_text,source,row_idx
0,tobacco,TOBACCO_USE,tobacco,22,29,Chewing pan of Indian tobacco for the last 15 ...,lf:lf_tobacco_use,40451
1,No history,NO_DRUG_USE,status,0,10,No history of tobacco or drug intake,lf:lf_no_drug_use,149806
2,tobacco,TOBACCO_USE,tobacco,14,21,No history of tobacco or drug intake,lf:lf_tobacco_use,149806
3,history of,PAST_DRUG_USE,temporal,3,13,No history of tobacco or drug intake,lf:lf_past_drug_use,149806
4,methadone,OPIOID_USE,opioid,13,22,Diazepam and methadone overdose,lf:lf_opioid_use,163624
...,...,...,...,...,...,...,...,...
1793,smoking,TOBACCO_USE,tobacco,13,20,occasionally smoking of marijuana,lf:lf_tobacco_use,100070
1794,marijuana,CANNABIS_USE,cannabis,24,33,occasionally smoking of marijuana,lf:lf_cannabis_use,100070
1795,Denied,NO_DRUG_USE,status,0,6,"Denied any intravenous drug use, urine drug sc...",lf:lf_no_drug_use,97973
1796,intravenous,IV_DRUG_USE,route,11,22,"Denied any intravenous drug use, urine drug sc...",lf:lf_iv_drug_use,97973


In [31]:
df_drug_usage_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,source,category
0,Diazepam,CHEMICAL,0,8,Diazepam and methadone overdose,163624,drug usage,bc5cdr,
1,methadone overdose,CHEMICAL,13,31,Diazepam and methadone overdose,163624,drug usage,bc5cdr,
2,zolpidem,CHEMICAL,9,17,Abuse of zolpidem,90815,drug usage,bc5cdr,
3,drug abuse,DISEASE,11,21,History of drug abuse,43921,drug usage,bc5cdr,
4,drug abuse,DISEASE,12,22,Intravenous drug abuse,84350,drug usage,bc5cdr,
...,...,...,...,...,...,...,...,...,...
710,bipolar disorder,DISEASE,34,50,Remote history of lithium use for bipolar diso...,113022,drug usage,bc5cdr,
711,cocaine,CHEMICAL,24,31,"Occasional recreational cocaine use, most rece...",137819,drug usage,bc5cdr,
712,cannabis,CHEMICAL,93,101,Heavy e-cigarette use for the previous 2 years...,99885,drug usage,bc5cdr,
713,neck pain,DISEASE,61,70,"Used heroin, street bought oral opiates to sel...",97753,drug usage,bc5cdr,


### Extracting Surgeries

In [32]:
df_surgery

Unnamed: 0,idx,has_surgery,reason,Type,time,outcome,details
0,155216,False,,,,,
1,133948,True,Idiopathic osteonecrosis of the femoral head,Total Hip Arthroplasty (THA),After diagnosis,Discharged in good condition without specific ...,First THA on the left hip
2,133948,True,Pain and limited ROM in the contralateral hip ...,Total Hip Arthroplasty (THA),One year after the first THA,Discharged in good condition without specific ...,Second THA on the contralateral hip
3,80176,True,Posttraumatic arthritis,Left elbow arthrodesis,At the age of 18,,Elbow was fused at 90 degrees
4,80176,True,Hypertrophic nonunion of ulnar shaft fracture ...,Repair of nonunion and conversion of elbow art...,Three months after the fall and subsequent con...,,The stem of the ulnar component would act as a...
...,...,...,...,...,...,...,...
35859,98004,True,Inferior segment elevation (ST) elevation myoc...,Primary percutaneous coronary intervention (dr...,,Successful treatment of right coronary artery ...,Procedure complicated by Ventricular Fibrillat...
35860,133320,True,Leiomyosarcoma,Wide tumor resection,,Successful with no adjuvant chemotherapy and r...,
35861,133320,True,Lung nodules,Excisional biopsy,One year and 3 months postoperatively,Histopathological diagnosis was consistent wit...,
35862,133320,True,Bone metastasis of the right femur,Cryoablation under CT guidance,,,Ablation needles were inserted into the proxim...


In [33]:
class TemporalStandardizer:
    """Extract and standardize temporal expressions from medical text"""

    def __init__(self):
        # Load spaCy model
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except:
            self.nlp = spacy.load("en_core_web_sm")

        # Conversion mappings - Define this BEFORE adding patterns
        self.time_units = {
            'day': 1, 'days': 1, 'd': 1,
            'week': 7, 'weeks': 7, 'wk': 7, 'wks': 7,
            'month': 30, 'months': 30, 'mo': 30, 'mos': 30,
            'year': 365, 'years': 365, 'yr': 365, 'yrs': 365,
            'hour': 1/24, 'hours': 1/24, 'hr': 1/24, 'hrs': 1/24,
            'minute': 1/1440, 'minutes': 1/1440, 'min': 1/1440, 'mins': 1/1440
        }

        # Initialize matcher for temporal patterns
        self.matcher = Matcher(self.nlp.vocab)
        # Now call _add_temporal_patterns AFTER time_units is defined
        self._add_temporal_patterns()


    def _add_temporal_patterns(self):
        """Add temporal patterns to spaCy matcher"""

        # Pattern: "X days/weeks/months"
        pattern1 = [
            {"LIKE_NUM": True},
            # Access self.time_units here is now safe
            {"LOWER": {"IN": list(self.time_units.keys())}}
        ]
        self.matcher.add("DURATION", [pattern1])

        # Pattern: "past X days/weeks"
        pattern2 = [
            {"LOWER": {"IN": ["past", "last", "previous"]}},
            {"LIKE_NUM": True},
            {"LOWER": {"IN": list(self.time_units.keys())}}
        ]
        self.matcher.add("PAST_DURATION", [pattern2])

        # Pattern: "X days/weeks ago"
        pattern3 = [
            {"LIKE_NUM": True},
            {"LOWER": {"IN": list(self.time_units.keys())}},
            {"LOWER": "ago"}
        ]
        self.matcher.add("AGO_DURATION", [pattern3])

    def extract_all_temporal_info(self, text):
        """Extract all temporal information from text"""
        if pd.isna(text) or text == '':
            return {
                'duration_days': None,
                'is_ongoing': False,
                'has_date': False,
                'temporal_type': None,
                'original_text': text
            }

        text = str(text)
        info = {
            'duration_days': None,
            'is_ongoing': False,
            'has_date': False,
            'temporal_type': None,
            'original_text': text
        }

        # Check for ongoing conditions
        ongoing_patterns = [
            'persisting', 'continuing', 'ongoing', 'current',
            'still', 'continues', 'persistent', 'chronic'
        ]
        info['is_ongoing'] = any(pattern in text.lower() for pattern in ongoing_patterns)

        # Try to extract specific dates
        dates = self._extract_dates(text)
        if dates:
            info['has_date'] = True
            info['extracted_dates'] = dates

        # Extract duration
        duration = self._extract_duration(text)
        if duration:
            info['duration_days'] = duration

        # Classify temporal type
        info['temporal_type'] = self._classify_temporal_type(text)

        return info

    def _extract_dates(self, text):
        """Extract actual dates from text"""
        dates = []

        # Common date patterns
        date_patterns = [
            r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b',  # MM/DD/YYYY or MM-DD-YYYY
            r'\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b',    # YYYY-MM-DD
            r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b',  # Month DD, YYYY
            r'\b\d{1,2}\s+(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{4}\b',     # DD Month YYYY
        ]

        for pattern in date_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                try:
                    parsed_date = parser.parse(match.group(), fuzzy=False)
                    dates.append(parsed_date)
                except:
                    continue

        return dates

    def _extract_duration(self, text):
        """Extract duration in days from temporal expressions"""
        text_lower = text.lower()

        # Use regex to find duration patterns
        patterns = [
            # "X days/weeks/months"
            r'(\d+)\s*(day|days|week|weeks|month|months|year|years|hour|hours)',
            # "a few days/weeks"
            r'(a few|several|couple of)\s*(day|days|week|weeks|month|months)',
            # Written numbers
            r'(one|two|three|four|five|six|seven|eight|nine|ten)\s*(day|days|week|weeks|month|months|year|years)',
        ]

        for pattern in patterns:
            match = re.search(pattern, text_lower)
            if match:
                # Extract number
                number_text = match.group(1)
                unit = match.group(2)

                # Convert to number
                if number_text.isdigit():
                    number = int(number_text)
                elif number_text in ['a few', 'several']:
                    number = 3  # Approximate
                elif number_text == 'couple of':
                    number = 2
                else:
                    # Convert written numbers
                    number_map = {
                        'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
                        'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10
                    }
                    number = number_map.get(number_text, 1)

                # Convert to days
                if unit in self.time_units:
                    return number * self.time_units[unit]

        return None

    def _classify_temporal_type(self, text):
        """Classify the type of temporal expression"""
        text_lower = text.lower()

        if any(word in text_lower for word in ['ago', 'before', 'prior', 'previously']):
            return 'past_reference'
        elif any(word in text_lower for word in ['since', 'from', 'started']):
            return 'onset_reference'
        elif any(word in text_lower for word in ['for', 'duration', 'lasted']):
            return 'duration_reference'
        elif any(word in text_lower for word in ['until', 'through', 'to']):
            return 'range_reference'
        elif any(word in text_lower for word in ['after', 'following', 'post']):
            return 'post_event'
        elif re.search(r'\d{4}', text):  # Contains year
            return 'absolute_date'
        else:
            return 'unspecified'

    def standardize_temporal_column(self, df, column_name):
        """Standardize an entire temporal column"""

        print(f"\nProcessing temporal column: {column_name}")

        # Apply extraction to all values
        temporal_data = df[column_name].apply(self.extract_all_temporal_info)

        # Convert to DataFrame
        temporal_df = pd.DataFrame(temporal_data.tolist())

        # Add prefix to column names
        temporal_df.columns = [f"{column_name}_{col}" for col in temporal_df.columns]

        # Concatenate with original dataframe
        result_df = pd.concat([df, temporal_df], axis=1)

        # Generate report
        report = self._generate_temporal_report(temporal_df, column_name)

        return result_df, report

    def _generate_temporal_report(self, temporal_df, column_name):
        """Generate report on temporal extraction"""

        duration_col = f"{column_name}_duration_days"
        type_col = f"{column_name}_temporal_type"

        report = {
            'column': column_name,
            'total_entries': len(temporal_df),
            'extracted_durations': temporal_df[duration_col].notna().sum(),
            'ongoing_conditions': temporal_df[f"{column_name}_is_ongoing"].sum(),
            'has_specific_dates': temporal_df[f"{column_name}_has_date"].sum(),
            'temporal_types': temporal_df[type_col].value_counts().to_dict()
        }

        # Duration statistics
        if temporal_df[duration_col].notna().any():
            report['duration_stats'] = {
                'min_days': temporal_df[duration_col].min(),
                'max_days': temporal_df[duration_col].max(),
                'mean_days': temporal_df[duration_col].mean(),
                'median_days': temporal_df[duration_col].median()
            }

        return report

In [34]:
def extract_surgical_entities_custom(df_surgery):
    """
    Custom extraction for anatomy and procedures not caught by the model.
    Uses 'idx' column for row identity when available; falls back to df.index otherwise.
    Adds 'row_idx' to every extracted entity.
    Uses regex word-boundary matching; allows simple plurals (e.g., hips, arteries).
    """

    custom_entities = []

    # Anatomy vocab by category
    anatomy_patterns = {
        'hip':        ['hip', 'acetabul', 'femoral head'],
        'knee':       ['knee', 'meniscus', 'patella', 'tibia'],
        'bone':       ['bone', 'femur', 'humerus', 'radius', 'ulna'],
        'joint':      ['joint', 'articulation'],
        'spine':      ['spine', 'vertebra', 'disc', 'lumbar', 'cervical'],
        'organ':      ['kidney', 'liver', 'heart', 'lung', 'spleen'],
        'vessel':     ['artery', 'vein', 'vessel', 'vascular'],
        'nerve':      ['nerve', 'neural', 'plexus'],
        'muscle':     ['muscle', 'tendon', 'ligament']
    }

    # Procedure vocab by category
    procedure_patterns = {
        'replacement':   ['arthroplasty', 'replacement', 'implant'],
        'repair':        ['repair', 'reconstruction', 'fixation'],
        'removal':       ['removal', 'excision', 'resection', 'ectomy'],
        'fusion':        ['fusion', 'arthrodesis', 'osteotomy'],
        'diagnostic':    ['biopsy', 'exploration', 'diagnostic'],
        'decompression': ['decompression', 'release', 'neurolysis'],
        'transplant':    ['transplant', 'graft', 'implantation']
    }

    laterality_terms = ['left', 'right', 'bilateral', 'unilateral']

    for df_idx, row in df_surgery.iterrows():
        # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
        if 'idx' in row and pd.notna(row['idx']):
            row_id = row['idx']
        else:
            row_id = df_idx

        # ---- text (keep original for offsets; search on lowercased copy)
        original_text = str(row.get('combined_text', ''))
        text = original_text.lower()

        # ---- Anatomy extraction (regex with word boundaries; allow simple plural 's')
        for category, terms in anatomy_patterns.items():
            for term in terms:
                # allow plural 's' for basic nouns (skip stems like 'acetabul' where 's' doesn't apply cleanly)
                plural_opt = 's?' if term.isalpha() and not term.endswith(('al','el','ul')) else ''
                pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                for m in re.finditer(pattern, text):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'ANATOMY',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'custom_extraction',
                        'row_idx': row_id,   # <-- correct ID retained
                    })

        # ---- Procedure extraction (also boundary-based)
        for category, terms in procedure_patterns.items():
            for term in terms:
                plural_opt = 's?' if term.isalpha() else ''
                pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                for m in re.finditer(pattern, text):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'PROCEDURE',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'custom_extraction',
                        'row_idx': row_id,   # <-- correct ID retained
                    })

        # ---- Laterality extraction
        for lat in laterality_terms:
            pattern = r'\b' + re.escape(lat) + r'\b'
            for m in re.finditer(pattern, text):
                custom_entities.append({
                    'text': original_text[m.start():m.end()],
                    'label': 'LATERALITY',
                    'category': 'laterality',
                    'start': m.start(),
                    'end': m.end(),
                    'original_text': original_text,
                    'source': 'custom_extraction',
                    'row_idx': row_id,   # <-- correct ID retained
                })

    return pd.DataFrame(custom_entities)

if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_surgery_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_surgery, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine surgery text columns
    # (use str() to be tolerant to NaNs)
    df_surgery['combined_text'] = df_surgery.apply(
        lambda row: f"{str(row.get('reason',''))} {str(row.get('Type',''))} {str(row.get('details',''))} {str(row.get('outcome',''))}",
        axis=1
    )

    # Copy combined_text to processed dataframe (keeps alignment with df_surgery by index)
    df_surgery_processed['combined_text'] = df_surgery['combined_text']

    # Add temporal features to combined text
    df_surgery_processed['combined_text_enriched'] = df_surgery_processed.apply(
        lambda row: f"{row['combined_text']} {'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # NER over the (non-enriched) combined_text in df_surgery
    df_surgery_entities, surgery_summary, surgery_rules = run_medical_ner_extraction(
        df_surgery,
        text_column='combined_text',
        model_name='en_ner_bionlp13cg_md', 
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== Entity Types Found ===")
    print(surgery_summary['entity_types'])


    # Extract custom entities
    df_custom_entities = extract_surgical_entities_custom(df_surgery)

    print(f"\nCustom extraction found:")
    if not df_custom_entities.empty and 'label' in df_custom_entities.columns:
        print(df_custom_entities['label'].value_counts())
    else:
        print("No custom entities found.")

    # Combine model entities and custom entities
    # (keep row_idx/category when present for traceability)
    df_all_surgery_entities = pd.concat(
        [df_surgery_entities, df_custom_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_surgery_entities.empty and 'label' in df_all_surgery_entities.columns:
        print(df_all_surgery_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze anatomy entities
    anatomy_entities = df_all_surgery_entities[df_all_surgery_entities['label'] == 'ANATOMY'] if not df_all_surgery_entities.empty else pd.DataFrame()
    if not anatomy_entities.empty:
        print(f"\n=== Top Anatomical Entities ===")
        print(anatomy_entities['text'].str.lower().value_counts().head(20))

        if 'category' in anatomy_entities.columns:
            print(f"\n=== Anatomy by Category ===")
            print(anatomy_entities['category'].value_counts())

    # Analyze procedure entities
    procedure_entities = df_all_surgery_entities[df_all_surgery_entities['label'] == 'PROCEDURE'] if not df_all_surgery_entities.empty else pd.DataFrame()
    if not procedure_entities.empty:
        print(f"\n=== Top Surgical Procedures ===")
        print(procedure_entities['text'].str.lower().value_counts().head(20))

        if 'category' in procedure_entities.columns:
            print(f"\n=== Procedures by Category ===")
            print(procedure_entities['category'].value_counts())

    # ========== Labeling functions (unchanged logic) ==========

    def create_surgical_labeling_functions():
        COL = 'combined_text'   # main free text
        REASON = 'reason'       # for emergency logic

        def lf_hip_surgery(row):
            text = str(row.get(COL,''))
            if ('hip' in text.lower() and _first_hit(text, ['arthroplasty','replacement','repair'])):
                return {'label':'HIP_SURGERY','column':COL,'match':'hip','category':'joint'}
            return {'label':'ABSTAIN'}

        def lf_knee_surgery(row):
            text = str(row.get(COL,''))
            if ('knee' in text.lower() and _first_hit(text, ['arthroplasty','replacement','arthroscopy'])):
                return {'label':'KNEE_SURGERY','column':COL,'match':'knee','category':'joint'}
            return {'label':'ABSTAIN'}

        def lf_fracture_surgery(row):
            text = str(row.get(COL,''))
            if ('fracture' in text.lower() and _first_hit(text, ['fixation','repair','reduction'])):
                return {'label':'FRACTURE_SURGERY','column':COL,'match':'fracture','category':'bone'}
            return {'label':'ABSTAIN'}

        def lf_bilateral_procedure(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['bilateral'])
            return {'label':'BILATERAL_PROCEDURE','column':COL,'match':hit,'category':'laterality'} if hit else {'label':'ABSTAIN'}

        def lf_minimally_invasive(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['arthroscopic','endoscopic','laparoscopic','minimally invasive'])
            return {'label':'MINIMALLY_INVASIVE','column':COL,'match':hit,'category':'approach'} if hit else {'label':'ABSTAIN'}

        def lf_emergency_procedure(row):
            reason = str(row.get(REASON,''))
            hit = _first_hit(reason, ['emergency','urgent','acute','trauma'])
            return {'label':'EMERGENCY_SURGERY','column':REASON,'match':hit,'category':'urgency'} if hit else {'label':'ABSTAIN'}

        return [lf_hip_surgery, lf_knee_surgery, lf_fracture_surgery,
                lf_bilateral_procedure, lf_minimally_invasive, lf_emergency_procedure]

    # --- materialize spans ---
    surgical_lfs_span = create_surgical_labeling_functions()
    df_surgery_lf_spans = materialize_lf_spans(df_surgery, surgical_lfs_span, id_column='idx')
    print(f"\nLF-generated surgery spans: {len(df_surgery_lf_spans)}")
    if not df_surgery_lf_spans.empty:
        print(df_surgery_lf_spans['label'].value_counts())

    # --- combine model + custom + LF spans ---
    _cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _cols:
        if col not in df_surgery_entities.columns: df_surgery_entities[col] = 'bionlp13cg' if col=='source' else None
        if col not in df_custom_entities.columns: df_custom_entities[col] = 'custom_extraction' if col=='source' else None
        if col not in df_surgery_lf_spans.columns: df_surgery_lf_spans[col] = None

    df_all_surgery_entities = pd.concat(
        [df_surgery_entities[_cols], df_custom_entities[_cols], df_surgery_lf_spans[_cols]],
        ignore_index=True
    )
    print(f"Total surgery entities (model+custom+LF): {len(df_all_surgery_entities)}")
    df_surgery_lf_spans.to_csv('surgery_lf_generated_spans.csv', index=False)



Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 6388
Temporal types: {'unspecified': 3554, 'absolute_date': 2836, 'post_event': 2806, 'past_reference': 2754, 'range_reference': 654, 'duration_reference': 261, 'onset_reference': 191}
Loading en_ner_bionlp13cg_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 35864 texts in 120 batches...
Using model: en_ner_bionlp13cg_md for column: combined_text


Processing batches:   1%|          | 1/120 [00:02<04:12,  2.12s/it]


Checkpoint saved at batch 0


Processing batches:   9%|▉         | 11/120 [00:16<02:41,  1.48s/it]


Checkpoint saved at batch 3000


Processing batches:  18%|█▊        | 21/120 [00:30<02:19,  1.41s/it]


Checkpoint saved at batch 6000


Processing batches:  26%|██▌       | 31/120 [00:44<02:09,  1.45s/it]


Checkpoint saved at batch 9000


Processing batches:  34%|███▍      | 41/120 [00:57<01:53,  1.43s/it]


Checkpoint saved at batch 12000


Processing batches:  42%|████▎     | 51/120 [01:11<01:43,  1.51s/it]


Checkpoint saved at batch 15000


Processing batches:  51%|█████     | 61/120 [01:25<01:28,  1.49s/it]


Checkpoint saved at batch 18000


Processing batches:  59%|█████▉    | 71/120 [01:38<01:14,  1.52s/it]


Checkpoint saved at batch 21000


Processing batches:  68%|██████▊   | 81/120 [01:52<00:57,  1.48s/it]


Checkpoint saved at batch 24000


Processing batches:  76%|███████▌  | 91/120 [02:05<00:43,  1.49s/it]


Checkpoint saved at batch 27000


Processing batches:  84%|████████▍ | 101/120 [02:19<00:30,  1.59s/it]


Checkpoint saved at batch 30000


Processing batches:  92%|█████████▎| 111/120 [02:32<00:14,  1.56s/it]


Checkpoint saved at batch 33000


Processing batches: 100%|██████████| 120/120 [02:43<00:00,  1.36s/it]


Found 3452 entities appearing >= 5 times

=== Entity Types Found ===
{'MULTI_TISSUE_STRUCTURE': 32237, 'CANCER': 17003, 'TISSUE': 16665, 'ORGAN': 14854, 'PATHOLOGICAL_FORMATION': 14321, 'ORGANISM_SUBDIVISION': 4084, 'GENE_OR_GENE_PRODUCT': 3097, 'CELL': 2784, 'SIMPLE_CHEMICAL': 2742, 'ORGANISM': 2378, 'IMMATERIAL_ANATOMICAL_ENTITY': 1800, 'CELLULAR_COMPONENT': 1777, 'ORGANISM_SUBSTANCE': 1306, 'AMINO_ACID': 111, 'ANATOMICAL_SYSTEM': 48, 'DEVELOPING_ANATOMICAL_STRUCTURE': 1}

Custom extraction found:
label
PROCEDURE     24770
ANATOMY       20351
LATERALITY    19618
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
MULTI_TISSUE_STRUCTURE             32237
PROCEDURE                          24770
ANATOMY                            20351
LATERALITY                         19618
CANCER                             17003
TISSUE                             16665
ORGAN                              14854
PATHOLOGICAL_FORMATION             14321
ORGANISM_SUBDIVISION           

In [35]:
df_all_surgery_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,femoral head,PATHOLOGICAL_FORMATION,32,44,Idiopathic osteonecrosis of the femoral head T...,bionlp13cg,133948,
1,left hip,PATHOLOGICAL_FORMATION,91,99,Idiopathic osteonecrosis of the femoral head T...,bionlp13cg,133948,
2,joint Total,MULTI_TISSUE_STRUCTURE,46,57,Pain and limited ROM in the contralateral hip ...,bionlp13cg,133948,
3,Left elbow,TISSUE,24,34,Posttraumatic arthritis Left elbow arthrodesis...,bionlp13cg,80176,
4,ulnar shaft,MULTI_TISSUE_STRUCTURE,25,36,Hypertrophic nonunion of ulnar shaft fracture ...,bionlp13cg,80176,
...,...,...,...,...,...,...,...,...
186186,bilateral,BILATERAL_PROCEDURE,237,246,"Suprasellar, hemorrhagic mass with optic chias...",lf:lf_bilateral_procedure,77772,laterality
186187,Endoscopic,MINIMALLY_INVASIVE,60,70,"Suprasellar, hemorrhagic mass with optic chias...",lf:lf_minimally_invasive,77772,approach
186188,Laparoscopic,MINIMALLY_INVASIVE,20,32,Internal herniation Laparoscopic surgery First...,lf:lf_minimally_invasive,138087,approach
186189,Minimally invasive,MINIMALLY_INVASIVE,94,112,Incidental mass in the right kidney CT-guided ...,lf:lf_minimally_invasive,157822,approach


### Extracting Symptoms

In [36]:
df_symptoms

Unnamed: 0,idx,has_symptom,name of symptom,intensity of symptom,location,time,temporalisation,behaviours affecting the symptom,details
0,155216,True,"Discomfort in the neck and lower back, restric...",,Neck and lower back,Past four months,,Standing up from a sitting position,Head turned to the right and upwards due to su...
1,133948,True,Pain,Severe,Left hip joint,Persisting for two months,Increased over the following three weeks,Aggravated by hip joint flexion or rotation,Also complained of pain and limited ROM in the...
2,133948,True,Restricted range of motion,,Left hip joint,Persisting for two months,,,
3,133948,True,Gait disturbance,Severe,,,,Secondary to hip pain,Continued for two months and increased over th...
4,133948,True,Moderate moon face,Moderate,Face,At the time of the second surgery,,,Initially overlooked as weight gain
...,...,...,...,...,...,...,...,...,...
54939,137017,True,Left-sided weakness,,Left side,,,,
54940,98004,True,Chest pain,,Chest,,,,Cardiac sounding
54941,133320,True,Mass in right thigh,,Lateral side of the right thigh,Noticed four years prior to presentation,,,"Diameter of 4 cm, no adhesion with skin and no..."
54942,97973,True,Crushing substernal chest pressure,Acute onset,Substernal,,Following 1-week-long febrile illness,,Accompanied by dyspnea and profuse sweating


In [37]:
def extract_symptom_entities_custom(df_symptoms):
    """
    Extract symptom-specific entities not caught by BC5CDR.
    Uses 'idx' for row identity when available; falls back to df.index otherwise.
    Adds 'row_idx' to every extracted entity.
    Regex with word boundaries is used to avoid substring false positives.
    """
    custom_entities = []

    # Symptom type patterns
    symptom_patterns = {
        'pain':         ['pain', 'ache', 'soreness', 'tenderness', 'discomfort'],
        'neurological': ['numbness', 'tingling', 'weakness', 'paralysis', 'tremor'],
        'mobility':     ['inability to walk', 'gait disturbance', 'limited range of motion', 'stiffness'],
        'swelling':     ['swelling', 'edema', 'inflammation', 'mass', 'lump'],
        'sensory':      ['vision', 'hearing', 'taste', 'smell', 'sensation'],
        'systemic':     ['fever', 'fatigue', 'malaise', 'weight loss', 'night sweats'],
        'respiratory':  ['dyspnea', 'cough', 'wheezing', 'shortness of breath'],
        'gi':           ['nausea', 'vomiting', 'diarrhea', 'constipation', 'abdominal'],
        'skin':         ['rash', 'itching', 'lesion', 'discoloration', 'bruising']
    }

    # Anatomical location patterns
    anatomy_patterns = {
        'head_neck':      ['head', 'neck', 'scalp', 'face', 'throat', 'cervical'],
        'upper_extremity': ['shoulder', 'arm', 'elbow', 'forearm', 'wrist', 'hand', 'finger'],
        'lower_extremity': ['hip', 'thigh', 'knee', 'leg', 'ankle', 'foot', 'toe'],
        'trunk':           ['chest', 'back', 'abdomen', 'pelvis', 'spine', 'lumbar'],
        'joint':           ['joint', 'articulation'],
        'internal':        ['heart', 'lung', 'liver', 'kidney', 'stomach']
    }

    # Severity/Intensity (column-driven)
    severity_values = {'mild', 'moderate', 'severe', 'extreme', 'minimal', 'significant'}

    # Temporal patterns
    temporal_patterns = {
        'acute':        ['acute', 'sudden', 'abrupt', 'rapid'],
        'chronic':      ['chronic', 'persistent', 'ongoing', 'continuous'],
        'intermittent': ['intermittent', 'episodic', 'recurrent', 'periodic'],
        'progressive':  ['progressive', 'worsening', 'increasing', 'deteriorating']
    }

    laterality_terms = ['left', 'right', 'bilateral', 'both']

    for df_index, row in df_symptoms.iterrows():
        # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
        row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

        original_text = str(row.get('combined_text', ''))
        combined = original_text.lower()

        # Columns (original + lower)
        symptom_name_orig = str(row.get('name of symptom', ''))
        symptom_name = symptom_name_orig.lower()
        location_orig = str(row.get('location', ''))
        location = location_orig.lower()
        intensity_orig = str(row.get('intensity of symptom', ''))
        intensity = intensity_orig.lower()
        temporalisation_orig = str(row.get('temporalisation', ''))
        temporalisation = temporalisation_orig.lower()

        # ---- Primary SYMPTOM from name column (if present)
        if pd.notna(row.get('name of symptom')) and symptom_name.strip() != '' and symptom_name != 'nan':
            # try to find position in combined; if not found, set to 0
            start = combined.find(symptom_name)
            if start < 0:
                start = 0
            end = start + len(symptom_name_orig)
            end = min(end, len(original_text))
            custom_entities.append({
                'text': symptom_name_orig,
                'label': 'SYMPTOM',
                'category': 'primary_symptom',
                'start': start,
                'end': end,
                'original_text': original_text,
                'source': 'symptom_name_column',
                'row_idx': row_id
            })

        # ---- Symptom types (regex boundary matches)
        for category, terms in symptom_patterns.items():
            for term in terms:
                pattern = r'\b' + re.escape(term.lower()) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'SYMPTOM_TYPE',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'pattern_matching',
                        'row_idx': row_id
                    })

        # ---- Anatomical location from location column
        if pd.notna(row.get('location')) and location.strip() != '' and location != 'nan':
            start = combined.find(location)
            if start >= 0:
                end = start + len(location)
                custom_entities.append({
                    'text': location_orig,
                    'label': 'ANATOMY',
                    'category': 'symptom_location',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'location_column',
                    'row_idx': row_id
                })

        # ---- Additional anatomy from patterns
        for category, terms in anatomy_patterns.items():
            for term in terms:
                # allow simple plural 's' for nouns
                plural_opt = 's?' if term.isalpha() else ''
                pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'ANATOMY',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'pattern_matching',
                        'row_idx': row_id
                    })

        # ---- Severity from intensity column (exact value match)
        if pd.notna(row.get('intensity of symptom')) and intensity in severity_values:
            start = combined.find(intensity)
            if start >= 0:
                custom_entities.append({
                    'text': intensity_orig,
                    'label': 'SEVERITY',
                    'category': 'intensity',
                    'start': start,
                    'end': start + len(intensity_orig),
                    'original_text': original_text,
                    'source': 'intensity_column',
                    'row_idx': row_id
                })

        # ---- Temporal patterns (look in temporalisation column and combined text)
        for category, terms in temporal_patterns.items():
            for term in terms:
                found = False
                # search in combined first (to get offsets)
                m = re.search(r'\b' + re.escape(term) + r'\b', combined)
                if m:
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'TEMPORAL_PATTERN',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'temporal_extraction',
                        'row_idx': row_id
                    })
                    found = True
                # if not found in combined but present in temporalisation text, add without exact offset
                if not found and term in temporalisation:
                    custom_entities.append({
                        'text': term,
                        'label': 'TEMPORAL_PATTERN',
                        'category': category,
                        'start': 0,
                        'end': len(term),
                        'original_text': original_text,
                        'source': 'temporal_extraction',
                        'row_idx': row_id
                    })

        # ---- Laterality (location or combined)
        for lat in laterality_terms:
            pattern = r'\b' + re.escape(lat) + r'\b'
            for m in re.finditer(pattern, combined):
                custom_entities.append({
                    'text': original_text[m.start():m.end()],
                    'label': 'LATERALITY',
                    'category': 'laterality',
                    'start': m.start(),
                    'end': m.end(),
                    'original_text': original_text,
                    'source': 'laterality_extraction',
                    'row_idx': row_id
                })

    return pd.DataFrame(custom_entities)

if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_symptoms_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_symptoms, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine symptom information into comprehensive text (robust to NaNs)
    df_symptoms['combined_text'] = df_symptoms.apply(
        lambda row: f"{str(row.get('name of symptom',''))} with {str(row.get('intensity of symptom',''))} "
                    f"intensity located in {str(row.get('location',''))} lasting {str(row.get('time',''))} "
                    f"{str(row.get('temporalisation',''))} {str(row.get('details',''))}",
        axis=1
    )

    # Copy combined_text to processed dataframe (corrected to use df_symptoms)
    df_symptoms_processed['combined_text'] = df_symptoms['combined_text']

    # Add temporal features to combined text (uses time_duration_days if present)
    df_symptoms_processed['combined_text_enriched'] = df_symptoms_processed.apply(
        lambda row: f"{row['combined_text']} "
                    f"{'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # Run NER extraction with BC5CDR over combined_text
    df_symptoms_entities, symptoms_summary, symptoms_rules = run_medical_ner_extraction(
        df_symptoms,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(symptoms_summary['entity_types'])


    # Extract custom entities
    df_custom_symptom_entities = extract_symptom_entities_custom(df_symptoms)

    print(f"\nCustom extraction found:")
    if not df_custom_symptom_entities.empty and 'label' in df_custom_symptom_entities.columns:
        print(df_custom_symptom_entities['label'].value_counts())
    else:
        print("No custom symptom entities found.")

    # Combine all entities (preserve row_idx/category when present)
    df_all_symptom_entities = pd.concat(
        [df_symptoms_entities, df_custom_symptom_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_symptom_entities.empty and 'label' in df_all_symptom_entities.columns:
        print(df_all_symptom_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze symptom entities
    symptom_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'SYMPTOM'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not symptom_entities.empty:
        print(f"\n=== Top Symptoms ===")
        print(symptom_entities['text'].str.lower().value_counts().head(20))

    # Analyze anatomical locations
    anatomy_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'ANATOMY'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not anatomy_entities.empty:
        print(f"\n=== Top Symptom Locations ===")
        print(anatomy_entities['text'].str.lower().value_counts().head(20))

    # Analyze severity
    severity_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'SEVERITY'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not severity_entities.empty:
        print(f"\n=== Severity Distribution ===")
        print(severity_entities['text'].str.lower().value_counts())

    # ================= Symptom-specific labeling functions =================
    print("\n=== Creating Symptom-Specific Labeling Functions ===")

    def create_symptom_labeling_functions():
        COL = 'combined_text'
        LOC = 'location'
        TEMP = 'temporalisation'

        def lf_severe_pain(row):
            text = str(row.get(COL,''))
            if 'pain' in text.lower() and _first_hit(text, ['severe']):
                return {'label':'SEVERE_PAIN','column':COL,'match':'severe','category':'severity'}
            return {'label':'ABSTAIN'}

        def lf_chronic_symptom(row):
            t = ' '.join([str(row.get(TEMP,'')), str(row.get('time',''))])
            hit = _first_hit(t, ['chronic','persistent','ongoing','months','years'])
            return {'label':'CHRONIC_SYMPTOM','column':TEMP if hit and hit in str(row.get(TEMP,'')).lower() else 'time','match':hit,'category':'temporal'} if hit else {'label':'ABSTAIN'}

        def lf_neurological_symptom(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['numbness','tingling','weakness','paralysis','sensation','tremor'])
            return {'label':'NEUROLOGICAL','column':COL,'match':hit,'category':'neuro'} if hit else {'label':'ABSTAIN'}

        def lf_bilateral_symptom(row):
            loc = str(row.get(LOC,''))
            hit = _first_hit(loc, ['bilateral','both','left and right'])
            return {'label':'BILATERAL_SYMPTOM','column':LOC,'match':hit,'category':'laterality'} if hit else {'label':'ABSTAIN'}

        def lf_acute_onset(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['sudden','acute','abrupt','rapid onset'])
            return {'label':'ACUTE_ONSET','column':COL,'match':hit,'category':'temporal'} if hit else {'label':'ABSTAIN'}

        def lf_progressive_symptom(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['worsening','progressive','increasing'])
            return {'label':'PROGRESSIVE_SYMPTOM','column':COL,'match':hit,'category':'temporal'} if hit else {'label':'ABSTAIN'}

        def lf_mobility_issue(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['walk','gait','mobility','movement'])
            return {'label':'MOBILITY_ISSUE','column':COL,'match':hit,'category':'function'} if hit else {'label':'ABSTAIN'}

        def lf_pain_with_location(row):
            loc = str(row.get(LOC,''))
            if not loc: return {'label':'ABSTAIN'}
            for area, lab in [('hip','HIP_PAIN'), ('knee','KNEE_PAIN'), ('back','BACK_PAIN')]:
                if area in loc.lower():
                    return {'label':lab,'column':LOC,'match':area,'category':'localized'}
            if loc.strip().lower() not in ('', 'nan'):
                return {'label':'LOCALIZED_PAIN','column':LOC,'match':loc.split()[0].lower(),'category':'localized'}
            return {'label':'ABSTAIN'}

        def lf_systemic_symptom(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['fever','fatigue','weight loss','malaise'])
            return {'label':'SYSTEMIC_SYMPTOM','column':COL,'match':hit,'category':'systemic'} if hit else {'label':'ABSTAIN'}

        return [lf_severe_pain, lf_chronic_symptom, lf_neurological_symptom, lf_bilateral_symptom,
                lf_acute_onset, lf_progressive_symptom, lf_mobility_issue, lf_pain_with_location,
                lf_systemic_symptom]

    # --- materialize spans ---
    symptom_lfs_span = create_symptom_labeling_functions()
    df_symptom_lf_spans = materialize_lf_spans(df_symptoms, symptom_lfs_span, id_column='idx')
    print(f"\nLF-generated symptom spans: {len(df_symptom_lf_spans)}")
    if not df_symptom_lf_spans.empty:
        print(df_symptom_lf_spans['label'].value_counts())

    # --- combine model + custom + LF spans ---
    _cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _cols:
        if col not in df_symptoms_entities.columns: df_symptoms_entities[col] = 'bc5cdr' if col=='source' else None
        if col not in df_custom_symptom_entities.columns: df_custom_symptom_entities[col] = 'custom_extraction' if col=='source' else None
        if col not in df_symptom_lf_spans.columns: df_symptom_lf_spans[col] = None

    df_all_symptom_entities = pd.concat(
        [df_symptoms_entities[_cols], df_custom_symptom_entities[_cols], df_symptom_lf_spans[_cols]],
        ignore_index=True
    )
    print(f"Total symptom entities (model+custom+LF): {len(df_all_symptom_entities)}")
    df_symptom_lf_spans.to_csv('symptom_lf_generated_spans.csv', index=False)


Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 18893
Temporal types: {'unspecified': 13201, 'post_event': 6146, 'range_reference': 5023, 'past_reference': 4704, 'duration_reference': 4063, 'onset_reference': 2904, 'absolute_date': 1164}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 54944 texts in 184 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   1%|          | 1/184 [00:01<05:49,  1.91s/it]


Checkpoint saved at batch 0


Processing batches:   6%|▌         | 11/184 [00:13<03:32,  1.23s/it]


Checkpoint saved at batch 3000


Processing batches:  11%|█▏        | 21/184 [00:25<03:16,  1.20s/it]


Checkpoint saved at batch 6000


Processing batches:  17%|█▋        | 31/184 [00:37<02:59,  1.17s/it]


Checkpoint saved at batch 9000


Processing batches:  22%|██▏       | 41/184 [00:48<02:51,  1.20s/it]


Checkpoint saved at batch 12000


Processing batches:  28%|██▊       | 51/184 [00:59<02:35,  1.17s/it]


Checkpoint saved at batch 15000


Processing batches:  33%|███▎      | 61/184 [01:11<02:28,  1.21s/it]


Checkpoint saved at batch 18000


Processing batches:  39%|███▊      | 71/184 [01:22<02:13,  1.18s/it]


Checkpoint saved at batch 21000


Processing batches:  44%|████▍     | 81/184 [01:33<02:02,  1.19s/it]


Checkpoint saved at batch 24000


Processing batches:  49%|████▉     | 91/184 [01:45<01:49,  1.17s/it]


Checkpoint saved at batch 27000


Processing batches:  55%|█████▍    | 101/184 [01:56<01:40,  1.21s/it]


Checkpoint saved at batch 30000


Processing batches:  60%|██████    | 111/184 [02:07<01:28,  1.21s/it]


Checkpoint saved at batch 33000


Processing batches:  66%|██████▌   | 121/184 [02:18<01:16,  1.21s/it]


Checkpoint saved at batch 36000


Processing batches:  71%|███████   | 131/184 [02:30<01:04,  1.22s/it]


Checkpoint saved at batch 39000


Processing batches:  77%|███████▋  | 141/184 [02:41<00:53,  1.24s/it]


Checkpoint saved at batch 42000


Processing batches:  82%|████████▏ | 151/184 [02:52<00:40,  1.23s/it]


Checkpoint saved at batch 45000


Processing batches:  88%|████████▊ | 161/184 [03:03<00:28,  1.25s/it]


Checkpoint saved at batch 48000


Processing batches:  93%|█████████▎| 171/184 [03:15<00:15,  1.23s/it]


Checkpoint saved at batch 51000


Processing batches:  98%|█████████▊| 181/184 [03:26<00:03,  1.26s/it]


Checkpoint saved at batch 54000


Processing batches: 100%|██████████| 184/184 [03:28<00:00,  1.14s/it]


Found 1857 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 84179, 'CHEMICAL': 2286}

Custom extraction found:
label
ANATOMY             67817
SYMPTOM_TYPE        56789
SYMPTOM             53580
LATERALITY          31971
TEMPORAL_PATTERN    14118
SEVERITY             5306
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
DISEASE             84179
ANATOMY             67817
SYMPTOM_TYPE        56789
SYMPTOM             53580
LATERALITY          31971
TEMPORAL_PATTERN    14118
SEVERITY             5306
CHEMICAL             2286
Name: count, dtype: int64

=== Top Symptoms ===
text
pain                    1796
abdominal pain          1182
swelling                1146
fever                    825
headache                 722
chest pain               576
vomiting                 575
dyspnea                  519
shortness of breath      501
weight loss              459
nausea                   399
dysphagia                313
fatigue             

In [38]:
df_all_symptom_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,Pain,DISEASE,0,4,Pain with Severe intensity located in Left hip...,bc5cdr,133948,
1,pain,DISEASE,147,151,Pain with Severe intensity located in Left hip...,bc5cdr,133948,
2,weight gain,DISEASE,129,140,Moderate moon face with Moderate intensity loc...,bc5cdr,133948,
3,Central obesity,DISEASE,0,15,Central obesity with nan intensity located in ...,bc5cdr,133948,
4,Muscle mass reduction,DISEASE,0,21,Muscle mass reduction with nan intensity locat...,bc5cdr,133948,
...,...,...,...,...,...,...,...,...
383858,Chest,LOCALIZED_PAIN,0,5,Chest,lf:lf_pain_with_location,98004,localized
383859,years,CHRONIC_SYMPTOM,13,18,Noticed four years prior to presentation,lf:lf_chronic_symptom,133320,temporal
383860,Lateral,LOCALIZED_PAIN,0,7,Lateral side of the right thigh,lf:lf_pain_with_location,133320,localized
383861,Acute,ACUTE_ONSET,40,45,Crushing substernal chest pressure with Acute ...,lf:lf_acute_onset,97973,temporal


### Extracting Diagnosis

In [39]:
df_diagnosis

Unnamed: 0,idx,has_diagnosis,test,severity,result,condition,time,details
0,155216,False,,,,,,
1,133948,True,Magnetic resonance imaging (MRI) scan,,Increased amount of joint fluid and bone marro...,Idiopathic osteonecrosis of the femoral head,,Patient did not complain of any pain on the co...
2,133948,True,Repeat MRI,,Similar findings to those noted previously in ...,,One year after the initial surgery and symptom...,
3,80176,True,Radiographs,Minimally displaced,Proximal ulnar shaft fracture,"Proximal ulnar shaft fracture, hypertrophic no...",,Elbow arthrodesis at 90 degrees with retained ...
4,72232,True,MRI,Moderate-sized,Focal area of marrow edema/contusion involving...,Bone marrow edema,"September 2016, three months later, April 2017...",Involvement of medial femoral condyle in mid a...
...,...,...,...,...,...,...,...,...
61350,133320,True,Histopathological examination,,Consistent with lung metastasis of leiomyosarcoma,Lung metastasis of leiomyosarcoma,One year and 3 months postoperatively,
61351,97973,True,Electrocardiogram (ECG),,Diffuse ST depressions in all precordial leads,Consistent with an acute coronary syndrome,,
61352,97973,True,Transthoracic echocardiogram,Ejection fraction (EF) of 45% with severe aort...,Torn right coronary cusp,Severe aortic insufficiency,,Emergent transthoracic echocardiogram performed
61353,97973,True,Blood cultures,,Positive for S.\nlugdunensis in both bottles,,,


In [40]:
def extract_diagnosis_entities_custom(df_diagnosis):
        """
        Extract diagnosis-specific entities not (fully) caught by BC5CDR.
        Uses 'idx' for row identity when available; falls back to df.index otherwise.
        Adds 'row_idx' and character offsets for all extracted entities.
        Regex with word boundaries is used to avoid substring false positives.
        """
        custom_entities = []

        # Test/Procedure patterns
        test_patterns = {
            'imaging': [
                'mri', 'magnetic resonance', 'ct', 'computed tomography',
                'x-ray', 'radiograph', 'ultrasound', 'sonography', 'scan',
                'pet', 'spect', 'angiography', 'mammography'
            ],
            'laboratory': [
                'blood test', 'serum', 'plasma', 'biochemical', 'hematology',
                'urinalysis', 'culture', 'biopsy', 'cytology', 'pathology'
            ],
            'functional': [
                'ecg', 'ekg', 'electrocardiogram', 'eeg', 'electroencephalogram',
                'emg', 'electromyography', 'spirometry', 'pulmonary function'
            ],
            'endoscopy': [
                'endoscopy', 'colonoscopy', 'gastroscopy', 'bronchoscopy',
                'cystoscopy', 'arthroscopy', 'laparoscopy'
            ]
        }

        # Finding/Result patterns
        finding_patterns = {
            'structural': [
                'fracture', 'lesion', 'mass', 'tumor', 'cyst', 'nodule',
                'stenosis', 'occlusion', 'herniation', 'displacement'
            ],
            'inflammatory': [
                'inflammation', 'edema', 'swelling', 'effusion', 'congestion',
                'infiltration', 'consolidation'
            ],
            'degenerative': [
                'degeneration', 'atrophy', 'necrosis', 'fibrosis', 'sclerosis',
                'osteoarthritis', 'spondylosis'
            ],
            'vascular': [
                'ischemia', 'infarction', 'hemorrhage', 'aneurysm', 'thrombosis',
                'embolism', 'vasculitis'
            ],
            'neoplastic': [
                'malignant', 'benign', 'metastasis', 'carcinoma', 'sarcoma',
                'lymphoma', 'adenoma'
            ]
        }

        # Anatomical patterns specific to diagnosis
        anatomy_patterns = {
            'bone':   ['femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna', 'vertebra'],
            'joint':  ['hip', 'knee', 'shoulder', 'elbow', 'ankle', 'wrist'],
            'organ':  ['liver', 'kidney', 'heart', 'lung', 'brain', 'pancreas', 'spleen'],
            'vessel': ['artery', 'vein', 'aorta', 'carotid', 'coronary'],
            'region': ['parietal', 'temporal', 'frontal', 'occipital', 'cervical', 'lumbar']
        }

        # Severity/Grade patterns (from column)
        severity_patterns = {
            'mild':     ['mild', 'minimal', 'slight', 'minor'],
            'moderate': ['moderate', 'medium', 'intermediate'],
            'severe':   ['severe', 'significant', 'marked', 'extensive'],
            'critical': ['critical', 'life-threatening', 'emergency']
        }

        # Laterality
        laterality_terms = ['left', 'right', 'bilateral']

        for df_index, row in df_diagnosis.iterrows():
            # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

            original_text = str(row.get('combined_text', ''))
            combined = original_text.lower()

            test_text_orig = str(row.get('test', ''))
            test_text = test_text_orig.lower()
            result_text = str(row.get('result', '')).lower()
            condition_text_orig = str(row.get('condition', ''))
            condition_text = condition_text_orig.lower()
            severity_text_orig = str(row.get('severity', ''))
            severity_text = severity_text_orig.lower()

            # ---- TEST from test column (place in the combined text if possible)
            if test_text and test_text != 'nan':
                start = combined.find(test_text)
                if start < 0:
                    start = 0
                end = min(start + len(test_text_orig), len(original_text))
                custom_entities.append({
                    'text': test_text_orig,
                    'label': 'TEST',
                    'category': 'diagnostic_test',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'test_column',
                    'row_idx': row_id
                })
                # TEST_TYPE via patterns (regex boundaries)
                for category, terms in test_patterns.items():
                    for term in terms:
                        pattern = r'\b' + re.escape(term.lower()) + r'\b'
                        for m in re.finditer(pattern, combined):
                            custom_entities.append({
                                'text': original_text[m.start():m.end()],
                                'label': 'TEST_TYPE',
                                'category': category,
                                'start': m.start(),
                                'end': m.end(),
                                'original_text': original_text,
                                'source': 'pattern_matching',
                                'row_idx': row_id
                            })

            # ---- FINDINGS from result (search in combined for offsets)
            if result_text and result_text != 'nan':
                for category, terms in finding_patterns.items():
                    for term in terms:
                        pattern = r'\b' + re.escape(term.lower()) + r'\b'
                        for m in re.finditer(pattern, combined):
                            custom_entities.append({
                                'text': original_text[m.start():m.end()],
                                'label': 'FINDING',
                                'category': category,
                                'start': m.start(),
                                'end': m.end(),
                                'original_text': original_text,
                                'source': 'result_extraction',
                                'row_idx': row_id
                            })

            # ---- CONDITION from condition column
            if condition_text and condition_text != 'nan':
                start = combined.find(condition_text)
                if start < 0:
                    start = 0
                end = min(start + len(condition_text_orig), len(original_text))
                custom_entities.append({
                    'text': condition_text_orig,
                    'label': 'CONDITION',
                    'category': 'diagnosis',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'condition_column',
                    'row_idx': row_id
                })

            # ---- ANATOMY patterns (allow simple plural 's')
            for category, terms in anatomy_patterns.items():
                for term in terms:
                    plural_opt = 's?' if term.isalpha() else ''
                    pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'ANATOMY',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'anatomy_extraction',
                            'row_idx': row_id
                        })

            # ---- SEVERITY from severity column
            if severity_text and severity_text != 'nan':
                for category, terms in severity_patterns.items():
                    if severity_text in terms:
                        start = combined.find(severity_text)
                        if start < 0:
                            start = 0
                        end = min(start + len(severity_text_orig), len(original_text))
                        custom_entities.append({
                            'text': severity_text_orig,
                            'label': 'SEVERITY',
                            'category': category,
                            'start': start,
                            'end': end,
                            'original_text': original_text,
                            'source': 'severity_column',
                            'row_idx': row_id
                        })
                        break

            # ---- LATERALITY
            for lat in laterality_terms:
                pattern = r'\b' + re.escape(lat) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'LATERALITY',
                        'category': 'laterality',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'laterality_extraction',
                        'row_idx': row_id
                    })

            # ---- MEASUREMENTS (capture full token with units)
            # e.g., "12 mm", "3.5 cm", "45%", "10 mg", "30 ml"
            measurement_pattern = r'\b\d+(?:\.\d+)?\s*(?:mm|cm|ml|mg|%)\b'
            for m in re.finditer(measurement_pattern, combined):
                custom_entities.append({
                    'text': original_text[m.start():m.end()],
                    'label': 'MEASUREMENT',
                    'category': 'quantitative',
                    'start': m.start(),
                    'end': m.end(),
                    'original_text': original_text,
                    'source': 'measurement_extraction',
                    'row_idx': row_id
                })

        return pd.DataFrame(custom_entities)
if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_diagnosis_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_diagnosis, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine diagnosis text columns (robust to NaNs)
    df_diagnosis['combined_text'] = df_diagnosis.apply(
        lambda row: f"{str(row.get('test',''))} performed with {str(row.get('severity',''))} severity "
                    f"showed {str(row.get('result',''))} indicating {str(row.get('condition',''))} "
                    f"{str(row.get('details',''))} at {str(row.get('time',''))}",
        axis=1        
    )

    # Copy combined_text to processed dataframe
    df_diagnosis_processed['combined_text'] = df_diagnosis['combined_text']

    # Add temporal features to combined text
    df_diagnosis_processed['combined_text_enriched'] = df_diagnosis_processed.apply(
        lambda row: f"{row['combined_text']} "
                    f"{'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # Run NER extraction with BC5CDR
    df_diagnosis_entities, diagnosis_summary, diagnosis_rules = run_medical_ner_extraction(
        df_diagnosis,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300, 
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(diagnosis_summary['entity_types'])
    

    # Extract custom entities
    df_custom_diagnosis_entities = extract_diagnosis_entities_custom(df_diagnosis)

    print(f"\nCustom extraction found:")
    if not df_custom_diagnosis_entities.empty and 'label' in df_custom_diagnosis_entities.columns:
        print(df_custom_diagnosis_entities['label'].value_counts())
    else:
        print("No custom diagnosis entities found.")

    # Combine all entities (preserve row_idx/category when present)
    df_all_diagnosis_entities = pd.concat(
        [df_diagnosis_entities, df_custom_diagnosis_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_diagnosis_entities.empty and 'label' in df_all_diagnosis_entities.columns:
        print(df_all_diagnosis_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze test entities
    test_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'TEST'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not test_entities.empty:
        print(f"\n=== Top Diagnostic Tests ===")
        print(test_entities['text'].str.lower().value_counts().head(20))

    # Analyze findings
    finding_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'FINDING'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not finding_entities.empty:
        print(f"\n=== Top Diagnostic Findings ===")
        print(finding_entities['text'].str.lower().value_counts().head(20))

    # Analyze conditions
    condition_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'CONDITION'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not condition_entities.empty:
        print(f"\n=== Top Diagnosed Conditions ===")
        print(condition_entities['text'].str.lower().value_counts().head(20))

    # ================= Diagnosis-specific labeling functions =================
    print("\n=== Creating Diagnosis-Specific Labeling Functions ===")

    def create_diagnosis_labeling_functions():
        COL = 'combined_text'

        def lf_imaging_test(row):
            test = str(row.get('test',''))
            hit = _first_hit(test, ['mri','ct','x-ray','radiograph','ultrasound','scan'])
            return {'label':'IMAGING_TEST','column':'test','match':hit,'category':'test'} if hit else {'label':'ABSTAIN'}

        def lf_fracture_diagnosis(row):
            res = str(row.get('result',''))
            cond = str(row.get('condition',''))
            if 'no fracture' in res.lower() or 'no fracture' in cond.lower():
                col = 'result' if 'no fracture' in res.lower() else 'condition'
                return {'label':'NO_FRACTURE','column':col,'match':'no fracture','category':'finding'}
            if 'fracture' in res.lower() or 'fracture' in cond.lower():
                col = 'result' if 'fracture' in res.lower() else 'condition'
                return {'label':'FRACTURE_PRESENT','column':col,'match':'fracture','category':'finding'}
            return {'label':'ABSTAIN'}

        def lf_neoplastic_finding(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['tumor','mass','lesion','malignant','benign','metastasis'])
            return {'label':'NEOPLASTIC_FINDING','column':COL,'match':hit,'category':'neoplastic'} if hit else {'label':'ABSTAIN'}

        def lf_normal_finding(row):
            res = str(row.get('result',''))
            hit = _first_hit(res, ['normal','negative','no abnormality','unremarkable'])
            return {'label':'NORMAL_FINDING','column':'result','match':hit,'category':'normal'} if hit else {'label':'ABSTAIN'}

        def lf_critical_finding(row):
            sev = str(row.get('severity',''))
            det = str(row.get('details',''))
            for col, txt in [('severity', sev), ('details', det), (COL, str(row.get(COL,'')))]:
                hit = _first_hit(txt, ['critical','emergency','urgent','life-threatening'])
                if hit:
                    return {'label':'CRITICAL_FINDING','column':col,'match':hit,'category':'critical'}
            return {'label':'ABSTAIN'}

        def lf_bone_pathology(row):
            text = str(row.get(COL,''))
            if any(b in text.lower() for b in ['bone','osseous','fracture','osteo','marrow']):
                hit = _first_hit(text, ['lesion','edema','necrosis','fracture'])
                if hit:
                    return {'label':'BONE_PATHOLOGY','column':COL,'match':hit,'category':'bone'}
            return {'label':'ABSTAIN'}

        def lf_vascular_finding(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['vascular','artery','vein','aneurysm','stenosis','occlusion'])
            return {'label':'VASCULAR_FINDING','column':COL,'match':hit,'category':'vascular'} if hit else {'label':'ABSTAIN'}

        def lf_inflammatory_finding(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['inflammation','inflammatory','edema','effusion','swelling'])
            return {'label':'INFLAMMATORY_FINDING','column':COL,'match':hit,'category':'inflammatory'} if hit else {'label':'ABSTAIN'}

        def lf_bilateral_finding(row):
            text = str(row.get(COL,''))
            hit = _first_hit(text, ['bilateral'])
            return {'label':'BILATERAL_FINDING','column':COL,'match':hit,'category':'laterality'} if hit else {'label':'ABSTAIN'}

        def lf_followup_needed(row):
            det = str(row.get('details',''))
            hit = _first_hit(det, ['follow-up','followup','repeat','monitor','reassess'])
            return {'label':'FOLLOWUP_NEEDED','column':'details','match':hit,'category':'plan'} if hit else {'label':'ABSTAIN'}

        return [lf_imaging_test, lf_fracture_diagnosis, lf_neoplastic_finding, lf_normal_finding,
                lf_critical_finding, lf_bone_pathology, lf_vascular_finding,
                lf_inflammatory_finding, lf_bilateral_finding, lf_followup_needed]

    # --- materialize spans ---
    diagnosis_lfs_span = create_diagnosis_labeling_functions()
    df_diagnosis_lf_spans = materialize_lf_spans(df_diagnosis, diagnosis_lfs_span, id_column='idx')
    print(f"\nLF-generated diagnosis spans: {len(df_diagnosis_lf_spans)}")
    if not df_diagnosis_lf_spans.empty:
        print(df_diagnosis_lf_spans['label'].value_counts())

    # --- combine model + custom + LF spans ---
    _cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _cols:
        if col not in df_diagnosis_entities.columns: df_diagnosis_entities[col] = 'bc5cdr' if col=='source' else None
        if col not in df_custom_diagnosis_entities.columns: df_custom_diagnosis_entities[col] = 'custom_extraction' if col=='source' else None
        if col not in df_diagnosis_lf_spans.columns: df_diagnosis_lf_spans[col] = None

    df_all_diagnosis_entities = pd.concat(
        [df_diagnosis_entities[_cols], df_custom_diagnosis_entities[_cols], df_diagnosis_lf_spans[_cols]],
        ignore_index=True
    )
    print(f"Total diagnosis entities (model+custom+LF): {len(df_all_diagnosis_entities)}")
    df_diagnosis_lf_spans.to_csv('diagnosis_lf_generated_spans.csv', index=False)




Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 3879
Temporal types: {'unspecified': 4734, 'post_event': 3354, 'absolute_date': 1766, 'past_reference': 1280, 'range_reference': 1258, 'duration_reference': 299, 'onset_reference': 155}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 61355 texts in 205 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   0%|          | 1/205 [00:02<06:59,  2.05s/it]


Checkpoint saved at batch 0


Processing batches:   5%|▌         | 11/205 [00:16<04:39,  1.44s/it]


Checkpoint saved at batch 3000


Processing batches:  10%|█         | 21/205 [00:30<04:37,  1.51s/it]


Checkpoint saved at batch 6000


Processing batches:  15%|█▌        | 31/205 [00:44<04:16,  1.47s/it]


Checkpoint saved at batch 9000


Processing batches:  20%|██        | 41/205 [00:58<03:54,  1.43s/it]


Checkpoint saved at batch 12000


Processing batches:  25%|██▍       | 51/205 [01:12<03:37,  1.41s/it]


Checkpoint saved at batch 15000


Processing batches:  30%|██▉       | 61/205 [01:26<03:27,  1.44s/it]


Checkpoint saved at batch 18000


Processing batches:  35%|███▍      | 71/205 [01:40<03:16,  1.46s/it]


Checkpoint saved at batch 21000


Processing batches:  40%|███▉      | 81/205 [01:54<02:58,  1.44s/it]


Checkpoint saved at batch 24000


Processing batches:  44%|████▍     | 91/205 [02:08<02:48,  1.48s/it]


Checkpoint saved at batch 27000


Processing batches:  49%|████▉     | 101/205 [02:22<02:34,  1.48s/it]


Checkpoint saved at batch 30000


Processing batches:  54%|█████▍    | 111/205 [02:36<02:18,  1.48s/it]


Checkpoint saved at batch 33000


Processing batches:  59%|█████▉    | 121/205 [02:50<02:04,  1.49s/it]


Checkpoint saved at batch 36000


Processing batches:  64%|██████▍   | 131/205 [03:05<01:52,  1.52s/it]


Checkpoint saved at batch 39000


Processing batches:  69%|██████▉   | 141/205 [03:19<01:37,  1.53s/it]


Checkpoint saved at batch 42000


Processing batches:  74%|███████▎  | 151/205 [03:33<01:22,  1.53s/it]


Checkpoint saved at batch 45000


Processing batches:  79%|███████▊  | 161/205 [03:47<01:06,  1.52s/it]


Checkpoint saved at batch 48000


Processing batches:  83%|████████▎ | 171/205 [04:01<00:52,  1.54s/it]


Checkpoint saved at batch 51000


Processing batches:  88%|████████▊ | 181/205 [04:15<00:36,  1.52s/it]


Checkpoint saved at batch 54000


Processing batches:  93%|█████████▎| 191/205 [04:29<00:21,  1.55s/it]


Checkpoint saved at batch 57000


Processing batches:  98%|█████████▊| 201/205 [04:44<00:06,  1.57s/it]


Checkpoint saved at batch 60000


Processing batches: 100%|██████████| 205/205 [04:48<00:00,  1.41s/it]


Found 2745 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 77150, 'CHEMICAL': 8577}

Custom extraction found:
label
TEST           56150
TEST_TYPE      51197
FINDING        40916
CONDITION      36529
ANATOMY        27435
LATERALITY     25373
MEASUREMENT    10044
SEVERITY        1054
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
DISEASE        77150
TEST           56150
TEST_TYPE      51197
FINDING        40916
CONDITION      36529
ANATOMY        27435
LATERALITY     25373
MEASUREMENT    10044
CHEMICAL        8577
SEVERITY        1054
Name: count, dtype: int64

=== Top Diagnostic Tests ===
text
mri                                 1005
ct scan                              786
biopsy                               663
magnetic resonance imaging (mri)     655
histopathological examination        577
chest x-ray                          537
computed tomography (ct) scan        402
laboratory tests                     397
computed tomograph

In [41]:
df_all_diagnosis_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,bone marrow edema,DISEASE,109,126,Magnetic resonance imaging (MRI) scan performe...,bc5cdr,133948,
1,femoral head necrosis,DISEASE,148,169,Magnetic resonance imaging (MRI) scan performe...,bc5cdr,133948,
2,Idiopathic osteonecrosis of the femoral head P...,DISEASE,207,263,Magnetic resonance imaging (MRI) scan performe...,bc5cdr,133948,
3,pain,DISEASE,284,288,Magnetic resonance imaging (MRI) scan performe...,bc5cdr,133948,
4,fracture,DISEASE,84,92,Radiographs performed with Minimally displaced...,bc5cdr,80176,
...,...,...,...,...,...,...,...,...
400806,CT,IMAGING_TEST,39,41,Low-dose thoracic computed tomography (CT),lf:lf_imaging_test,137017,test
400807,lesion,NEOPLASTIC_FINDING,86,92,Coronary angiography performed with nan severi...,lf:lf_neoplastic_finding,98004,neoplastic
400808,artery,VASCULAR_FINDING,71,77,Coronary angiography performed with nan severi...,lf:lf_vascular_finding,98004,vascular
400809,metastasis,NEOPLASTIC_FINDING,86,96,Histopathological examination performed with n...,lf:lf_neoplastic_finding,133320,neoplastic


### Extracting Treatments

In [42]:
df_treatments

Unnamed: 0,idx,has_treatments,name,related condition,dosage,time,frequency,duration,reason for taking,reaction to treatment,details
0,155216,True,Olanzapine tablets,Bipolar affective disorder,5 mg per day,Past four months,Daily,,Control of exacerbated mental illness,"Pain and discomfort in neck, sustained and abn...",Previously managed with olanzapine tablets in ...
1,155216,True,Trihexyphenidyl,Rigidity in upper limbs,4 mg per day,Brief period of around three weeks,Daily,,Rigidity in upper limbs,Good response,
2,133948,False,,,,,,,,,
3,80176,True,Closed treatment in a cast,Proximal ulnar shaft fracture,,Initially after the fall,,,To treat the ulnar shaft fracture,Developed a hypertrophic nonunion,
4,80176,True,Conservative treatment,Ulna nonunion,,Three months after the fall,,An additional three months,To treat the ulna nonunion,Worsening motion through the nonunion site,
...,...,...,...,...,...,...,...,...,...,...,...
50421,98004,True,Hypovolaemic shock treatment,Haemodynamic instability and hypovolaemic shock,,,,,To maintain blood pressure and treat shock,Required large doses of vasopressor and blood ...,Treatment given after becoming haemodynamicall...
50422,133320,True,Systemic chemotherapy,Lung and bone metastases,,,,,To treat lung and bone metastases,,Using doxorubicin and ifosfamide
50423,97973,True,Rapid sequence intubation,Cardiogenic shock and flash pulmonary edema,,,,,To manage suspected cardiogenic shock and flas...,,
50424,97973,True,Advanced cardiac life support (ACLS) protocol,Cardiac arrest,,13 min,,,To restore return of spontaneous circulation a...,Return of spontaneous circulation was restored,


In [43]:
def create_treatment_text(row):
    """Create comprehensive text from treatment row"""
    parts = []
    
    if pd.notna(row.get('name')) and str(row['name']) != 'NaN':
        parts.append(f"Treatment: {row['name']}")
    
    if pd.notna(row.get('related condition')) and str(row['related condition']) != 'NaN':
        parts.append(f"for {row['related condition']}")
    
    if pd.notna(row.get('dosage')) and str(row['dosage']) != 'NaN':
        parts.append(f"dosage {row['dosage']}")
    
    if pd.notna(row.get('frequency')) and str(row['frequency']) != 'NaN':
        parts.append(f"frequency {row['frequency']}")
    
    if pd.notna(row.get('time')) and str(row['time']) != 'NaN':
        parts.append(f"time {row['time']}")
    
    if pd.notna(row.get('duration')) and str(row['duration']) != 'NaN':
        parts.append(f"duration {row['duration']}")
    
    if pd.notna(row.get('reason for taking')) and str(row['reason for taking']) != 'NaN':
        parts.append(f"reason: {row['reason for taking']}")
    
    if pd.notna(row.get('reaction to treatment')) and str(row['reaction to treatment']) != 'NaN':
        parts.append(f"reaction: {row['reaction to treatment']}")
    
    if pd.notna(row.get('details')) and str(row['details']) != 'NaN':
        parts.append(f"details: {row['details']}")
    
    return " ".join(parts)

In [44]:
def extract_treatment_entities_custom(df_treatments):
    """Extract treatment-specific entities not caught by BC5CDR.
        Uses 'idx' for row identity when available; falls back to df.index otherwise.
        Adds 'row_idx' to every extracted entity.
    """
    custom_entities = []
    
    # Treatment type patterns
    treatment_type_patterns = {
        'medication': ['tablet', 'tablets', 'pill', 'pills', 'capsule', 'injection', 'infusion', 'medication', 'drug'],
        'procedure': ['surgery', 'surgical', 'operation', 'procedure', 'therapy', 'intubation', 'resection', 'removal', 'repair'],
        'supportive': ['support', 'life support', 'ventilation', 'oxygen', 'fluid', 'nutrition', 'acls', 'protocol'],
        'chemotherapy': ['chemotherapy', 'chemo', 'cytotoxic', 'antineoplastic', 'systemic chemotherapy'],
        'conservative': ['conservative', 'non-operative', 'non-surgical', 'observation', 'closed treatment'],
        'emergency': ['emergency', 'urgent', 'rapid', 'resuscitation', 'life-saving'],
        'diagnostic': ['biopsy', 'exploration', 'diagnostic', 'assessment']
    }
    
    # Drug/medication patterns
    medication_patterns = {
        'antipsychotic': ['olanzapine', 'risperidone', 'quetiapine', 'haloperidol', 'clozapine'],
        'muscle_relaxant': ['trihexyphenidyl', 'baclofen', 'tizanidine', 'cyclobenzaprine'],
        'antibiotic': ['nafcillin', 'vancomycin', 'ceftriaxone', 'penicillin', 'amoxicillin'],
        'cardiovascular': ['vasopressor', 'inotrope', 'beta blocker', 'ace inhibitor', 'hypovolaemic'],
        'analgesic': ['morphine', 'fentanyl', 'oxycodone', 'acetaminophen', 'ibuprofen']
    }
    
    # Condition patterns (treated conditions)
    condition_patterns = {
        'psychiatric': ['bipolar', 'affective disorder', 'psychosis', 'mania', 'depression', 'mental illness'],
        'orthopedic': ['fracture', 'joint', 'bone', 'hip', 'knee', 'spine', 'ulnar shaft'],
        'cardiovascular': ['cardiac', 'heart', 'arrhythmia', 'shock', 'arrest', 'hypovolaemic', 'endocarditis'],
        'oncological': ['cancer', 'metastases', 'tumor', 'malignancy', 'carcinoma'],
        'neurological': ['rigidity', 'tremor', 'paralysis', 'nerve', 'neurological'],
        'infectious': ['infection', 'sepsis', 'endocarditis', 'abscess', 'pneumonia']
    }
    
    # Dosage unit patterns
    dosage_patterns = {
        'mg': r'\b(\d+\.?\d*)\s*mg\b',
        'mcg': r'\b(\d+\.?\d*)\s*mcg\b',
        'units': r'\b(\d+\.?\d*)\s*units?\b',
        'ml': r'\b(\d+\.?\d*)\s*ml\b',
        'per_day': r'\b(?:per\s*day|/day|daily)\b',
        'min': r'\b(\d+\.?\d*)\s*min(?:utes?)?\b'
    }
    
    # Frequency patterns
    frequency_terms = ['daily', 'twice daily', 'three times', 'qid', 'bid', 'tid', 'prn', 
                        'as needed', 'every', 'once', 'continuous', 'intermittent']
    
    # Route of administration patterns
    route_terms = ['oral', 'intravenous', 'iv', 'im', 'intramuscular', 'subcutaneous', 
                    'topical', 'inhaled', 'nasal', 'rectal']
    
    # Treatment response/outcome patterns
    response_patterns = {
        'positive': ['good response', 'improved', 'resolved', 'successful', 'effective', 
                        'restored', 'return of', 'recovered'],
        'negative': ['worsening', 'failed', 'no response', 'adverse', 'side effect', 
                        'complication', 'deterioration'],
        'neutral': ['no change', 'stable', 'maintained', 'sustained', 'continued'],
        'partial': ['partial response', 'some improvement', 'limited response']
    }
    
    # Temporal patterns specific to treatments
    temporal_patterns = {
        'acute': ['acute', 'sudden', 'rapid', 'emergency', 'immediate'],
        'chronic': ['chronic', 'long-term', 'maintenance', 'ongoing', 'continuous'],
        'perioperative': ['preoperative', 'postoperative', 'intraoperative', 'perioperative'],
        'duration': ['months', 'weeks', 'days', 'hours', 'years']
    }
    
    for df_index, row in df_treatments.iterrows():
        # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
        row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

        # Get relevant text fields with safe string conversion
        name_orig = str(row.get('name', ''))
        name = name_orig.lower() if pd.notna(row.get('name')) else ''
        condition_orig = str(row.get('related condition', ''))
        condition = condition_orig.lower() if pd.notna(row.get('related condition')) else ''
        dosage_orig = str(row.get('dosage', ''))
        dosage = dosage_orig.lower() if pd.notna(row.get('dosage')) else ''
        frequency_orig = str(row.get('frequency', ''))
        frequency = frequency_orig.lower() if pd.notna(row.get('frequency')) else ''
        time_orig = str(row.get('time', ''))
        time_l = time_orig.lower() if pd.notna(row.get('time')) else ''
        duration_orig = str(row.get('duration', ''))
        duration_l = duration_orig.lower() if pd.notna(row.get('duration')) else ''
        reason_orig = str(row.get('reason for taking', ''))
        reason = reason_orig.lower() if pd.notna(row.get('reason for taking')) else ''
        reaction_orig = str(row.get('reaction to treatment', ''))
        reaction = reaction_orig.lower() if pd.notna(row.get('reaction to treatment')) else ''
        details_orig = str(row.get('details', ''))
        details = details_orig.lower() if pd.notna(row.get('details')) else ''
        combined_orig = str(row.get('combined_text', ''))
        combined = combined_orig.lower()

        # Primary TREATMENT from name column (position in combined if possible)
        if name and name != 'nan':
            start = combined.find(name)
            if start < 0:
                start = 0
            end = min(start + len(name_orig), len(combined_orig))
            custom_entities.append({
                'text': name_orig,
                'label': 'TREATMENT',
                'category': 'primary_treatment',
                'start': start,
                'end': end,
                'original_text': combined_orig if combined_orig else name_orig,
                'source': 'name_column',
                'row_idx': row_id
            })
        
        # CONDITION (indication)
        if condition and condition != 'nan':
            start = combined.find(condition)
            if start < 0:
                start = 0
            end = min(start + len(condition_orig), len(combined_orig))
            custom_entities.append({
                'text': condition_orig,
                'label': 'CONDITION',
                'category': 'treatment_indication',
                'start': start,
                'end': end,
                'original_text': combined_orig if combined_orig else condition_orig,
                'source': 'condition_column',
                'row_idx': row_id
            })
        
        # TREATMENT_TYPE (search in combined)
        for category, terms in treatment_type_patterns.items():
            for term in terms:
                pattern = r'\b' + re.escape(term) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': combined_orig[m.start():m.end()],
                        'label': 'TREATMENT_TYPE',
                        'category': category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': combined_orig,
                        'source': 'pattern_matching',
                        'row_idx': row_id
                    })
        
        # MEDICATION names/classes (search in combined)
        for drug_class, drugs in medication_patterns.items():
            for drug in drugs:
                pattern = r'\b' + re.escape(drug) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': combined_orig[m.start():m.end()],
                        'label': 'MEDICATION',
                        'category': drug_class,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': combined_orig,
                        'source': 'medication_pattern',
                        'row_idx': row_id
                    })
        
        # DOSAGE (from dosage column)
        if dosage and dosage != 'nan':
            for unit, pattern in dosage_patterns.items():
                for match in re.finditer(pattern, dosage, re.IGNORECASE):
                    custom_entities.append({
                        'text': match.group(0),
                        'label': 'DOSAGE',
                        'category': unit,
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': dosage_orig,
                        'source': 'dosage_column',
                        'row_idx': row_id
                    })
        
        # FREQUENCY (from frequency column)
        if frequency and frequency != 'nan':
            for freq_term in frequency_terms:
                pattern = r'\b' + re.escape(freq_term) + r'\b'
                for m in re.finditer(pattern, frequency):
                    custom_entities.append({
                        'text': frequency_orig[m.start():m.end()],
                        'label': 'FREQUENCY',
                        'category': 'dosing_frequency',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': frequency_orig,
                        'source': 'frequency_column',
                        'row_idx': row_id
                    })
        
        # ROUTE (search across combined context)
        for route in route_terms:
            pattern = r'\b' + re.escape(route) + r'\b'
            for m in re.finditer(pattern, combined):
                custom_entities.append({
                    'text': combined_orig[m.start():m.end()],
                    'label': 'ROUTE',
                    'category': 'administration_route',
                    'start': m.start(),
                    'end': m.end(),
                    'original_text': combined_orig,
                    'source': 'route_extraction',
                    'row_idx': row_id
                })
        
        # TREATMENT_RESPONSE (reaction + details + combined)
        response_text_orig = ' '.join([reaction_orig, details_orig, combined_orig])
        response_text = response_text_orig.lower()
        for response_type, patterns in response_patterns.items():
            for phrase in patterns:
                pattern = r'\b' + re.escape(phrase) + r'\b'
                for m in re.finditer(pattern, response_text):
                    custom_entities.append({
                        'text': response_text_orig[m.start():m.end()],
                        'label': 'TREATMENT_RESPONSE',
                        'category': response_type,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': response_text_orig,
                        'source': 'response_extraction',
                        'row_idx': row_id
                    })
        
        # CONDITION_TYPE categories (condition or reason)
        for condition_type, condition_terms in condition_patterns.items():
            for term in condition_terms:
                pattern = r'\b' + re.escape(term) + r'\b'
                # search in condition text first, then reason, else combined
                found = False
                for text_orig, text_l, source in [
                    (condition_orig, condition, 'condition_pattern'),
                    (reason_orig, reason, 'reason_pattern'),
                    (combined_orig, combined, 'combined_pattern'),
                ]:
                    m = re.search(pattern, text_l)
                    if m:
                        custom_entities.append({
                            'text': text_orig[m.start():m.end()],
                            'label': 'CONDITION_TYPE',
                            'category': condition_type,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': text_orig,
                            'source': source,
                            'row_idx': row_id
                        })
                        found = True
                        break
        
        # TEMPORAL_PATTERN
        temporal_text_orig = ' '.join([time_orig, duration_orig, details_orig])
        temporal_text = temporal_text_orig.lower()
        for temp_category, temp_terms in temporal_patterns.items():
            for term in temp_terms:
                pattern = r'\b' + re.escape(term) + r'\b'
                for m in re.finditer(pattern, temporal_text):
                    custom_entities.append({
                        'text': temporal_text_orig[m.start():m.end()],
                        'label': 'TEMPORAL_PATTERN',
                        'category': temp_category,
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': temporal_text_orig,
                        'source': 'temporal_extraction',
                        'row_idx': row_id
                    })
        
        # TREATMENT_REASON (from reason column specific patterns)
        if reason and reason != 'nan':
            reason_patterns = [
                (r'to\s+treat\s+(\w+(?:\s+\w+)*)', 'treatment_goal'),
                (r'to\s+manage\s+(\w+(?:\s+\w+)*)', 'management_goal'),
                (r'control\s+of\s+(\w+(?:\s+\w+)*)', 'control_goal'),
                (r'for\s+(\w+(?:\s+\w+)*)', 'indication')
            ]
            for pattern, category in reason_patterns:
                for match in re.finditer(pattern, reason, re.IGNORECASE):
                    custom_entities.append({
                        'text': match.group(0),
                        'label': 'TREATMENT_REASON',
                        'category': category,
                        'start': match.start(),
                        'end': match.end(),
                        'original_text': reason_orig,
                        'source': 'reason_column',
                        'row_idx': row_id
                    })
    
    return pd.DataFrame(custom_entities)
if __name__ == "__main__":
    # Use the global TemporalStandardizer class 
    temporal_standardizer = TemporalStandardizer()

    # Process temporal information using the standardize_temporal_column method
    df_treatments_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_treatments, 
        'time'
    )
    
    # Process duration column as well
    df_treatments_processed, duration_report = temporal_standardizer.standardize_temporal_column(
        df_treatments_processed, 
        'duration'
    )

    print(f"\nTemporal extraction results:")
    print(f"Duration extracted from 'time' column: {temporal_report['extracted_durations']} rows")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine treatments information into comprehensive text
    df_treatments_processed['combined_text'] = df_treatments_processed.apply(
        lambda row: create_treatment_text(row),
        axis=1
    )

    # Run NER extraction with BC5CDR
    df_treatments_entities, treatments_summary, treatments_rules = run_medical_ner_extraction(
        df_treatments_processed,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(treatments_summary['entity_types'])

    # Custom entity extraction for treatment-specific entities (keeps correct row_idx)
    print("\n=== CUSTOM TREATMENT ENTITY EXTRACTION ===")


    # Extract custom entities
    df_custom_treatments_entities = extract_treatment_entities_custom(df_treatments_processed)

    # Combine all entities (preserve row_idx/category when present)
    df_all_treatments_entities = pd.concat(
        [df_treatments_entities, df_custom_treatments_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_treatments_entities.empty and 'label' in df_all_treatments_entities.columns:
        print(df_all_treatments_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze treatment entities
    treatment_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not treatment_entities.empty:
        print(f"\n=== Top Treatments ===")
        print(treatment_entities['text'].str.lower().value_counts().head(20))

    # Analyze medications
    medication_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'MEDICATION'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not medication_entities.empty:
        print(f"\n=== Top Medications ===")
        print(medication_entities['text'].str.lower().value_counts().head(20))

    # Analyze conditions being treated
    condition_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'CONDITION'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not condition_entities.empty:
        print(f"\n=== Top Conditions Treated ===")
        print(condition_entities['text'].str.lower().value_counts().head(20))

    # Analyze dosages
    dosage_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'DOSAGE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not dosage_entities.empty:
        print(f"\n=== Dosage Distribution ===")
        print(dosage_entities['text'].str.lower().value_counts().head(15))

    # Analyze treatment responses
    response_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT_RESPONSE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not response_entities.empty:
        print(f"\n=== Treatment Response Distribution ===")
        print(response_entities['category'].value_counts())

    # Analyze treatment types
    treatment_type_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT_TYPE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not treatment_type_entities.empty:
        print(f"\n=== Treatment Type Distribution ===")
        print(treatment_type_entities['category'].value_counts())

    # Create treatment-specific labeling functions (unchanged logic)
    print("\n=== Creating Treatment-Specific Labeling Functions ===")

    def create_treatment_labeling_functions():
        COL = 'combined_text'

        def lf_has_treatment(row):
            if row.get('has_treatments') is True:
                # anchor on treatment name if present
                nm = str(row.get('name',''))
                if nm:
                    return {'label':'HAS_TREATMENT','column':'name','match':nm.split()[0].lower(),'category':'presence'}
            elif row.get('has_treatments') is False:
                hit = _first_hit(str(row.get(COL,'')), ['no treatment','none'])
                if hit:
                    return {'label':'NO_TREATMENT','column':COL,'match':hit,'category':'absence'}
            return {'label':'ABSTAIN'}

        def lf_medication_treatment(row):
            name = str(row.get('name','')); dose = str(row.get('dosage','')); text = ' '.join([name,dose,str(row.get(COL,''))])
            hit = _first_hit(text, ['tablet','tablets','pill','mg','mcg','capsule','injection','infusion'])
            if hit:
                # choose best column for offsets
                col = 'name' if hit in name.lower() else ('dosage' if hit in dose.lower() else COL)
                return {'label':'MEDICATION_TREATMENT','column':col,'match':hit,'category':'medication'}
            return {'label':'ABSTAIN'}

        def lf_surgical_treatment(row):
            name = str(row.get('name','')); det = str(row.get('details',''))
            text = name + ' ' + det
            hit = _first_hit(text, ['surgery','surgical','operation','resection','removal','repair','intubation'])
            if hit:
                col = 'name' if hit in name.lower() else 'details'
                return {'label':'SURGICAL_TREATMENT','column':col,'match':hit,'category':'procedure'}
            return {'label':'ABSTAIN'}

        def lf_emergency_treatment(row):
            text = ' '.join([str(row.get('name','')), str(row.get('related condition','')), str(row.get('details',''))])
            hit = _first_hit(text, ['emergency','urgent','cardiac arrest','shock','life support','acls','resuscitation','rapid sequence'])
            if hit:
                # pick the column containing the hit
                for col in ['name','related condition','details']:
                    if hit in str(row.get(col,'')).lower():
                        return {'label':'EMERGENCY_TREATMENT','column':col,'match':hit,'category':'urgency'}
            return {'label':'ABSTAIN'}

        def lf_cancer_treatment(row):
            text = ' '.join([str(row.get('name','')), str(row.get('related condition','')), str(row.get('reason for taking',''))])
            hit = _first_hit(text, ['chemotherapy','cancer','metastases','tumor','oncology','malignant','carcinoma'])
            if hit:
                for col in ['name','related condition','reason for taking']:
                    if hit in str(row.get(col,'')).lower():
                        return {'label':'CANCER_TREATMENT','column':col,'match':hit,'category':'oncology'}
            return {'label':'ABSTAIN'}

        def lf_psychiatric_treatment(row):
            name = str(row.get('name','')); cond = str(row.get('related condition',''))
            drug_hit = _first_hit(name, ['olanzapine','risperidone','haloperidol','quetiapine','trihexyphenidyl'])
            cond_hit = _first_hit(cond, ['bipolar','psychosis','mania','depression','anxiety','affective disorder'])
            if drug_hit:
                return {'label':'PSYCHIATRIC_TREATMENT','column':'name','match':drug_hit,'category':'psychiatry'}
            if cond_hit:
                return {'label':'PSYCHIATRIC_TREATMENT','column':'related condition','match':cond_hit,'category':'psychiatry'}
            return {'label':'ABSTAIN'}

        def lf_chronic_treatment(row):
            t = str(row.get('time','')); d = str(row.get('duration',''))
            hit = _first_hit(t + ' ' + d, ['months','years','chronic','long-term','maintenance'])
            if hit:
                col = 'duration' if hit in d.lower() else 'time'
                return {'label':'CHRONIC_TREATMENT','column':col,'match':hit,'category':'temporal'}
            return {'label':'ABSTAIN'}

        def lf_daily_medication(row):
            freq = str(row.get('frequency',''))
            hit = _first_hit(freq, ['daily','every day','per day'])
            return {'label':'DAILY_MEDICATION','column':'frequency','match':hit,'category':'frequency'} if hit else {'label':'ABSTAIN'}

        def lf_positive_response(row):
            rxn = str(row.get('reaction to treatment','')); det = str(row.get('details',''))
            for col, txt in [('reaction to treatment', rxn), ('details', det)]:
                hit = _first_hit(txt, ['good response','improved','resolved','successful','effective','restored','return of','recovered'])
                if hit:
                    return {'label':'POSITIVE_RESPONSE','column':col,'match':hit,'category':'response'}
            return {'label':'ABSTAIN'}

        def lf_conservative_treatment(row):
            name = str(row.get('name',''))
            hit = _first_hit(name, ['conservative','non-operative','closed treatment'])
            return {'label':'CONSERVATIVE_TREATMENT','column':'name','match':hit,'category':'conservative'} if hit else {'label':'ABSTAIN'}

        def lf_infection_treatment(row):
            text = ' '.join([str(row.get('name','')), str(row.get('related condition','')), str(row.get('reason for taking',''))])
            hit = _first_hit(text, ['antibiotic','infection','endocarditis','sepsis','nafcillin','antimicrobial'])
            if hit:
                for col in ['name','related condition','reason for taking']:
                    if hit in str(row.get(col,'')).lower():
                        return {'label':'INFECTION_TREATMENT','column':col,'match':hit,'category':'infectious'}
            return {'label':'ABSTAIN'}

        def lf_cardiovascular_treatment(row):
            text = ' '.join([str(row.get('name','')), str(row.get('related condition',''))])
            hit = _first_hit(text, ['cardiac','heart','hypovolaemic','shock','arrest','vasopressor','arrhythmia'])
            if hit:
                for col in ['name','related condition']:
                    if hit in str(row.get(col,'')).lower():
                        return {'label':'CARDIOVASCULAR_TREATMENT','column':col,'match':hit,'category':'cardio'}
            return {'label':'ABSTAIN'}

        def lf_treatment_duration(row):
            # anchor on explicit duration tokens if present
            dur = str(row.get('duration','')); time = str(row.get('time',''))
            hit = _first_hit(dur + ' ' + time, ['days','weeks','months','years','hours'])
            if hit:
                col = 'duration' if hit in dur.lower() else 'time'
                return {'label':'TREATMENT_DURATION_MENTIONED','column':col,'match':hit,'category':'duration'}
            return {'label':'ABSTAIN'}

        return [lf_has_treatment, lf_medication_treatment, lf_surgical_treatment, lf_emergency_treatment,
                lf_cancer_treatment, lf_psychiatric_treatment, lf_chronic_treatment, lf_daily_medication,
                lf_positive_response, lf_conservative_treatment, lf_infection_treatment, lf_cardiovascular_treatment,
                lf_treatment_duration]

    # --- materialize spans (use the processed DF which holds combined_text) ---
    treatment_lfs_span = create_treatment_labeling_functions()
    df_treatments_lf_spans = materialize_lf_spans(df_treatments_processed, treatment_lfs_span, id_column='idx')
    print(f"\nLF-generated treatment spans: {len(df_treatments_lf_spans)}")
    if not df_treatments_lf_spans.empty:
        print(df_treatments_lf_spans['label'].value_counts())

    # --- combine model + custom + LF spans ---
    _cols = ['text','label','start','end','original_text','source','row_idx','category']
    for col in _cols:
        if col not in df_treatments_entities.columns: df_treatments_entities[col] = 'bc5cdr' if col=='source' else None
        if col not in df_custom_treatments_entities.columns: df_custom_treatments_entities[col] = 'custom_extraction' if col=='source' else None
        if col not in df_treatments_lf_spans.columns: df_treatments_lf_spans[col] = None

    df_all_treatments_entities = pd.concat(
        [df_treatments_entities[_cols], df_custom_treatments_entities[_cols], df_treatments_lf_spans[_cols]],
        ignore_index=True
    )
    print(f"Total treatment entities (model+custom+LF): {len(df_all_treatments_entities)}")
    df_treatments_lf_spans.to_csv('treatments_lf_generated_spans.csv', index=False)



Processing temporal column: time

Processing temporal column: duration

Temporal extraction results:
Duration extracted from 'time' column: 5037 rows
Temporal types: {'unspecified': 6645, 'post_event': 5777, 'range_reference': 2315, 'past_reference': 2193, 'onset_reference': 2007, 'absolute_date': 1267, 'duration_reference': 914}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 50426 texts in 169 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   1%|          | 1/169 [00:02<05:54,  2.11s/it]


Checkpoint saved at batch 0


Processing batches:   7%|▋         | 11/169 [00:15<03:36,  1.37s/it]


Checkpoint saved at batch 3000


Processing batches:  12%|█▏        | 21/169 [00:28<03:23,  1.38s/it]


Checkpoint saved at batch 6000


Processing batches:  18%|█▊        | 31/169 [00:42<03:11,  1.39s/it]


Checkpoint saved at batch 9000


Processing batches:  24%|██▍       | 41/169 [00:55<02:55,  1.37s/it]


Checkpoint saved at batch 12000


Processing batches:  30%|███       | 51/169 [01:08<02:41,  1.37s/it]


Checkpoint saved at batch 15000


Processing batches:  36%|███▌      | 61/169 [01:21<02:31,  1.40s/it]


Checkpoint saved at batch 18000


Processing batches:  42%|████▏     | 71/169 [01:34<02:18,  1.41s/it]


Checkpoint saved at batch 21000


Processing batches:  48%|████▊     | 81/169 [01:47<02:02,  1.39s/it]


Checkpoint saved at batch 24000


Processing batches:  54%|█████▍    | 91/169 [02:00<01:50,  1.42s/it]


Checkpoint saved at batch 27000


Processing batches:  60%|█████▉    | 101/169 [02:14<01:38,  1.45s/it]


Checkpoint saved at batch 30000


Processing batches:  66%|██████▌   | 111/169 [02:27<01:23,  1.44s/it]


Checkpoint saved at batch 33000


Processing batches:  72%|███████▏  | 121/169 [02:40<01:09,  1.44s/it]


Checkpoint saved at batch 36000


Processing batches:  78%|███████▊  | 131/169 [02:54<00:59,  1.57s/it]


Checkpoint saved at batch 39000


Processing batches:  83%|████████▎ | 141/169 [03:08<00:42,  1.51s/it]


Checkpoint saved at batch 42000


Processing batches:  89%|████████▉ | 151/169 [03:22<00:28,  1.56s/it]


Checkpoint saved at batch 45000


Processing batches:  95%|█████████▌| 161/169 [03:35<00:12,  1.55s/it]


Checkpoint saved at batch 48000


Processing batches: 100%|██████████| 169/169 [03:44<00:00,  1.33s/it]


Found 3696 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 80940, 'CHEMICAL': 37907}

=== CUSTOM TREATMENT ENTITY EXTRACTION ===

=== COMBINED Entity Distribution ===
label
DISEASE               80940
TREATMENT             44086
CONDITION             43157
CHEMICAL              37907
TREATMENT_TYPE        33978
TREATMENT_REASON      16765
CONDITION_TYPE        14634
TREATMENT_RESPONSE    13219
TEMPORAL_PATTERN      12523
DOSAGE                12305
ROUTE                  8352
FREQUENCY              6813
MEDICATION             3117
Name: count, dtype: int64

=== Top Treatments ===
text
chemotherapy               522
antibiotics                520
surgery                    285
conservative treatment     282
aspirin                    280
surgical excision          265
blood transfusion          255
conservative management    250
prednisone                 220
intravenous antibiotics    170
surgical resection         170
radiotherapy               168
adjuvan

In [45]:
df_all_treatments_entities

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category
0,Olanzapine,CHEMICAL,11,21,Treatment: Olanzapine tablets for Bipolar affe...,bc5cdr,155216,
1,Bipolar affective disorder,DISEASE,34,60,Treatment: Olanzapine tablets for Bipolar affe...,bc5cdr,155216,
2,mental illness reaction,DISEASE,150,173,Treatment: Olanzapine tablets for Bipolar affe...,bc5cdr,155216,
3,Pain,DISEASE,175,179,Treatment: Olanzapine tablets for Bipolar affe...,bc5cdr,155216,
4,olanzapine,CHEMICAL,326,336,Treatment: Olanzapine tablets for Bipolar affe...,bc5cdr,155216,
...,...,...,...,...,...,...,...,...
416557,Cardiac arrest,EMERGENCY_TREATMENT,0,14,Cardiac arrest,lf:lf_emergency_treatment,97973,urgency
416558,restored,POSITIVE_RESPONSE,38,46,Return of spontaneous circulation was restored,lf:lf_positive_response,97973,response
416559,cardiac,CARDIOVASCULAR_TREATMENT,9,16,Advanced cardiac life support (ACLS) protocol,lf:lf_cardiovascular_treatment,97973,cardio
416560,Intravenous,HAS_TREATMENT,0,11,Intravenous nafcillin,lf:lf_has_treatment,97973,presence


### Extracting Info

In [46]:
df_info

Unnamed: 0,idx,age,sex
0,155216,Sixteen years old,Female
2,133948,36 years old,Female
3,80176,49,male
4,72232,47,Male
5,31864,24 years,Female
...,...,...,...
29995,39279,28,male
29996,137017,82,Male
29997,98004,54,Male
29998,133320,49,Woman


In [47]:
df_info['sex'].value_counts()

sex
Female                                                             10077
Male                                                                9974
male                                                                4588
Woman                                                               2486
female                                                               993
man                                                                  497
woman                                                                391
boy                                                                  190
Boy                                                                   78
Girl                                                                  77
girl                                                                  72
Man                                                                   68
Gentleman                                                              7
Trans man                                      

In [48]:
def standardize_sex(sex_str):
    """
    Standardize sex/gender values based on the variations in your data.
    
    Special cases handled:
    - Veterinary cases (neutered/castrated)
    - Multiple patients in one record
    - Trans individuals
    - Various capitalizations and terms
    """
    if pd.isna(sex_str):
        return None
    
    sex_str = str(sex_str).strip().lower()
    
    # Handle multiple patients first
    if 'both' in sex_str or 'second patient' in sex_str or 'patient case' in sex_str:
        return 'Multiple_Patients'
    
    # Map variations to standard values
    female_terms = ['female', 'woman', 'girl', 'lady']
    male_terms = ['male', 'man', 'boy', 'gentleman']
    
    # Check for trans individuals
    if 'trans' in sex_str:
        if 'trans man' in sex_str or 'transitioned to male' in sex_str:
            return 'Trans_Male'
        elif 'trans woman' in sex_str:
            return 'Trans_Female'
    
    # Check for assigned at birth
    if 'assigned female at birth' in sex_str:
        return 'AFAB'
    
    # Check for phenotype mentions
    if 'phenotype' in sex_str:
        if 'female' in sex_str:
            return 'Female_Phenotype'
    
    # Check for veterinary cases (neutered/castrated)
    if 'neutered' in sex_str or 'castrated' in sex_str or 'entire' in sex_str:
        if any(term in sex_str for term in female_terms):
            if 'neutered' in sex_str:
                return 'Female_Neutered'
            else:
                return 'Female_Intact'
        elif any(term in sex_str for term in male_terms):
            if 'neutered' in sex_str or 'castrated' in sex_str:
                return 'Male_Neutered'
            else:
                return 'Male_Intact'
    
    # Standard cases
    for term in female_terms:
        if term == sex_str:
            return 'Female'
    
    for term in male_terms:
        if term == sex_str:
            return 'Male'
    
    # If we can't classify, return as unclassified
    return 'Unclassified'


def standardize_sex_simple(sex_str):
    """
    Simplified version that maps to just Male/Female/Other categories
    """
    if pd.isna(sex_str):
        return None
    
    sex_str = str(sex_str).strip().lower()
    
    # Handle multiple patients
    if 'both' in sex_str or 'second patient' in sex_str or 'patient case' in sex_str:
        return 'Multiple_Records'
    
    # Simple mapping
    female_terms = ['female', 'woman', 'girl', 'lady', 'trans woman', 'female phenotype']
    male_terms = ['male', 'man', 'boy', 'gentleman', 'trans man']
    
    # Check for main terms
    for term in female_terms:
        if term in sex_str:
            return 'Female'
    
    for term in male_terms:
        if term in sex_str:
            return 'Male'
    
    return 'Other'


def extract_age_from_text(age_str):
    """
    Extract numeric age from various text formats in your data.
    
    Handles cases like:
    - Simple numbers: "62", "35"
    - Years old format: "18 yr old", "37-years old"
    - Written numbers: "Sixteen years old", "Almost three-year old"
    - Complex cases: "Initially 21 years old, 33 years old at last mention"
    - Age ranges: "29 at first admission, 55 at the time of the last mentioned clinical examination"
    """
    if pd.isna(age_str):
        return None
    
    age_str = str(age_str).strip()
    
    # First check if it's already a simple number
    if age_str.isdigit():
        return int(age_str)
    
    # Convert written numbers to digits
    written_numbers = {
        'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
        'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10,
        'eleven': 11, 'twelve': 12, 'thirteen': 13, 'fourteen': 14,
        'fifteen': 15, 'sixteen': 16, 'seventeen': 17, 'eighteen': 18,
        'nineteen': 19, 'twenty': 20, 'thirty': 30, 'forty': 40,
        'fifty': 50, 'sixty': 60, 'seventy': 70, 'eighty': 80,
        'ninety': 90
    }
    
    # Replace written numbers with digits
    age_str_lower = age_str.lower()
    for word, num in written_numbers.items():
        age_str_lower = age_str_lower.replace(word, str(num))
    
    # Handle compound written numbers (e.g., "twenty-one")
    age_str_lower = re.sub(r'(\d+)\s*-\s*(\d+)', lambda m: str(int(m.group(1)) + int(m.group(2))), age_str_lower)
    
    # Extract all numbers from the text
    numbers = re.findall(r'\d+', age_str_lower)
    
    if not numbers:
        return None
    
    # For multiple ages (patient history), typically want the first mentioned age
    # You might want to change this logic based on your needs
    if 'initially' in age_str_lower or 'first' in age_str_lower:
        # Return the first number
        return int(numbers[0])
    elif 'last' in age_str_lower or 'current' in age_str_lower:
        # Return the last number
        return int(numbers[-1])
    else:
        # Default to first number found
        return int(numbers[0])


def get_age_category(age):
    """
    Categorize age into standard medical categories
    """
    if pd.isna(age):
        return 'Unknown'
    
    if age < 2:
        return 'Infant'
    elif age < 13:
        return 'Child'
    elif age < 18:
        return 'Adolescent'
    elif age < 65:
        return 'Adult'
    else:
        return 'Elderly'


def standardize_demographics(df):
    """
    Apply all standardization to the dataframe
    """
    # Create a copy to avoid modifying original
    df_clean = df.copy()
    
    # Standardize sex - both detailed and simple versions
    df_clean['sex_detailed'] = df_clean['sex'].apply(standardize_sex)
    df_clean['sex_standardized'] = df_clean['sex'].apply(standardize_sex_simple)
    
    # Extract numeric age
    df_clean['age_numeric'] = df_clean['age'].apply(extract_age_from_text)
    
    # Add age category
    df_clean['age_category'] = df_clean['age_numeric'].apply(get_age_category)
    
    # Create flags for special cases
    df_clean['is_veterinary'] = df_clean['sex_detailed'].str.contains('Neutered|Intact', na=False)
    df_clean['is_multiple_patients'] = df_clean['sex_detailed'] == 'Multiple_Patients'
    df_clean['has_complex_age'] = df_clean['age'].str.contains('initially|first|last|mention', case=False, na=False)
    
    return df_clean


    
# Apply to your dataframe:
df_info_clean = standardize_demographics(df_info)

# View results:
print("\nStandardized Sex Distribution:")
print(df_info_clean['sex_standardized'].value_counts())
print("\nAge Distribution:")
print(df_info_clean['age_category'].value_counts())
print("\nSpecial Cases:")
print(f"Veterinary cases: {df_info_clean['is_veterinary'].sum()}")
print(f"Multiple patient records: {df_info_clean['is_multiple_patients'].sum()}")
print(f"Complex age descriptions: {df_info_clean['has_complex_age'].sum()}")


Standardized Sex Distribution:
sex_standardized
Male                15429
Female              14129
Multiple_Records        5
Other                   3
Name: count, dtype: int64

Age Distribution:
age_category
Adult         19376
Elderly        6153
Child          2566
Adolescent     1489
Infant           89
Unknown          82
Name: count, dtype: int64

Special Cases:
Veterinary cases: 13
Multiple patient records: 5
Complex age descriptions: 120


In [49]:
df_info_clean

Unnamed: 0,idx,age,sex,sex_detailed,sex_standardized,age_numeric,age_category,is_veterinary,is_multiple_patients,has_complex_age
0,155216,Sixteen years old,Female,Female,Female,6.0,Child,False,False,False
2,133948,36 years old,Female,Female,Female,36.0,Adult,False,False,False
3,80176,49,male,Male,Male,49.0,Adult,False,False,False
4,72232,47,Male,Male,Male,47.0,Adult,False,False,False
5,31864,24 years,Female,Female,Female,24.0,Adult,False,False,False
...,...,...,...,...,...,...,...,...,...,...
29995,39279,28,male,Male,Male,28.0,Adult,False,False,False
29996,137017,82,Male,Male,Male,82.0,Elderly,False,False,False
29997,98004,54,Male,Male,Male,54.0,Adult,False,False,False
29998,133320,49,Woman,Female,Female,49.0,Adult,False,False,False


In [50]:
# df_info_clean has: idx, age, sex, sex_detailed, sex_standardized, age_numeric, ...

info = df_info_clean.rename(columns={"idx":"row_idx"}).copy()

def age_variants(n: int):
    n = int(n)
    return [
        f"{n} years old", f"{n} year old",
        f"{n}-year-old", f"{n} yo", f"{n} y/o", f"aged {n}"
    ]

SEX_SYNONYMS = {
    "female": ["female", "woman", "female patient", "women"],
    "male":   ["male", "man", "male patient", "men"],
}

rows = []
for r in info.itertuples(index=False):
    rid = r.row_idx

    # 1) Use the raw age text if present
    if isinstance(r.age, str) and r.age.strip():
        rows.append({"row_idx": rid, "text": r.age.strip(), "label": "Age", "table": "info"})

    # 2) Generate common age variants from numeric
    if pd.notna(r.age_numeric):
        for t in age_variants(r.age_numeric):
            rows.append({"row_idx": rid, "text": t, "label": "Age", "table": "info"})

    # 3) Sex synonyms (driven by standardized sex if available)
    sex_std = None
    for s in (getattr(r, "sex_standardized", None), getattr(r, "sex_detailed", None), getattr(r, "sex", None)):
        if isinstance(s, str) and s.strip():
            sex_std = s.strip().lower()
            break
    if sex_std in SEX_SYNONYMS:
        for t in SEX_SYNONYMS[sex_std]:
            rows.append({"row_idx": rid, "text": t, "label": "Sex", "table": "info"})

        # 4) Age+sex combos (very common in narratives, e.g., "36-year-old female")
        if pd.notna(r.age_numeric):
            for av in age_variants(r.age_numeric):
                for sx in SEX_SYNONYMS[sex_std]:
                    rows.append({"row_idx": rid, "text": f"{av} {sx}", "label": "AgeSex", "table": "info"})

df_info_entities = pd.DataFrame(rows).drop_duplicates(["row_idx","text","label"])
print(df_info_entities.shape, df_info_entities.label.value_counts())


(1029121, 4) label
AgeSex    707568
Age       203321
Sex       118232
Name: count, dtype: int64


In [51]:
all_entities = pd.concat([
    df_all_physiological_entities.assign(table='physiological'),
    df_psychological_entities.assign(table='psychological'),
    df_vaccination_entities.assign(table='vaccination'),
    df_allergies_entities.assign(table='allergies'),
    df_drug_usage_entities.assign(table='drug_usage'),
    df_all_surgery_entities.assign(table='surgery'),
    df_all_symptom_entities.assign(table='symptoms'),
    df_all_diagnosis_entities.assign(table='diagnosis'),
    df_all_treatments_entities.assign(table='treatments'),
    df_info_entities.assign(table='info')
], ignore_index=True)

print(f"Total entities across all tables: {len(all_entities):,}")
print("\nEntity distribution by table:")
print(all_entities['table'].value_counts())
print("\nEntity types found:")
print(all_entities['label'].value_counts())

Total entities across all tables: 2,473,487

Entity distribution by table:
table
info             1029121
treatments        416562
diagnosis         400811
symptoms          383863
surgery           186191
physiological      51759
psychological       3402
allergies            934
drug_usage           715
vaccination          129
Name: count, dtype: int64

Entity types found:
label
AgeSex                                   707568
DISEASE                                  281269
Age                                      203321
ANATOMY                                  125649
Sex                                      118232
                                          ...  
AMINO_ACID                                  111
ANATOMY_WITH_LATERALITY_AND_DIRECTION        52
ANATOMICAL_SYSTEM                            48
NO_FRACTURE                                  19
DEVELOPING_ANATOMICAL_STRUCTURE               1
Name: count, Length: 91, dtype: int64


In [52]:
all_entities['label'].unique()

array(['DISEASE', 'CHEMICAL', 'ANATOMY_WITH_LATERALITY', 'ANATOMY',
       'ANATOMY_WITH_DIRECTION', 'ANATOMY_WITH_LATERALITY_AND_DIRECTION',
       'MUSCULOSKELETAL', 'BILATERAL_CONDITION', 'CARDIAC_ANATOMY',
       'NEUROLOGICAL_ANATOMY', 'PATHOLOGICAL_FORMATION',
       'MULTI_TISSUE_STRUCTURE', 'TISSUE', 'ORGANISM_SUBDIVISION', 'CELL',
       'ORGAN', 'CANCER', 'ORGANISM_SUBSTANCE', 'ORGANISM',
       'IMMATERIAL_ANATOMICAL_ENTITY', 'GENE_OR_GENE_PRODUCT',
       'SIMPLE_CHEMICAL', 'CELLULAR_COMPONENT', 'AMINO_ACID',
       'ANATOMICAL_SYSTEM', 'DEVELOPING_ANATOMICAL_STRUCTURE',
       'PROCEDURE', 'LATERALITY', 'HIP_SURGERY', 'FRACTURE_SURGERY',
       'MINIMALLY_INVASIVE', 'KNEE_SURGERY', 'BILATERAL_PROCEDURE',
       'EMERGENCY_SURGERY', 'SYMPTOM', 'SYMPTOM_TYPE', 'SEVERITY',
       'TEMPORAL_PATTERN', 'CHRONIC_SYMPTOM', 'BACK_PAIN', 'SEVERE_PAIN',
       'HIP_PAIN', 'MOBILITY_ISSUE', 'LOCALIZED_PAIN',
       'BILATERAL_SYMPTOM', 'KNEE_PAIN', 'NEUROLOGICAL',
       'PROGRESSIVE_

In [53]:
ROOT = Path.cwd().parent
all_entities_filepath = ROOT/"data"/ "clean"/'all_entities.csv'
all_entities.to_csv(all_entities_filepath, index=False)

In [54]:
all_entities.head()

Unnamed: 0,text,label,start,end,original_text,source,row_idx,category,table,source_column
0,posttraumatic arthritis,DISEASE,48.0,71.0,History of left elbow arthrodesis performed fo...,bc5cdr,80176,,physiological,
1,pain,DISEASE,116.0,120.0,"Inability to walk since babyhood, did not walk...",bc5cdr,31864,,physiological,
2,fracture,DISEASE,151.0,159.0,"Inability to walk since babyhood, did not walk...",bc5cdr,31864,,physiological,
3,Coxa vara deformity,DISEASE,0.0,19.0,"Coxa vara deformity of bilateral hips, bilater...",bc5cdr,149866,,physiological,
4,fracture,DISEASE,75.0,83.0,"Coxa vara deformity of bilateral hips, bilater...",bc5cdr,149866,,physiological,


### NER Model Training

#### Stage 1: Configuration for Data Ingestion

In [None]:
TEXT_COL = "note"        # main note column
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
MAX_LEN = 256

In [56]:
df_lean_file = ROOT / "data" / "clean" / "augmented_notes_30K.csv"
df_lean = pd.read_csv(df_lean_file).reset_index(drop=True)
df_lean.head()
df_lean = df_lean[["idx", TEXT_COL]].dropna(subset=[TEXT_COL]).copy()
df_lean[TEXT_COL] = df_lean[TEXT_COL].astype(str)

In [57]:
all_entities = all_entities.rename(columns={"row_idx": "idx"})
all_entities = all_entities[["idx","text", "label", "start","end", "table", "original_text"]].copy()

##### Label Mapping for Silver Spanning

In [58]:
LABEL_MAP = {
    # Problems
    "DISEASE":"Problem","SYMPTOM":"Problem","CONDITION":"Problem","CONDITION_TYPE":"Problem","CANCER":"Problem","FINDING":"Problem",
    # Anatomy
    "ANATOMY":"Anatomy","ANATOMY_WITH_LATERALITY":"Anatomy","TISSUE":"Anatomy","MULTI_TISSUE_STRUCTURE":"Anatomy","ORGAN":"Anatomy",
    "ANATOMICAL_SYSTEM":"Anatomy","ORGANISM_SUBDIVISION":"Anatomy","DEVELOPING_ANATOMICAL_STRUCTURE":"Anatomy","IMMATERIAL_ANATOMICAL_ENTITY":"Anatomy",
    "CELL":"Anatomy","CELLULAR_COMPONENT":"Anatomy",
    # Meds / substances
    "CHEMICAL":"Medication","SIMPLE_CHEMICAL":"Medication","MEDICATION":"Medication",
    "GENE_OR_GENE_PRODUCT":"Substance","AMINO_ACID":"Substance","ORGANISM":"Substance","ORGANISM_SUBSTANCE":"Substance",
    # Procedures/tests
    "PROCEDURE":"Procedure","TEST":"TestName","TEST_TYPE":"TestType",
    # Modifiers
    "LATERALITY":"Laterality","SEVERITY":"Severity","TEMPORAL_PATTERN":"TemporalPattern","MEASUREMENT":"Measurement",
    # Treatment related
    "TREATMENT":"TreatmentName","TREATMENT_TYPE":"TreatmentType","TREATMENT_RESPONSE":"TreatmentResponse","TREATMENT_REASON":"TreatmentReason",
    # Dosing
    "DOSAGE":"Dosage","FREQUENCY":"Frequency","ROUTE":"Route",
    # Symptom subtype
    "SYMPTOM_TYPE":"Problem",
    # Info
    "AGE":"Age","SEX":"Sex","AGESEX":"AgeSex","AGE_SEX":"AgeSex",
}
ALLOWED_LABELS = {
    "Problem","Anatomy","Medication","Procedure","TestName","TestType","Laterality","Severity","TemporalPattern",
    "Measurement","Dosage","Frequency","Route","TreatmentName","TreatmentType","TreatmentResponse","TreatmentReason","Age","Sex","AgeSex"
}

def normalize_label(lbl: str) -> str:
    if not isinstance(lbl, str):
        return ""
    base = lbl.strip().upper()
    base = LABEL_MAP.get(base, base)
    # normalize exact cases for final schema
    if base in {"AGE","SEX","AGESEX"}:
        return {"AGE":"Age","SEX":"Sex","AGESEX":"AgeSex"}[base]
    return base

def normalize_phrase(s: str) -> str:
    return re.sub(r"\s+", " ", str(s).strip())

#### Stage 2: Labeling Functions & Probes

In [59]:
# =========================
# Stage 2 — robust LF signals + probes
# =========================

ABSTAIN = "ABSTAIN"
USE_PREFILTER = False  

def _is_abstain(x):
    if isinstance(x, dict):
        return x.get('label', 'ABSTAIN').upper() == 'ABSTAIN'
    return (not isinstance(x, str)) or (x.strip() == "") or (x.strip().upper() == "ABSTAIN")

def _norm_bool(s):
    """Treat True/1/yes/'true' as True; works if the column is string or bool."""
    return s.astype(str).str.lower().isin(["true","1","yes","y"])

def _probe_keywords(df, col, patterns, title, n_show=8):
    if col not in df.columns:
        print(f"[PROBE] {title}: '{col}' missing")
        return
    s = df[col].astype(str).str.lower()
    print(f"[PROBE] {title}: rows={len(df)}")
    found_any = False
    for p in patterns:
        m = s.str.contains(p, na=False, regex=True)
        cnt = int(m.sum())
        print(f"  - contains /{p}/ : {cnt}")
        if cnt and not found_any:
            # show a few examples
            print("    examples:")
            for t in s[m].head(n_show):
                print("     ·", t[:120], "…")
            found_any = True
    if not found_any:
        print("  (no examples matched any probe pattern)")

def build_signals_table_dynamic(df, lfs, namespace: str, id_col='idx'):
    """
    Iterate full df, collect fired labels dynamically, and materialize a wide signals table.
    Adds 'lf_provenance' list per idx (lf_name=>label).
    """
    import math
    from collections import defaultdict
    if df is None or len(df) == 0:
        return pd.DataFrame({id_col: []}), []

    assert id_col in df.columns, f"{namespace} df must contain '{id_col}'"

    idx2cols   = defaultdict(set)
    idx2prov   = defaultdict(list)
    all_cols   = set()

    # Iterate rows (no sampling)
    for _, row in df.iterrows():
        idx = row[id_col]
        for lf in lfs:
            try:
                out = lf(row)
            except Exception:
                continue
            
            # Extract label from dict if needed
            if isinstance(out, dict):
                label = out.get('label', 'ABSTAIN')
            else:
                label = out
                
            if not _is_abstain(label):
                col = f"{namespace}.{label}"
                all_cols.add(col)
                idx2cols[idx].add(col)
                idx2prov[idx].append(f"{lf.__name__}=>{label}")

    sig = pd.DataFrame({id_col: df[id_col].drop_duplicates().values})
    for col in sorted(all_cols):
        sig[col] = sig[id_col].map(lambda i: 1 if col in idx2cols.get(i, set()) else 0).astype(int)
    sig["lf_provenance"] = sig[id_col].map(lambda i: idx2prov.get(i, []))
    labels = sorted({c.split(".", 1)[1] for c in all_cols})
    return sig, labels

def unify_signals_safe(*signals, id_col='idx'):
    frames = [s for s in signals if s is not None and len(s) > 0]
    if not frames:
        return pd.DataFrame({id_col: []})

    prov_parts, numeric_frames = [], []
    for s in frames:
        s = s.copy()
        if 'lf_provenance' in s.columns:
            prov_parts.append(s[[id_col, 'lf_provenance']])
            s = s.drop(columns=['lf_provenance'])
        numeric_frames.append(s)

    base = numeric_frames[0]
    for s in numeric_frames[1:]:
        base = base.merge(s, on=id_col, how='outer')

    for c in base.columns:
        if c != id_col and pd.api.types.is_numeric_dtype(base[c]):
            base[c] = base[c].fillna(0).astype(int)

    if prov_parts:
        prov = pd.concat(prov_parts, ignore_index=True)
        prov = prov.groupby(id_col, as_index=False)['lf_provenance'] \
                   .agg(lambda L: sum((v if isinstance(v, list) else [] for v in L), []))
        base = base.merge(prov, on=id_col, how='left')
        base['lf_provenance'] = base['lf_provenance'].apply(lambda v: v if isinstance(v, list) else [])
    return base

def _debug_cov(df, lfs, name, n=2000):
    cov = {}
    sample = df.head(n) if len(df) > n else df
    for lf in lfs:
        hits = 0
        for _, r in sample.iterrows():
            try:
                out = lf(r)
            except Exception:
                out = ABSTAIN
            if not _is_abstain(out):
                hits += 1
        cov[lf.__name__] = f"{hits}/{len(sample)} = {100*hits/max(1,len(sample)):.1f}%"
    return {name: cov}

surg_lfs  = create_surgical_labeling_functions()
symp_lfs  = create_symptom_labeling_functions()
diag_lfs  = create_diagnosis_labeling_functions()
treat_lfs = create_treatment_labeling_functions()

signals = []
coverage_report = {}

# ----- SURGERY -----
dfS = df_surgery.copy()
if USE_PREFILTER and 'has_surgery' in dfS.columns:
    mask = _norm_bool(dfS['has_surgery'])
    dfS = dfS[mask]
# ensure combined_text exists
if 'combined_text' not in dfS.columns:
    dfS['combined_text'] = dfS.apply(
        lambda r: f"{str(r.get('reason',''))} {str(r.get('Type',''))} {str(r.get('details',''))} {str(r.get('outcome',''))}",
        axis=1
    )
# Probes for what the surgery LFs look for
_probe_keywords(
    dfS, "combined_text",
    patterns=[r"\bhip\b", r"\barthroplasty\b", r"\breplacement\b", r"\bfracture\b", r"\bfixation\b", r"\brepair\b", r"\bbilateral\b"],
    title="Surgery combined_text probes"
)
sig_surg, labs_surg = build_signals_table_dynamic(dfS, surg_lfs, "surgery", id_col='idx')
signals.append(sig_surg)
coverage_report.update(_debug_cov(dfS, surg_lfs, "Surgery"))

# ----- SYMPTOMS -----
dfY = df_symptoms.copy()
if USE_PREFILTER and 'has_symptom' in dfY.columns:
    dfY = dfY[_norm_bool(dfY['has_symptom'])]
if 'combined_text' not in dfY.columns:
    dfY['combined_text'] = dfY.apply(
        lambda r: f"{str(r.get('name of symptom',''))} with {str(r.get('intensity of symptom',''))} "
                  f"located in {str(r.get('location',''))} lasting {str(r.get('time',''))} "
                  f"{str(r.get('temporalisation',''))} {str(r.get('details',''))}",
        axis=1
    )
_probe_keywords(
    dfY, "combined_text",
    patterns=[r"\bpain\b", r"\bsevere\b", r"\bchronic\b", r"\bacute\b", r"\bbilateral\b", r"\bhip\b", r"\bknee\b", r"\bback\b"],
    title="Symptoms combined_text probes"
)
sig_symp, labs_symp = build_signals_table_dynamic(dfY, symp_lfs, "symptoms", id_col='idx')
signals.append(sig_symp)
coverage_report.update(_debug_cov(dfY, symp_lfs, "Symptoms"))

# ----- DIAGNOSIS -----
dfD = df_diagnosis.copy()
if USE_PREFILTER and 'has_diagnosis' in dfD.columns:
    dfD = dfD[_norm_bool(dfD['has_diagnosis'])]
if 'combined_text' not in dfD.columns:
    dfD['combined_text'] = dfD.apply(
        lambda r: f"{str(r.get('test',''))} {str(r.get('result',''))} {str(r.get('condition',''))} "
                  f"{str(r.get('details',''))} {str(r.get('time',''))}",
        axis=1
    )
_probe_keywords(
    dfD, "combined_text",
    patterns=[r"\bmri\b", r"\bct\b", r"\bx-ray\b", r"\bradiograph\b", r"\bfracture\b", r"\btumor\b", r"\bnormal\b", r"\bcritical\b"],
    title="Diagnosis combined_text probes"
)
sig_diag, labs_diag = build_signals_table_dynamic(dfD, diag_lfs, "diagnosis", id_col='idx')
signals.append(sig_diag)
coverage_report.update(_debug_cov(dfD, diag_lfs, "Diagnosis"))

# ----- TREATMENTS -----
dfT = df_treatments_processed.copy()
if USE_PREFILTER and 'has_treatments' in dfT.columns:
    dfT = dfT[_norm_bool(dfT['has_treatments'])]
if 'combined_text' not in dfT.columns:
    dfT['combined_text'] = dfT.apply(
        lambda r: " ".join(
            str(x) for x in [
                r.get('name',''), r.get('related condition',''),
                r.get('dosage',''), r.get('frequency',''),
                r.get('route',''), r.get('time',''),
                r.get('duration',''), r.get('reason for taking',''),
                r.get('reaction to treatment',''), r.get('details','')
            ]
        ),
        axis=1
    )
_probe_keywords(
    dfT, "combined_text",
    patterns=[r"\bemergency\b", r"\bshock\b", r"\bchemotherapy\b", r"\bantibiotic\b", r"\bmorphine\b", r"\bsurgery\b"],
    title="Treatments combined_text probes"
)
sig_treat, labs_treat = build_signals_table_dynamic(dfT, treat_lfs, "treatments", id_col='idx')
signals.append(sig_treat)
coverage_report.update(_debug_cov(dfT, treat_lfs, "Treatment"))

# ----- unify & show -----
signals_idx = unify_signals_safe(sig_surg, sig_symp, sig_diag, sig_treat)
signals_idx.to_parquet("lf_signals_index.parquet", index=False)

print(f"[Stage 2] Signals index: {signals_idx.shape}")
print("Any surgery.* flags?", any(c.startswith("surgery.") for c in signals_idx.columns))
print("Any symptoms.* flags?", any(c.startswith("symptoms.") for c in signals_idx.columns))
print("Any diagnosis.* flags?", any(c.startswith("diagnosis.") for c in signals_idx.columns))
print("Any treatments.* flags?", any(c.startswith("treatments.") for c in signals_idx.columns))

# show a few rows that actually fired
if 'lf_provenance' in signals_idx.columns:
    demo = signals_idx[signals_idx['lf_provenance'].str.len() > 0].head(10)
    print("[Stage 2] Examples with provenance (first 10):")
    display(demo)

print("Coverage (sampled):")
for k, v in coverage_report.items():
    print(f"  {k:<9}: {v}")


[PROBE] Surgery combined_text probes: rows=35864
  - contains /\bhip\b/ : 506
    examples:
     · idiopathic osteonecrosis of the femoral head total hip arthroplasty (tha) first tha on the left hip discharged in good c …
     · pain and limited rom in the contralateral hip joint total hip arthroplasty (tha) second tha on the contralateral hip dis …
     · femoral neck fracture with dislocation of the femoral head into the pelvis hip surgery with lateral approach and anterio …
     · fistula injury of the ipsilateral ureter urological evaluation and surgery moore prosthesis extracted, hip debrided and  …
     · severe osteoarthritis pain total left hip arthroplasty nan nan …
     · extensive femoral bone loss with displacement of the femoral component, femoral pseudo-tumor revision total hip arthropl …
     · severe osteoarthritis pain that hindered baseline activities total left hip arthroplasty nan nan …
     · extensive femoral bone loss with displacement of the femoral component, f

Unnamed: 0,idx,surgery.BILATERAL_PROCEDURE,surgery.EMERGENCY_SURGERY,surgery.FRACTURE_SURGERY,surgery.HIP_SURGERY,surgery.KNEE_SURGERY,surgery.MINIMALLY_INVASIVE,symptoms.ACUTE_ONSET,symptoms.BACK_PAIN,symptoms.BILATERAL_SYMPTOM,...,treatments.DAILY_MEDICATION,treatments.EMERGENCY_TREATMENT,treatments.HAS_TREATMENT,treatments.INFECTION_TREATMENT,treatments.MEDICATION_TREATMENT,treatments.POSITIVE_RESPONSE,treatments.PSYCHIATRIC_TREATMENT,treatments.SURGICAL_TREATMENT,treatments.TREATMENT_DURATION_MENTIONED,lf_provenance
0,14,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,"[lf_pain_with_location=>LOCALIZED_PAIN, lf_ima..."
1,21,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,1,"[lf_chronic_symptom=>CHRONIC_SYMPTOM, lf_acute..."
2,41,0,0,0,0,0,0,0,0,1,...,0,0,1,0,0,0,0,0,1,"[lf_bilateral_symptom=>BILATERAL_SYMPTOM, lf_m..."
3,48,0,0,0,0,0,0,0,0,1,...,0,0,1,0,0,0,0,0,0,"[lf_neurological_symptom=>NEUROLOGICAL, lf_neu..."
4,91,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,"[lf_has_treatment=>HAS_TREATMENT, lf_cardiovas..."
5,94,1,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,"[lf_bilateral_procedure=>BILATERAL_PROCEDURE, ..."
6,104,0,0,0,0,0,0,0,0,0,...,0,0,1,0,1,0,0,1,0,"[lf_has_treatment=>HAS_TREATMENT, lf_medicatio..."
7,156,0,0,0,0,0,0,0,0,0,...,1,0,1,0,1,0,0,0,0,"[lf_chronic_symptom=>CHRONIC_SYMPTOM, lf_has_t..."
8,169,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,0,"[lf_pain_with_location=>LOCALIZED_PAIN, lf_pai..."
9,171,0,0,0,0,0,0,0,0,0,...,0,0,1,0,0,1,0,0,0,"[lf_pain_with_location=>LOCALIZED_PAIN, lf_neu..."


Coverage (sampled):
  Surgery  : {'lf_hip_surgery': '34/2000 = 1.7%', 'lf_knee_surgery': '29/2000 = 1.4%', 'lf_fracture_surgery': '86/2000 = 4.3%', 'lf_bilateral_procedure': '71/2000 = 3.5%', 'lf_minimally_invasive': '98/2000 = 4.9%', 'lf_emergency_procedure': '35/2000 = 1.8%'}
  Symptoms : {'lf_severe_pain': '127/2000 = 6.3%', 'lf_chronic_symptom': '440/2000 = 22.0%', 'lf_neurological_symptom': '158/2000 = 7.9%', 'lf_bilateral_symptom': '57/2000 = 2.9%', 'lf_acute_onset': '103/2000 = 5.2%', 'lf_progressive_symptom': '189/2000 = 9.4%', 'lf_mobility_issue': '54/2000 = 2.7%', 'lf_pain_with_location': '1399/2000 = 70.0%', 'lf_systemic_symptom': '112/2000 = 5.6%'}
  Diagnosis: {'lf_imaging_test': '638/2000 = 31.9%', 'lf_fracture_diagnosis': '91/2000 = 4.5%', 'lf_neoplastic_finding': '440/2000 = 22.0%', 'lf_normal_finding': '232/2000 = 11.6%', 'lf_critical_finding': '28/2000 = 1.4%', 'lf_bone_pathology': '132/2000 = 6.6%', 'lf_vascular_finding': '242/2000 = 12.1%', 'lf_inflammatory_finding'

#### Stage 3: Create Silver Spans for NER

In [60]:
# STAGE 3 — Silver spans for NER 
# ------------------------------------------------------
ae = all_entities.dropna(subset=["text","label"]).copy()
ae["text"]  = ae["text"].astype(str).map(normalize_phrase)
ae["label"] = ae["label"].astype(str).map(normalize_label)
ae = ae[(ae["text"].str.len() >= 2) & (ae["label"].isin(ALLOWED_LABELS))]
ae = ae[["idx","text","label"]].drop_duplicates()

print("[Stage 3] Candidate phrases:", ae.shape)

# Pre-group for speed: idx -> [{'text','label'}, ...]
ents_by_idx = {k: g[["text","label"]].to_dict("records") for k,g in ae.groupby("idx")}

def find_all(text: str, phrase: str):
    return [(m.start(), m.end()) for m in re.finditer(re.escape(phrase), text, flags=re.IGNORECASE)]

def dedupe_overlaps(spans):
    spans = sorted(spans, key=lambda x: (x["label"], x["start"], -(x["end"]-x["start"])))
    kept = []
    for sp in spans:
        conflict = False
        for kp in kept:
            if sp["label"] == kp["label"] and not (sp["end"] <= kp["start"] or sp["start"] >= kp["end"]):
                conflict = True
                if (sp["end"]-sp["start"]) > (kp["end"]-kp["start"]):
                    kp.update(sp)
                break
        if not conflict:
            kept.append(sp)
    return kept

silver_rows, matched_notes = [], 0
for r in df_lean.itertuples(index=False):
    note_id = r.idx
    text    = getattr(r, TEXT_COL)
    spans = []
    for ent in ents_by_idx.get(note_id, []):
        for (s,e) in find_all(text, ent["text"]):
            spans.append({"start": s, "end": e, "label": ent["label"], "text": text[s:e]})
    spans = dedupe_overlaps(spans)
    matched_notes += int(len(spans) > 0)
    silver_rows.append({"idx": note_id, "text": text, "silver_spans": spans})

silver_df = pd.DataFrame(silver_rows)
silver_df["n_spans"] = silver_df["silver_spans"].str.len()
print(f"[Stage 3] Notes with ≥1 span: {matched_notes}/{len(silver_df)} = {matched_notes/len(silver_df):.1%}")

# Light confidence filter 
rows = []
for r in silver_df.itertuples(index=False):
    for sp in r.silver_spans:
        rows.append({"idx": r.idx, "note_text": r.text, **sp})
cand = pd.DataFrame(rows)

if len(cand):
    freq = (cand.groupby(["label","text"]).size()/len(silver_df)).to_dict()
    def conf_score(row):
        s = 0.0
        s += 0.45 * freq.get((row["label"], row["text"]), 0)
        s += 0.25 * (1.0 if row["note_text"][row["start"]:row["end"]] == row["text"] else 0.0)
        local = row["note_text"][max(0,row["start"]-40):min(len(row["note_text"]),row["end"]+40)]
        s += 0.30 * (fuzz.partial_ratio(row["text"].lower(), local.lower())/100.0)
        return min(s,1.0)
    cand["conf"] = cand.apply(conf_score, axis=1)
    cand = cand[cand["conf"] >= 0.35]

    def rebuild(group):
        spans = group[["start","end","label","text"]].to_dict("records")
        return pd.Series({"text": group["note_text"].iloc[0], "silver_spans": spans})

    silver_df = cand.groupby("idx").apply(rebuild).reset_index()
    silver_df["n_spans"] = silver_df["silver_spans"].str.len()

silver_df.to_parquet("silver_spans_filtered.parquet", index=False)
print(f"[Stage 3] Silver spans saved. Notes with ≥1 span: {(silver_df['n_spans']>0).sum()}")

[Stage 3] Candidate phrases: (1876570, 3)
[Stage 3] Notes with ≥1 span: 29755/30000 = 99.2%
[Stage 3] Silver spans saved. Notes with ≥1 span: 29755


#### Stage 4: Train NER Model on Silver Spans

In [61]:
# ======================================================
# Stage 4 — Build HF dataset & train NER on Silver (fixed)
# ======================================================

from seqeval.metrics import precision_score, recall_score, f1_score
import gc, numpy as np

# label space for BIO tags
entity_labels = sorted({sp["label"] for L in silver_df["silver_spans"] for sp in L}) if len(silver_df) else []
id2label_list = ["O"] + sum(([f"B-{l}", f"I-{l}"] for l in entity_labels), [])
label2id = {l:i for i,l in enumerate(id2label_list)}
id2label = {i:l for i,l in enumerate(id2label_list)}

MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
MAX_LEN = 256
tok = AutoTokenizer.from_pretrained(MODEL_NAME)

def to_bio(example):
    text, spans = example["text"], example["silver_spans"]
    enc = tok(text, truncation=True, max_length=MAX_LEN, return_offsets_mapping=True)
    # start with -100 for all tokens (ignored by loss/metrics)
    labels = [-100] * len(enc["offset_mapping"])

    # mark non-special tokens as "O"
    for i, (a, b) in enumerate(enc["offset_mapping"]):
        if a != b:  # real token span
            labels[i] = label2id["O"]

    # paint spans
    for sp in spans:
        s, e, lab = sp["start"], sp["end"], sp["label"]
        began = False
        for i, (a, b) in enumerate(enc["offset_mapping"]):
            if a == b: 
                continue
            if a >= s and b <= e:
                labels[i] = label2id[f"I-{lab}" if began else f"B-{lab}"]
                began = True

    enc["labels"] = labels
    enc.pop("offset_mapping", None)
    enc["idx"] = example["idx"]
    return enc

# ---- Dataset build ----
if len(silver_df):
    hf = ds.Dataset.from_pandas(silver_df[["idx","text","silver_spans"]], preserve_index=False)
    hf = hf.map(to_bio, remove_columns=["text","silver_spans"], desc="Tokenize + BIO")
    splits = hf.train_test_split(test_size=0.15, seed=42)
    splits.save_to_disk("ds_ner_silver_with_idx")
    with open("ner_label_space.json","w") as f:
        json.dump({"id2label": id2label_list, "label2id": label2id}, f, indent=2)
    print("[Stage 4] Saved HF dataset + label space")
else:
    print("[Stage 4] Skipped (no silver spans)")

# ---- Training ----
if len(silver_df):
    tokenizer = tok
    model = AutoModelForTokenClassification.from_pretrained(
        MODEL_NAME, num_labels=len(id2label), id2label=id2label, label2id=label2id
    )
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

    TA_FIELDS = set(TrainingArguments.__dataclass_fields__.keys())
    ta = dict(
        output_dir="ckpt/ner",
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        learning_rate=2e-5,
        num_train_epochs=3,
        weight_decay=0.01,
        logging_steps=50,
        report_to="none"
    )
    if "warmup_ratio" in TA_FIELDS: ta["warmup_ratio"] = 0.1
    if "fp16" in TA_FIELDS: ta["fp16"] = torch.cuda.is_available()

    # ---- Robust save/eval strategy handling (fixes your error) ----
    eval_key = "eval_strategy" if "eval_strategy" in TA_FIELDS else ("evaluation_strategy" if "evaluation_strategy" in TA_FIELDS else None)
    if eval_key:
        ta[eval_key] = "epoch"
    if "save_strategy" in TA_FIELDS:
        ta["save_strategy"] = "epoch"

    # remove step-based settings so they can't force 'steps'
    for k in ("save_steps", "eval_steps"):
        if k in ta:
            ta.pop(k)

    if "load_best_model_at_end" in TA_FIELDS:
        ta["load_best_model_at_end"] = True
        if "metric_for_best_model" in TA_FIELDS: ta["metric_for_best_model"] = "eval_f1"
        if "greater_is_better" in TA_FIELDS:    ta["greater_is_better"] = True

    args = TrainingArguments(**ta)

    def align_and_decode(preds, labels):
        pred_labels, true_labels = [], []
        for p, t in zip(preds, labels):
            p = np.array(p); t = np.array(t)
            mask = t != -100
            p = p[mask]; t = t[mask]
            pred_labels.append([id2label[i] for i in p])
            true_labels.append([id2label[i] for i in t])
        return pred_labels, true_labels

    def compute_metrics(p):
        logits, labels = p.predictions, p.label_ids
        preds = np.argmax(logits, axis=-1)
        pred_tags, true_tags = align_and_decode(preds, labels)
        return {
            "precision": precision_score(true_tags, pred_tags),
            "recall":    recall_score(true_tags, pred_tags),
            "f1":        f1_score(true_tags, pred_tags),
        }

    trainer = Trainer(
        model=model, args=args,
        train_dataset=splits["train"],
        eval_dataset=splits.get("validation", splits.get("test")),
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
    )
    trainer.train()
    eval_output = trainer.evaluate()
    print("[Stage 4] NER eval:", eval_output)

    trainer.save_model("model_ner_best")
    tokenizer.save_pretrained("model_ner_best")
    print("[Stage 4] Saved model_ner_best")

    del trainer, model; gc.collect()


Tokenize + BIO: 100%|██████████| 29755/29755 [01:07<00:00, 437.79 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 25291/25291 [00:00<00:00, 127857.67 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 4464/4464 [00:00<00:00, 121410.06 examples/s]


[Stage 4] Saved HF dataset + label space


Some weights of BertForTokenClassification were not initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.3066,0.284814,0.547895,0.676287,0.605358
2,0.2701,0.261943,0.57471,0.70448,0.633013
3,0.2432,0.259012,0.592299,0.711481,0.646443


[Stage 4] NER eval: {'eval_loss': 0.25901228189468384, 'eval_precision': 0.5922989807474519, 'eval_recall': 0.7114811774363785, 'eval_f1': 0.6464426898528488, 'eval_runtime': 277.987, 'eval_samples_per_second': 16.058, 'eval_steps_per_second': 1.004, 'epoch': 3.0}
[Stage 4] Saved model_ner_best


#### Stage 5: Weak-Supervised Row Classifier

In [62]:
# ------------------------------------------------------
# STAGE 5 — Weak→supervised row classifier
# ------------------------------------------------------
def make_target_from_signals(signals_idx: pd.DataFrame, positive_any: list, negative_any: list = None):
    df = signals_idx.copy()
    for col in positive_any:
        if col not in df.columns: df[col] = 0
    pos = np.zeros(len(df), dtype=int)
    for col in positive_any:
        pos |= df[col].astype(int).values
    if negative_any:
        for col in negative_any:
            if col not in df.columns: df[col] = 0
        neg = np.zeros(len(df), dtype=int)
        for col in negative_any:
            neg |= df[col].astype(int).values
        y = np.where((pos == 1) & (neg == 0), 1, 0)
    else:
        y = pos
    return df[["idx"]].assign(y=y)

POSITIVE_SIGNALS = [
    "surgery.EMERGENCY_SURGERY",
    "treatments.EMERGENCY_TREATMENT",
    "diagnosis.CRITICAL_FINDING",
]
NEGATIVE_SIGNALS = []

def train_row_classifier(df_notes, signals_idx, positive_cols, negative_cols=None, text_col=TEXT_COL):
    present = [c for c in positive_cols if c in signals_idx.columns and signals_idx[c].sum() > 0]
    missing = [c for c in positive_cols if c not in signals_idx.columns]
    print(f"[Stage 5] POS signals present: {present if present else 'None'}")
    if missing:
        print("[Stage 5] POS signals missing (not in signals_idx):", missing)

    if not present:
        print("[Stage 5] No positive signal columns found. Skipping classifier training.")
        return None

    y_df = make_target_from_signals(signals_idx, positive_cols, negative_cols)
    train_df = df_notes.merge(y_df, on="idx", how="inner")

    cls_counts = train_df["y"].value_counts().to_dict()
    if len(cls_counts) < 2:
        print(f"[Stage 5] Only one class present {cls_counts}. Skipping classifier.")
        return None

    clf_pipe = Pipeline([
        ("tfidf", TfidfVectorizer(max_features=50000, ngram_range=(1,2))),
        ("clf",   LogisticRegression(max_iter=1000, class_weight="balanced", n_jobs=None))
    ])
    clf_pipe.fit(train_df[text_col], train_df["y"])
    joblib.dump(clf_pipe, "row_classifier_emergency.joblib")
    print("[Stage 5] Row classifier saved → row_classifier_emergency.joblib")

    pred = clf_pipe.predict(train_df[text_col])
    print("[Stage 5] Weakly-supervised in-sample report:\n", skl_report(train_df["y"], pred, digits=3))
    return clf_pipe

row_clf = train_row_classifier(df_lean, signals_idx, POSITIVE_SIGNALS, NEGATIVE_SIGNALS, text_col=TEXT_COL)

[Stage 5] POS signals present: ['surgery.EMERGENCY_SURGERY', 'treatments.EMERGENCY_TREATMENT', 'diagnosis.CRITICAL_FINDING']
[Stage 5] Row classifier saved → row_classifier_emergency.joblib
[Stage 5] Weakly-supervised in-sample report:
               precision    recall  f1-score   support

           0      0.999     0.935     0.966     27637
           1      0.540     0.992     0.699      2118

    accuracy                          0.939     29755
   macro avg      0.769     0.964     0.833     29755
weighted avg      0.967     0.939     0.947     29755



#### Stage 6 & 7: Information Retrieval & Classifier Use Cases

In [63]:
# ------------------------------------------------------
# STAGE 6 — IR / Filters (signals & classifier)
# ------------------------------------------------------
def filter_by_signals(df_notes, signals_idx, require: dict, id_col="idx"):
    joined = df_notes.merge(signals_idx, on=id_col, how="left").fillna(0)
    mask = np.ones(len(joined), dtype=bool)
    for col, val in require.items():
        if col not in joined.columns:
            mask &= False
        else:
            mask &= (joined[col].astype(int) == int(bool(val)))
    return joined.loc[mask, [id_col, TEXT_COL]].copy()

def search_within(df, text_query=None, text_col=TEXT_COL, top_k=50):
    if not text_query:
        return df.head(top_k)
    q = str(text_query).lower()
    pri = df[text_col].str.lower().str.contains(q, na=False).astype(int)
    V = TfidfVectorizer(max_features=20000)
    X = V.fit_transform(df[text_col].fillna(""))
    qv = V.transform([text_query])
    sim = (X @ qv.T).toarray().ravel()
    scored = df.assign(_pri=pri, _sim=sim).sort_values(["_pri","_sim"], ascending=[False,False])
    return scored.head(top_k).drop(columns=["_pri","_sim"])

# Example: signal-only filter
if len(signals_idx):
    emer_only = filter_by_signals(df_lean[["idx", TEXT_COL]], signals_idx, {"surgery.EMERGENCY_SURGERY": 1})
    print("[Stage 6] Emergency surgery notes (signal-only):", len(emer_only))

# ------------------------------------------------------
# STAGE 7 — Runtime: prefilter → classifier gate → NER
# ------------------------------------------------------
# Load NER 
ner_pipe = None
if Path("model_ner_best").exists():
    device = 0 if torch.cuda.is_available() else -1
    tok_infer = AutoTokenizer.from_pretrained("model_ner_best")
    model_infer = AutoModelForTokenClassification.from_pretrained("model_ner_best")
    ner_pipe = pipeline("token-classification", model=model_infer, tokenizer=tok_infer,
                        aggregation_strategy="simple", device=device)

def merge_overlaps(ents):
    ents = sorted(ents, key=lambda e: (e["start"], e["end"]))
    merged = []
    for e in ents:
        if merged and e["start"] <= merged[-1]["end"] and e["entity_group"] == merged[-1]["entity_group"]:
            merged[-1]["end"]   = max(merged[-1]["end"], e["end"])
            merged[-1]["score"] = max(merged[-1]["score"], e["score"])
        else:
            merged.append(e)
    return merged

def ner_with_windows(text, tok, pipe, max_length=512, stride=128):
    enc = tok(text, return_offsets_mapping=True, return_overflowing_tokens=True,
              truncation=True, max_length=max_length, stride=stride)
    all_ents = []
    for offsets in enc["offset_mapping"]:
        valid = [(a, b) for (a, b) in offsets if b > a]
        if not valid: 
            continue
        start_char = valid[0][0]; end_char = valid[-1][1]
        chunk = text[start_char:end_char]
        ents = pipe(chunk)
        for e in ents:
            e["start"] += start_char
            e["end"]   += start_char
        all_ents.extend(ents)
    return merge_overlaps(all_ents)

def run_ner(df_notes, text_col, strategy="auto", max_length=512, stride=128):
    if ner_pipe is None:
        raise RuntimeError("NER pipeline not loaded. Train or place model_ner_best first.")
    rows = []
    for r in df_notes.itertuples(index=False):
        text = getattr(r, text_col) or ""
        if strategy == "truncate":
            ents = ner_pipe(text)
        elif strategy == "window":
            ents = ner_with_windows(text, tok_infer, ner_pipe, max_length=max_length, stride=stride)
        else:
            approx_len = len(tok_infer(text, add_special_tokens=False)["input_ids"])
            if approx_len > max_length - 8:
                ents = ner_with_windows(text, tok_infer, ner_pipe, max_length=max_length, stride=stride)
            else:
                ents = ner_pipe(text)
        rows.append({"idx": r.idx, "text": text, "pred_entities": ents})
    return pd.DataFrame(rows)

# Example runtime flow:
# 1) LF prefilter (e.g., emergency across domains)
prefilter = filter_by_signals(
    df_notes=df_lean[["idx", TEXT_COL]],
    signals_idx=signals_idx,
    require={"surgery.EMERGENCY_SURGERY": 1}
) if len(signals_idx) else df_lean[["idx", TEXT_COL]]

# 2) Optional classifier gate (boost precision if the classifier exists)
if Path("row_classifier_emergency.joblib").exists() and len(prefilter):
    gate = joblib.load("row_classifier_emergency.joblib")
    mask = gate.predict(prefilter[TEXT_COL]) == 1
    gated = prefilter.loc[mask]
else:
    gated = prefilter
print(f"[Stage 7] Notes after LF+gate: {len(gated)}")

# 3) NER over filtered set
if len(gated) and ner_pipe is not None:
    pred_df = run_ner(gated, TEXT_COL, strategy="auto") 
    pred_df.to_parquet("ner_predictions_prefiltered.parquet", index=False)
    print("[Stage 7] Saved ner_predictions_prefiltered.parquet")
else:
    print("[Stage 7] Skipped NER inference (no gated set or no NER model)")

# ===== Extended NER runs =====
# Run the full pipeline on your dataset
print("\n[Stage 7 - Extended] Running NER on full dataset...")
if ner_pipe is not None and len(df_lean) > 0:
    # Option 1: Run on all data
    full_predictions = run_ner(df_lean[["idx", TEXT_COL]], TEXT_COL, strategy="auto")
    full_predictions.to_parquet("ner_predictions_full.parquet", index=False)
    print(f"[Stage 7 - Extended] Saved full predictions for {len(full_predictions)} notes")
    
    # Show sample results
    print("\nSample predictions:")
    for _, row in full_predictions.head(3).iterrows():
        if row['pred_entities']:
            print(f"\nNote {row['idx']} entities:")
            for ent in row['pred_entities']:
                print(f"  - {ent['word']} [{ent['entity_group']}] (conf: {ent['score']:.2f})")

Device set to use cpu


[Stage 6] Emergency surgery notes (signal-only): 565


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


[Stage 7] Notes after LF+gate: 555
[Stage 7] Saved ner_predictions_prefiltered.parquet

[Stage 7 - Extended] Running NER on full dataset...
[Stage 7 - Extended] Saved full predictions for 30000 notes

Sample predictions:

Note 155216 entities:
  - discomfort [Problem] (conf: 0.99)
  - neck [Anatomy] (conf: 0.56)
  - lower back [Problem] (conf: 0.46)
  - right [Laterality] (conf: 0.81)
  - sustained [TemporalPattern] (conf: 0.54)
  - neck [Anatomy] (conf: 0.93)
  - back [Anatomy] (conf: 0.88)
  - lumbar [Anatomy] (conf: 0.86)
  - back [Anatomy] (conf: 0.78)
  - neck [Anatomy] (conf: 0.94)
  - neck [Anatomy] (conf: 0.94)
  - lumbar [Anatomy] (conf: 0.87)
  - daily [Frequency] (conf: 0.77)
  - months [TemporalPattern] (conf: 0.56)
  - olanzapine [TreatmentName] (conf: 0.84)
  - tablets [TreatmentType] (conf: 0.90)
  - mental illness [Problem] (conf: 0.59)
  - years [TemporalPattern] (conf: 0.55)
  - bipolar affective disorder [Problem] (conf: 0.97)
  - affective disorder [Problem] (conf: 

#### Model Evaluation

In [64]:
# If you still have the trainer and test set available:
if 'trainer' in locals() and 'splits' in locals() and "test" in splits:
    pred = trainer.predict(splits["test"])
    pred_ids = pred.predictions.argmax(-1)
    
    def align(ids, gold):
        p_tags, t_tags = [], []
        for p, t in zip(ids, gold):
            mask = (t != -100)
            p = p[mask]; t = t[mask]
            p_tags.append([id2label[i] for i in p])
            t_tags.append([id2label[i] for i in t])
        return p_tags, t_tags
    
    p_tags, t_tags = align(pred_ids, pred.label_ids)
    
    # Use seqeval for proper NER evaluation
    from seqeval.metrics import classification_report
    print("\nDetailed NER Evaluation Report:")
    print(classification_report(t_tags, p_tags, digits=3))

#### Analyze Entity Distribution

In [65]:
# Analyze what entities were found
def analyze_predictions(pred_df):
    all_entities = []
    for _, row in pred_df.iterrows():
        for ent in row['pred_entities']:
            all_entities.append({
                'entity': ent['word'],
                'label': ent['entity_group'],
                'score': ent['score']
            })
    
    entity_df = pd.DataFrame(all_entities)
    
    print("\nEntity Label Distribution:")
    print(entity_df['label'].value_counts())
    
    print("\nTop entities by label:")
    for label in entity_df['label'].unique():
        print(f"\n{label}:")
        top_ents = entity_df[entity_df['label']==label]['entity'].value_counts().head(10)
        print(top_ents)
    
    return entity_df

if 'full_predictions' in locals():
    entity_analysis = analyze_predictions(full_predictions)


Entity Label Distribution:
label
Problem              419506
Anatomy              230550
Laterality           109845
TestType              78293
TreatmentName         58265
TestName              53893
TreatmentType         40949
TemporalPattern       40344
Sex                   33242
Medication            32112
Procedure             25353
AgeSex                17917
Dosage                13665
Age                   12233
Measurement           10795
Severity               9450
Route                  8557
Frequency              7385
TreatmentResponse      5605
TreatmentReason         427
Name: count, dtype: int64

Top entities by label:

Problem:
entity
mass              20826
pain              15555
tumor             15028
lesion            10824
swelling           7498
abdominal          5023
fracture           4913
hypertension       3941
fever              3495
abdominal pain     3200
Name: count, dtype: int64

Anatomy:
entity
artery      5631
eye         5068
liver       4613
lung 

#### Quality Checks and Post-processing

In [66]:
# Filter by confidence threshold
def filter_by_confidence(pred_df, min_score=0.8):
    filtered_rows = []
    for _, row in pred_df.iterrows():
        filtered_ents = [e for e in row['pred_entities'] if e['score'] >= min_score]
        filtered_rows.append({
            'idx': row['idx'],
            'text': row['text'],
            'pred_entities': filtered_ents
        })
    return pd.DataFrame(filtered_rows)

high_conf_predictions = filter_by_confidence(full_predictions, min_score=0.85)

#### Final Entities for Downstream Tasks

In [67]:
# Create final entity extraction for downstream use
def create_final_entities(pred_df):
    rows = []
    for _, row in pred_df.iterrows():
        for ent in row['pred_entities']:
            rows.append({
                'idx': row['idx'],
                'entity_text': ent['word'],
                'entity_label': ent['entity_group'],
                'start': ent['start'],
                'end': ent['end'],
                'confidence': ent['score']
            })
    return pd.DataFrame(rows)

final_entities = create_final_entities(full_predictions)
final_entities.to_parquet("final_extracted_entities.parquet", index=False)

# Merge with original data if needed
enriched_df = df_lean.merge(
    final_entities.groupby('idx').agg(
        entities=('entity_text', list),
        labels=('entity_label', list),
        n_entities=('entity_text', 'count')
    ).reset_index(),
    on='idx',
    how='left'
)

In [69]:
enriched_df

Unnamed: 0,idx,note,entities,labels,n_entities
0,155216,"A a sixteen year-old girl, presented to our Ou...","[discomfort, neck, lower back, right, sustaine...","[Problem, Anatomy, Problem, Laterality, Tempor...",46
1,77465,This is the case of a 56-year-old man that was...,"[56 - year - old, man, dump pain, right, back,...","[AgeSex, Sex, Problem, Laterality, Anatomy, Pr...",44
2,133948,A 36-year old female patient visited our hospi...,"[36, female patient, pain and restricted range...","[Age, Sex, Problem, Problem, Laterality, Probl...",56
3,80176,A 49-year-old male presented with a complaint ...,"[49 - year - old, male, pain, left, proximal f...","[AgeSex, Sex, Problem, Laterality, Anatomy, La...",49
4,72232,A 47-year-old male patient was referred to the...,"[47 - year - old, male patient, recurrent, pai...","[AgeSex, Sex, TemporalPattern, Problem, Latera...",71
...,...,...,...,...,...
29995,39279,A 28-year-old male was admitted to the emergen...,"[28 - year - old, male, left, nipple, left, ch...","[AgeSex, Sex, Laterality, Anatomy, Laterality,...",35
29996,137017,"An 82-year-old man (64.5 kg, 175 cm) diagnosed...","[82 - year - old, man, falcine, men, ##ingioma...","[AgeSex, Sex, Problem, Sex, Problem, Problem, ...",49
29997,98004,A 54 year-old man with no past medical history...,"[54, man, cardiac, chest pain, coronary, myoca...","[Age, Sex, Problem, Problem, Anatomy, Problem,...",36
29998,133320,A 49-year-old woman visited the clinic due to ...,"[49 - year - old, woman, mass, right, thigh, m...","[AgeSex, Sex, Problem, Laterality, Anatomy, Pr...",54
