In [149]:
import pandas as pd
import sys
import gc
import os
import spacy
import scispacy
from dateutil import parser
from spacy.matcher import Matcher
import re
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')
import json


In [2]:
print(spacy.__version__)
print(scispacy.__version__)
import en_core_web_sm
print(en_core_web_sm.__version__)
import sys
print(sys.executable)

3.4.4
0.5.1
3.4.0
/opt/anaconda3/envs/nlp_clinical/bin/python


In [3]:
def get_data_path(filename):
    """
    Returns the path to a file in the data/clean directory.
    
    Args:
        filename (str): Name of the file (including extension)
    
    Returns:
        str: Full path to the file
    """
    cwd = os.getcwd()
    parent_dir = os.path.dirname(cwd)
    file_path = os.path.join(parent_dir, 'data', 'clean', filename)
    return file_path

In [4]:
df_medical_expanded = get_data_path('df_medical_expanded.csv')
df_medical = pd.read_csv(df_medical_expanded, index_col=0)

In [5]:
df_diagnosis_expanded = get_data_path('df_diagnosis_expanded.csv')
df_diagnosis = pd.read_csv(df_diagnosis_expanded, index_col=0)

In [6]:
df_surgery_expanded = get_data_path('df_surgery_expanded.csv')
df_surgery = pd.read_csv(df_surgery_expanded, index_col=0)

In [7]:
df_symptoms_expanded = get_data_path('df_symptoms_expanded.csv')
df_symptoms = pd.read_csv(df_symptoms_expanded, index_col=0)

In [8]:
df_treatments_expanded = get_data_path('df_treatments_expanded.csv')
df_treatments = pd.read_csv(df_treatments_expanded, index_col=0)

In [9]:
df_info_expanded = get_data_path('df_info_expanded.csv')
df_info = pd.read_csv(df_info_expanded, index_col=0)

### Medical Entity Extractor

In [None]:
class MedicalEntityExtractor:
    def __init__(self, model_name="en_core_sci_sm", batch_size=500):
        """
        Initialize the medical entity extractor.

        Args:
            model_name: ScispaCy model to use. Options:
                       - "en_core_sci_sm" (smaller, faster, generic entities)
                       - "en_core_sci_md" (medium)
                       - "en_ner_bc5cdr_md" (disease/chemical focused - BEST FOR MEDICAL)
                       - "en_ner_bionlp13cg_md" (multiple bio entity types)
            batch_size: Number of texts to process at once
        """
        self.batch_size = batch_size
        self.model_name = model_name
        self.nlp = None

    def load_model(self):
        """Load the SpaCy model with optimizations for memory efficiency."""
        print(f"Loading {self.model_name}...")
        try:
            self.nlp = spacy.load(self.model_name)
        except:
            print(f"Model {self.model_name} not found. Installing...")
            import os
            # Install the appropriate model
            if self.model_name == "en_ner_bc5cdr_md":
                os.system("pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_ner_bc5cdr_md-0.5.1.tar.gz")
            elif self.model_name == "en_ner_bionlp13cg_md":
                os.system("pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_ner_bionlp13cg_md-0.5.1.tar.gz")
            else:
                os.system(f"pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/{self.model_name}-0.5.1.tar.gz")
            self.nlp = spacy.load(self.model_name)

        # Disable unnecessary pipeline components to save memory
        pipes_to_disable = ["tagger", "parser", "attribute_ruler", "lemmatizer"]
        for pipe in pipes_to_disable:
            if pipe in self.nlp.pipe_names:
                self.nlp.disable_pipe(pipe)

        print(f"Model loaded. Active pipes: {self.nlp.pipe_names}")

    def process_batch(self, texts, ids):  # renamed 'indices' -> 'ids' for clarity
        batch_results = []
        docs = list(self.nlp.pipe(texts, batch_size=50))
        for doc, original_text, rid in zip(docs, texts, ids):
            for ent in doc.ents:
                batch_results.append({
                    'text': ent.text,
                    'label': ent.label_,
                    'start': ent.start_char,
                    'end': ent.end_char,
                    'original_text': original_text,
                    'row_idx': rid,   # <- comes from the provided id column
                })
        del docs
        gc.collect()
        return batch_results

    def extract_entities_from_df(
        self, df, text_column, *,
        save_checkpoints=True, checkpoint_dir='./ner_checkpoints',
        id_column='idx'                    
    ):
        if self.nlp is None:
            self.load_model()

        texts = df[text_column].fillna('').astype(str).tolist()

        # --- NEW: choose row IDs from the given column; fallback to df.index ---
        if id_column is not None and id_column in df.columns:
            ids_series = df[id_column]
            # if there are NaNs in idx, fall back to positional index for those rows
            ids_series = ids_series.where(pd.notna(ids_series), df.index)
            ids = ids_series.tolist()
            print(f"Stamping row identifier from column: '{id_column}'")
        else:
            ids = df.index.tolist()
            print("Stamping row identifier from DataFrame index")

        total_batches = (len(texts) + self.batch_size - 1) // self.batch_size
        all_results = []

        if save_checkpoints:
            import os
            os.makedirs(checkpoint_dir, exist_ok=True)

        print(f"Processing {len(texts)} texts in {total_batches} batches...")
        print(f"Using model: {self.model_name} for column: {text_column}")

        for batch_start in tqdm(range(0, len(texts), self.batch_size), desc="Processing batches"):
            batch_end   = batch_start + self.batch_size
            batch_texts = texts[batch_start:batch_end]
            batch_ids   = ids[batch_start:batch_end]

            try:
                batch_results = self.process_batch(batch_texts, batch_ids)
                all_results.extend(batch_results)

                if save_checkpoints and ((batch_start // self.batch_size) % 10 == 0):
                    checkpoint_df = pd.DataFrame(all_results)
                    safe_column_name = text_column.replace(' ', '_').replace('/', '_')
                    checkpoint_df.to_csv(
                        f"{checkpoint_dir}/checkpoint_{safe_column_name}_{batch_start}.csv",
                        index=False
                    )
                    print(f"\nCheckpoint saved at batch {batch_start}")
            except Exception as e:
                print(f"\nError processing batch {batch_start}: {e}")
                continue

            if (batch_start // self.batch_size) % 5 == 0:
                gc.collect()

        results_df = pd.DataFrame(all_results)
        results_df['source_column'] = text_column
        return results_df

    def get_entity_summary(self, entities_df):
        """Generate summary statistics of extracted entities."""
        summary = {
            'total_entities': len(entities_df),
            'unique_entities': entities_df['text'].nunique(),
            'entity_types': entities_df['label'].value_counts().to_dict(),
            'top_entities': entities_df['text'].value_counts().head(20).to_dict()
        }
        return summary



In [91]:
# Helper functions for analysis
def analyze_medical_entities(entities_df, min_frequency=5):
    """Analyze extracted entities and create labeling functions."""
    
    # Find frequent medical conditions
    frequent_entities = entities_df['text'].value_counts()
    frequent_entities = frequent_entities[frequent_entities >= min_frequency]
    
    print(f"Found {len(frequent_entities)} entities appearing >= {min_frequency} times")
    
    # Group by entity type
    entity_groups = {}
    for label in entities_df['label'].unique():
        label_entities = entities_df[entities_df['label'] == label]['text'].value_counts()
        entity_groups[label] = label_entities.head(10).to_dict()
    
    return frequent_entities, entity_groups

def create_entity_based_rules(entities_df, target_labels=['DISEASE', 'CHEMICAL'], source_column=None):
    """Create labeling functions based on extracted entities."""
    
    # Get the source column from the entities_df if not provided
    if source_column is None:
        if 'source_column' in entities_df.columns:
            source_column = entities_df['source_column'].iloc[0]
        else:
            raise ValueError("source_column must be provided or available in entities_df")
    
    # Get top entities for each label type
    rules = {}
    
    for label in target_labels:
        if label in entities_df['label'].unique():
            top_entities = entities_df[entities_df['label'] == label]['text'].value_counts().head(20)
            
            # Create a labeling function for this entity type
            def make_labeling_function(entity_list, label_name, column_name):
                def lf(row):
                    # Handle column names with spaces
                    if column_name in row:
                        text = str(row[column_name]).lower()
                    else:
                        # Try alternative column names
                        alt_column = column_name.replace('_', ' ')
                        if alt_column in row:
                            text = str(row[alt_column]).lower()
                        else:
                            return 'ABSTAIN'
                    
                    for entity in entity_list:
                        if entity.lower() in text:
                            return label_name
                    return 'ABSTAIN'
                
                # Create safe function name
                safe_column_name = column_name.replace(' ', '_').replace('/', '_')
                lf.__name__ = f"lf_{label_name.lower()}_{safe_column_name}"
                return lf
            
            rules[f'lf_{label.lower()}_{source_column.replace(" ", "_")}'] = make_labeling_function(
                top_entities.index.tolist(),
                label,
                source_column
            )
    
    return rules

# Main execution function for single column
def run_medical_ner_extraction(df, text_column,
                               model_name="en_ner_bc5cdr_md", batch_size=500,id_column='idx'):
    """
    Complete pipeline to extract medical entities from dataframe.
    
    Args:
        df: dataframe with medical text
        text_column: Column containing the text (NOW REQUIRED, NO DEFAULT)
        model_name: Which ScispaCy model to use
        batch_size: Batch size for processing
    
    Returns:
        entities_df: DataFrame with all extracted entities
        summary: Summary statistics
        rules: Generated labeling functions
    """
    if text_column not in df.columns:
        raise ValueError(f"Column '{text_column}' not found in dataframe.")

    extractor = MedicalEntityExtractor(model_name=model_name, batch_size=batch_size)
    entities_df = extractor.extract_entities_from_df(
        df, text_column, id_column=id_column
    )
    summary = extractor.get_entity_summary(entities_df)
    frequent_entities, entity_groups = analyze_medical_entities(entities_df)
    rules = create_entity_based_rules(entities_df, source_column=text_column)
    return entities_df, summary, rules


#### Extracting Entities From Patient Medical Record

In [130]:
df_medical.head(30)

Unnamed: 0,idx,has_medical_history,physiological context,psychological context,vaccination history,allergies,exercise frequency,nutrition,sexual history,alcohol consumption,drug usage,smoking status
0,155216,True,,Diagnosed with bipolar affective disorder at t...,,,,,,,,
2,133948,True,,Intensifying feelings of helplessness,,,,,,,,
3,80176,True,History of left elbow arthrodesis performed fo...,,,,,,,,,
4,72232,True,,,,,,,,,,
5,31864,True,"Inability to walk since babyhood, did not walk...",,,,,,Got married at the age of 15 and became pregna...,,,
6,26809,True,"Normal Apgar score, no resuscitation required ...",,,,,,,,,
7,149866,True,"Coxa vara deformity of bilateral hips, bilater...",,,,,,,,,
8,87064,True,,Patient could not realize that his symptoms mi...,,,,,,,,
9,123006,True,,,,,,,,,,
10,119317,True,Born at full term by spontaneous vaginal deliv...,Mentally healthy,,,,,,,,


In [84]:
df_medical['psychological context'].value_counts()

psychological context
Depression                                                               85
Bipolar disorder                                                         30
Schizophrenia                                                            29
Anxiety                                                                  27
No significant psychosocial history                                      14
                                                                         ..
Paranoid schizophrenia and bipolar disorder                               1
Mentally retarded with aggressive behavior                                1
Mental retardation grade 2, intelligence quotient of 23                   1
mild phantom limb pain and very frequent nonpainful phantom sensation     1
Exhibiting a tortured expression                                          1
Name: count, Length: 2231, dtype: int64

##### Physiological Contexts

In [138]:
if __name__ == "__main__":
    # Run extraction for the 'physiological context' column
    df_physiological_entities, disease_summary, rules = run_medical_ner_extraction(
        df_medical,
        text_column='physiological context',
        model_name="en_ner_bc5cdr_md",
        batch_size=500,
        id_column='idx')
    
    # Test the generated rules
    print("\nTesting generated labeling functions on df_medical...")
    
    # Select a sample row from the original df_medical to test the rules
    # Ensure df_medical is not empty before accessing iloc[0]
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]
        
        for rule_name, rule_func in rules.items():
            # Apply the rule function to a row from the original df_medical
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column '{e}' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")
    
    print("\n=== DEBUG: Check extracted entities dataframe ===")
    print(f"Shape of df_physiological_entities: {df_physiological_entities.shape}")
    print(f"\nColumn names: {df_physiological_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_physiological_entities.head())
    
    # Check what labels ScispaCy actually found in the entities dataframe
    print("\n=== ACTUAL Entity Labels Found in df_physiological_entities ===")
    if not df_physiological_entities.empty:
        print(df_physiological_entities['label'].value_counts())
        
        # Check for 'ENTITY' label if present
        if 'ENTITY' in df_physiological_entities['label'].values:
            print("\n=== Top ENTITY type examples from df_physiological_entities ===")
            physiological_entity_examples = df_physiological_entities[df_physiological_entities['label'] == 'ENTITY']['text'].value_counts().head(20)
            print(physiological_entity_examples)
    else:
        print("df_physiological_entities is empty.")
    
    # ============= CUSTOM ANATOMY EXTRACTION =============
    print("\n" + "="*60)
    print("CUSTOM ANATOMY EXTRACTION FOR PHYSIOLOGICAL CONTEXT")
    print("="*60)
    
    def extract_anatomy_from_physiological_context(df_medical, text_column, id_column='idx'):
        """Extract anatomy entities from physiological context that BC5CDR might miss"""
        
        anatomy_entities = []
         # build the id series once
        if id_column is not None and id_column in df_medical.columns:
            ids_series = df_medical[id_column].where(pd.notna(df_medical[id_column]), df_medical.index)
        else:
            ids_series = df_medical.index

        texts = df_medical[text_column].fillna('').astype(str)
        
        # Comprehensive anatomy patterns
        anatomy_patterns = {
            'organs': [
                'heart', 'lung', 'lungs', 'liver', 'kidney', 'kidneys', 'brain', 
                'pancreas', 'spleen', 'stomach', 'intestine', 'colon', 'bladder',
                'gallbladder', 'thyroid', 'prostate', 'uterus', 'ovary', 'ovaries'
            ],
            'bones': [
                'bone', 'femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna',
                'skull', 'spine', 'vertebra', 'vertebrae', 'rib', 'ribs', 'pelvis',
                'clavicle', 'scapula', 'sternum', 'patella'
            ],
            'joints': [
                'joint', 'hip', 'knee', 'shoulder', 'elbow', 'wrist', 'ankle',
                'knuckle', 'finger', 'toe', 'neck', 'back'
            ],
            'cardiovascular': [
                'artery', 'arteries', 'vein', 'veins', 'vessel', 'vessels',
                'aorta', 'carotid', 'coronary', 'pulmonary', 'cardiac',
                'ventricle', 'atrium', 'valve'
            ],
            'neurological': [
                'nerve', 'nerves', 'neural', 'spinal cord', 'brainstem',
                'cerebral', 'cerebellum', 'cortex', 'lobe', 'ganglion'
            ],
            'muscles': [
                'muscle', 'muscles', 'tendon', 'ligament', 'fascia',
                'biceps', 'triceps', 'quadriceps', 'hamstring'
            ],
            'regions': [
                'chest', 'abdomen', 'pelvis', 'thorax', 'cranium',
                'extremity', 'limb', 'upper extremity', 'lower extremity'
            ],
            'tissues': [
                'skin', 'tissue', 'membrane', 'mucosa', 'epithelium',
                'cartilage', 'marrow', 'lymph node', 'gland'
            ]
        }
        
        # Laterality terms
        laterality_terms = ['left', 'right', 'bilateral', 'unilateral']
        
        # Process each row
        for (df_index, text_original), idx in zip(texts.items(), ids_series):
            if text_original:
                text = text_original.lower()
                
                # Extract anatomy terms
                for category, terms in anatomy_patterns.items():
                    for term in terms:
                        if term in text:
                            # Find all occurrences
                            import re
                            for match in re.finditer(r'\b' + re.escape(term) + r'\b', text):
                                anatomy_entities.append({
                                    'text': term,
                                    'label': 'ANATOMY',
                                    'category': category,
                                    'start': match.start(),
                                    'end': match.end(),
                                    'original_text': df_medical.at[df_index, text_column],
                                    'source': 'custom_anatomy_extraction',
                                    'row_idx': idx
                                })
                
                # Extract laterality + anatomy combinations
                for lat_term in laterality_terms:
                    # Pattern: "left hip", "bilateral knees", etc.
                    pattern = rf'\b{lat_term}\s+(\w+)\b'
                    matches = re.finditer(pattern, text)
                    for match in matches:
                        full_term = match.group(0)
                        anatomy_part = match.group(1)
                        
                        # Check if the anatomy part is in our patterns
                        for category, terms in anatomy_patterns.items():
                            if any(anatomy_part.startswith(term) for term in terms):
                                anatomy_entities.append({
                                    'text': full_term,
                                    'label': 'ANATOMY_WITH_LATERALITY',
                                    'category': category,
                                    'laterality': lat_term,
                                    'anatomy': anatomy_part,
                                    'start': match.start(),
                                    'end': match.end(),
                                    'original_text': df_medical.at[df_index, text_column],
                                    'source': 'custom_anatomy_extraction',
                                    'row_idx': idx
                                })
                                break
        
        return pd.DataFrame(anatomy_entities)
    
    # Extract custom anatomy entities
    print("\nExtracting anatomy entities...")
    df_anatomy_custom = extract_anatomy_from_physiological_context(df_medical, text_column='physiological context', id_column='idx')
    
    print(f"\nCustom anatomy extraction found {len(df_anatomy_custom)} anatomy mentions")
    
    if not df_anatomy_custom.empty:
        print("\n=== Anatomy Entity Distribution ===")
        print(df_anatomy_custom['label'].value_counts())
        
        print("\n=== Top Anatomy Terms ===")
        print(df_anatomy_custom['text'].value_counts().head(30))
        
        print("\n=== Anatomy by Category ===")
        print(df_anatomy_custom['category'].value_counts())
        
        # Show examples of laterality
        laterality_examples = df_anatomy_custom[df_anatomy_custom['label'] == 'ANATOMY_WITH_LATERALITY']
        if not laterality_examples.empty:
            print("\n=== Examples of Anatomy with Laterality ===")
            print(laterality_examples[['text', 'laterality', 'anatomy']].head(10))
    
    # Combine BC5CDR entities with custom anatomy entities
    print("\n=== COMBINING BC5CDR AND CUSTOM ENTITIES ===")
    df_all_physiological_entities = pd.concat([
        df_physiological_entities,
        df_anatomy_custom[['text', 'label', 'start', 'end', 'original_text']]
    ], ignore_index=True)
    
    print(f"\nTotal combined entities: {len(df_all_physiological_entities)}")
    print("\nCombined entity distribution:")
    print(df_all_physiological_entities['label'].value_counts())
    
    # Create anatomy-specific labeling functions
    print("\n=== Creating Anatomy-Specific Labeling Functions ===")
    
    def create_anatomy_labeling_functions():
        """Create labeling functions for anatomy patterns"""
        
        def lf_cardiac_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            cardiac_terms = ['heart', 'cardiac', 'coronary', 'ventricle', 'atrium', 'valve']
            if any(term in text for term in cardiac_terms):
                return 'CARDIAC_ANATOMY'
            return 'ABSTAIN'
        
        def lf_musculoskeletal(row):
            text = str(row.get('physiological context', '')).lower()
            msk_terms = ['bone', 'joint', 'muscle', 'tendon', 'ligament', 'cartilage']
            if any(term in text for term in msk_terms):
                return 'MUSCULOSKELETAL'
            return 'ABSTAIN'
        
        def lf_bilateral_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            if 'bilateral' in text:
                return 'BILATERAL_CONDITION'
            return 'ABSTAIN'
        
        def lf_neurological_anatomy(row):
            text = str(row.get('physiological context', '')).lower()
            neuro_terms = ['brain', 'nerve', 'neural', 'spinal', 'cerebral']
            if any(term in text for term in neuro_terms):
                return 'NEUROLOGICAL_ANATOMY'
            return 'ABSTAIN'
        
        return [lf_cardiac_anatomy, lf_musculoskeletal, lf_bilateral_anatomy, lf_neurological_anatomy]
    
    # Test anatomy labeling functions
    anatomy_lfs = create_anatomy_labeling_functions()
    
    print("\nTesting anatomy labeling functions:")
    if not df_medical.empty:
        for lf in anatomy_lfs:
            result = lf(df_medical.iloc[0])
            if result != 'ABSTAIN':
                print(f"  {lf.__name__}: {result}")
    keep_cols = ['text','label','start','end','original_text','row_idx','category']
    df_all_physiological_entities = pd.concat(
    [df_physiological_entities, df_anatomy_custom[keep_cols]],
    ignore_index=True
)
    
    # Save combined results
    df_all_physiological_entities.to_csv('physiological_entities_comprehensive.csv', index=False)
    df_anatomy_custom.to_csv('anatomy_entities_detailed.csv', index=False)
    
    print("\n\nEntity extraction complete! Saved:")
    print("- physiological_entities_with_anatomy.csv (combined BC5CDR + anatomy)")
    print("- anatomy_entities_detailed.csv (detailed anatomy with categories)")

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 60 batches...
Using model: en_ner_bc5cdr_md for column: physiological context


Processing batches:   2%|▏         | 1/60 [00:01<01:42,  1.73s/it]


Checkpoint saved at batch 0


Processing batches:  18%|█▊        | 11/60 [00:16<01:05,  1.34s/it]


Checkpoint saved at batch 5000


Processing batches:  35%|███▌      | 21/60 [00:29<00:51,  1.33s/it]


Checkpoint saved at batch 10000


Processing batches:  52%|█████▏    | 31/60 [00:41<00:37,  1.28s/it]


Checkpoint saved at batch 15000


Processing batches:  68%|██████▊   | 41/60 [00:54<00:25,  1.32s/it]


Checkpoint saved at batch 20000


Processing batches:  85%|████████▌ | 51/60 [01:06<00:11,  1.33s/it]


Checkpoint saved at batch 25000


Processing batches: 100%|██████████| 60/60 [01:17<00:00,  1.29s/it]


Found 1249 entities appearing >= 5 times

Testing generated labeling functions on df_medical...
lf_disease_physiological_context applied to row 0: ABSTAIN
lf_chemical_physiological_context applied to row 0: ABSTAIN

=== DEBUG: Check extracted entities dataframe ===
Shape of df_physiological_entities: (36272, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
                      text    label  start  end  \
0  posttraumatic arthritis  DISEASE     48   71   
1                     pain  DISEASE    116  120   
2                 fracture  DISEASE    151  159   
3      Coxa vara deformity  DISEASE      0   19   
4                 fracture  DISEASE     75   83   

                                       original_text  row_idx  \
0  History of left elbow arthrodesis performed fo...    80176   
1  Inability to walk since babyhood, did not walk...    31864   
2  Inability to walk since babyhood, did not walk...    31864   
3  Coxa v

In [139]:
df_all_physiological_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category
0,posttraumatic arthritis,DISEASE,48,71,History of left elbow arthrodesis performed fo...,80176,physiological context,
1,pain,DISEASE,116,120,"Inability to walk since babyhood, did not walk...",31864,physiological context,
2,fracture,DISEASE,151,159,"Inability to walk since babyhood, did not walk...",31864,physiological context,
3,Coxa vara deformity,DISEASE,0,19,"Coxa vara deformity of bilateral hips, bilater...",149866,physiological context,
4,fracture,DISEASE,75,83,"Coxa vara deformity of bilateral hips, bilater...",149866,physiological context,
...,...,...,...,...,...,...,...,...
47377,colon,ANATOMY,61,66,"Coronary arteriosclerosis, spinal canal stenos...",86992,,organs
47378,coronary,ANATOMY,0,8,"Coronary arteriosclerosis, spinal canal stenos...",86992,,cardiovascular
47379,kidney,ANATOMY,0,6,"Kidney stone lithotripsy, hypertension treated...",157822,,organs
47380,pulmonary,ANATOMY,20,29,"Chronic obstructive pulmonary disease, high bl...",77450,,cardiovascular


In [140]:
df_anatomy_custom

Unnamed: 0,text,label,category,start,end,original_text,source,row_idx,laterality,anatomy
0,elbow,ANATOMY,joints,16,21,History of left elbow arthrodesis performed fo...,custom_anatomy_extraction,80176,,
1,left elbow,ANATOMY_WITH_LATERALITY,joints,11,21,History of left elbow arthrodesis performed fo...,custom_anatomy_extraction,80176,left,elbow
2,femur,ANATOMY,bones,87,92,"Coxa vara deformity of bilateral hips, bilater...",custom_anatomy_extraction,149866,,
3,neck,ANATOMY,joints,93,97,"Coxa vara deformity of bilateral hips, bilater...",custom_anatomy_extraction,149866,,
4,bilateral hips,ANATOMY_WITH_LATERALITY,joints,23,37,"Coxa vara deformity of bilateral hips, bilater...",custom_anatomy_extraction,149866,bilateral,hips
...,...,...,...,...,...,...,...,...,...,...
11105,colon,ANATOMY,organs,61,66,"Coronary arteriosclerosis, spinal canal stenos...",custom_anatomy_extraction,86992,,
11106,coronary,ANATOMY,cardiovascular,0,8,"Coronary arteriosclerosis, spinal canal stenos...",custom_anatomy_extraction,86992,,
11107,kidney,ANATOMY,organs,0,6,"Kidney stone lithotripsy, hypertension treated...",custom_anatomy_extraction,157822,,
11108,pulmonary,ANATOMY,cardiovascular,20,29,"Chronic obstructive pulmonary disease, high bl...",custom_anatomy_extraction,77450,,


##### Psychological Contexts

In [94]:
if __name__ == "__main__":
    # Run extraction for the 'psychological context' column
    df_psychological_entities, disease_summary, rules = run_medical_ner_extraction(
        df_medical,
        text_column='psychological context',
        model_name="en_ner_bc5cdr_md",
        batch_size=500,
        id_column='idx')

    # Test the generated rules
    print("\nTesting generated labeling functions on df_medical...")

    # Select a sample row from the original df_medical to test the rules
    # Ensure df_medical is not empty before accessing iloc[0]
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in rules.items():
            # Apply the rule function to a row from the original df_medical
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column '{e}' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")


print("\n=== DEBUG: Check extracted entities dataframe ===")
print(f"Shape of df_medical_entities: {df_psychological_entities.shape}")
print(f"\nColumn names: {df_psychological_entities.columns.tolist()}")
print(f"\nFirst few rows:")
print(df_psychological_entities.head())

# Check what labels ScispaCy actually found in the entities dataframe
print("\n=== ACTUAL Entity Labels Found in df_physiological_entities ===")
if not df_psychological_entities.empty:
    print(df_psychological_entities['label'].value_counts())

    # Check for 'ENTITY' label if present
    if 'ENTITY' in df_psychological_entities['label'].values:
        print("\n=== Top ENTITY type examples from df_medical_entities ===")
        psychological_entity_examples = df_psychological_entities[df_psychological_entities['label'] == 'ENTITY']['text'].value_counts().head(20)
        print(psychological_entity_examples)
else:
    print("df_psychological_entities is empty.")

    
df_psychological_entities.to_csv('psychological_entities_comprehensive.csv', index=False)

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 60 batches...
Using model: en_ner_bc5cdr_md for column: psychological context


Processing batches:   2%|▏         | 1/60 [00:01<01:09,  1.17s/it]


Checkpoint saved at batch 0


Processing batches:  18%|█▊        | 11/60 [00:03<00:14,  3.47it/s]


Checkpoint saved at batch 5000


Processing batches:  35%|███▌      | 21/60 [00:06<00:10,  3.63it/s]


Checkpoint saved at batch 10000


Processing batches:  52%|█████▏    | 31/60 [00:08<00:07,  3.63it/s]


Checkpoint saved at batch 15000


Processing batches:  68%|██████▊   | 41/60 [00:11<00:05,  3.37it/s]


Checkpoint saved at batch 20000


Processing batches:  85%|████████▌ | 51/60 [00:14<00:02,  3.50it/s]


Checkpoint saved at batch 25000


Processing batches: 100%|██████████| 60/60 [00:16<00:00,  3.62it/s]

Found 136 entities appearing >= 5 times

Testing generated labeling functions on df_medical...
lf_disease_psychological_context applied to row 0: ABSTAIN
lf_chemical_psychological_context applied to row 0: ABSTAIN

=== DEBUG: Check extracted entities dataframe ===
Shape of df_medical_entities: (3402, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
                         text    label  start  end  \
0  bipolar affective disorder  DISEASE     15   41   
1                       mania  DISEASE     90   95   
2           Parental distress  DISEASE      0   17   
3                  depression  DISEASE     68   78   
4                     anxiety  DISEASE     83   90   

                                       original_text  row_idx  \
0  Diagnosed with bipolar affective disorder at t...   155216   
1  Diagnosed with bipolar affective disorder at t...   155216   
2                                  Parental distress    90928   




In [95]:
df_psychological_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column
0,bipolar affective disorder,DISEASE,15,41,Diagnosed with bipolar affective disorder at t...,155216,psychological context
1,mania,DISEASE,90,95,Diagnosed with bipolar affective disorder at t...,155216,psychological context
2,Parental distress,DISEASE,0,17,Parental distress,90928,psychological context
3,depression,DISEASE,68,78,Known to local mental health services for 20 y...,45433,psychological context
4,anxiety,DISEASE,83,90,Known to local mental health services for 20 y...,45433,psychological context
...,...,...,...,...,...,...,...
3397,Bipolar disorder,DISEASE,0,16,Bipolar disorder,87937,psychological context
3398,Bipolar disorder,DISEASE,0,16,Bipolar disorder,113022,psychological context
3399,psychiatric,DISEASE,3,14,No psychiatric symptoms or previous psychiatri...,160392,psychological context
3400,psychiatric illness,DISEASE,36,55,No psychiatric symptoms or previous psychiatri...,160392,psychological context


##### Vaccination History

In [96]:
if __name__ == "__main__":
    # BC5CDR model will identify vaccines (CHEMICAL) and diseases (DISEASE)
    df_vaccination_entities, vaccination_summary, vaccination_rules = run_medical_ner_extraction(
        df_medical, 
        text_column='vaccination history',
        model_name="en_ner_bc5cdr_md",  # This model recognizes DISEASE and CHEMICAL entities
        batch_size=300,
        id_column='idx'  # Use 'idx' column for row identifiers
    )

    # Test the generated rules
    print("\nTesting generated labeling functions for vaccination history...")

    # Select a sample row from the original df_medical to test the rules
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in vaccination_rules.items():
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'vaccination_history' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Debug: Check extracted entities
    print("\n=== DEBUG: Check extracted vaccination entities ===")
    print(f"Shape of df_vaccination_entities: {df_vaccination_entities.shape}")
    print(f"\nColumn names: {df_vaccination_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_vaccination_entities.head())

    # Check entity label distribution
    print("\n=== Entity Labels Found in Vaccination History ===")
    if not df_vaccination_entities.empty:
        print(df_vaccination_entities['label'].value_counts())

        # Look at CHEMICAL entities (likely vaccines)
        if 'CHEMICAL' in df_vaccination_entities['label'].values:
            print("\n=== Top CHEMICAL/Vaccine entities ===")
            vaccine_entities = df_vaccination_entities[df_vaccination_entities['label'] == 'CHEMICAL']['text'].value_counts().head(20)
            print(vaccine_entities)

        # Look at DISEASE entities (conditions vaccines prevent)
        if 'DISEASE' in df_vaccination_entities['label'].values:
            print("\n=== Top DISEASE entities (conditions) ===")
            disease_entities = df_vaccination_entities[df_vaccination_entities['label'] == 'DISEASE']['text'].value_counts().head(20)
            print(disease_entities)
    else:
        print("df_vaccination_entities is empty.")

def create_vaccination_labeling_functions(entities_df):
    """Create specific labeling functions for vaccination history based on actual entities found"""

    # First, let's see what vaccine-related entities were actually extracted
    vaccine_related_terms = []
    disease_related_terms = []

    # Analyze the top entities to identify vaccine patterns
    top_entities = entities_df['text'].value_counts().head(100)

    for entity, count in top_entities.items():
        entity_lower = entity.lower()

        # Identify vaccine-related terms
        if any(term in entity_lower for term in ['vaccin', 'immuniz', 'shot', 'injection', 'dose']):
            vaccine_related_terms.append(entity)

        # Identify disease/condition terms that vaccines prevent
        if any(term in entity_lower for term in ['tetanus', 'hepatitis', 'measles', 'influenza',
                                                  'covid', 'polio', 'pertussis', 'mumps', 'rubella',
                                                  'pneumococcal', 'meningitis', 'hpv', 'varicella',
                                                  'diphtheria', 'rotavirus']):
            disease_related_terms.append(entity)

    print(f"\nVaccine-related entities found: {vaccine_related_terms[:10]}")
    print(f"Disease-related entities found: {disease_related_terms[:10]}")

    # Create labeling functions based on patterns in data
    def lf_covid_vaccination(row):
        """Detect COVID-19 vaccination"""
        text = str(row['vaccination history']).lower()
        covid_patterns = [
            'covid', 'coronavirus', 'sars-cov-2', 'pfizer', 'moderna',
            'astrazeneca', 'johnson', 'mrna-1273', 'bnt162b2'
        ]
        if any(pattern in text for pattern in covid_patterns):
            return 'COVID_VACCINE'
        return 'ABSTAIN'

    def lf_childhood_vaccines(row):
        """Detect standard childhood vaccinations"""
        text = str(row['vaccination history']).lower()
        childhood_vaccines = [
            'mmr', 'measles', 'mumps', 'rubella', 'varicella', 'chickenpox',
            'polio', 'dtap', 'diphtheria', 'tetanus', 'pertussis', 'whooping',
            'hib', 'hepatitis b', 'rotavirus', 'pcv', 'ipv'
        ]
        if any(vaccine in text for vaccine in childhood_vaccines):
            return 'CHILDHOOD_VACCINES'
        return 'ABSTAIN'

    def lf_influenza_vaccination(row):
        """Detect flu vaccination"""
        text = str(row['vaccination history']).lower()
        flu_patterns = ['influenza', 'flu vaccine', 'flu shot', 'seasonal flu', 'h1n1']
        if any(pattern in text for pattern in flu_patterns):
            return 'FLU_VACCINE'
        return 'ABSTAIN'

    def lf_hepatitis_vaccination(row):
        """Detect hepatitis vaccinations"""
        text = str(row['vaccination history']).lower()
        if any(hep in text for hep in ['hepatitis a', 'hepatitis b', 'hep a', 'hep b', 'havrix', 'engerix']):
            return 'HEPATITIS_VACCINE'
        return 'ABSTAIN'

    def lf_tetanus_vaccination(row):
        """Detect tetanus/Td/Tdap vaccinations"""
        text = str(row['vaccination history']).lower()
        if any(tet in text for tet in ['tetanus', 'tdap', 'td ', 'boostrix', 'adacel']):
            return 'TETANUS_VACCINE'
        return 'ABSTAIN'

    def lf_pneumococcal_vaccination(row):
        """Detect pneumococcal vaccinations"""
        text = str(row['vaccination history']).lower()
        if any(pneumo in text for pneumo in ['pneumococcal', 'pneumonia vaccine', 'prevnar', 'pneumovax']):
            return 'PNEUMO_VACCINE'
        return 'ABSTAIN'

    def lf_travel_vaccines(row):
        """Detect travel-related vaccinations"""
        text = str(row['vaccination history']).lower()
        travel_vaccines = ['yellow fever', 'typhoid', 'japanese encephalitis', 'rabies',
                          'meningococcal', 'cholera']
        if any(vaccine in text for vaccine in travel_vaccines):
            return 'TRAVEL_VACCINES'
        return 'ABSTAIN'

    def lf_vaccination_timing(row):
        """Detect vaccination timing information"""
        text = str(row['vaccination history']).lower()
        timing_patterns = ['booster', 'dose', 'series', 'schedule', 'up to date',
                          'fully vaccinated', 'partially vaccinated']
        if any(pattern in text for pattern in timing_patterns):
            return 'VACCINATION_TIMING'
        return 'ABSTAIN'

    def lf_no_vaccination(row):
        """Detect absence of vaccination"""
        text = str(row['vaccination history']).lower()
        no_vax_patterns = ['no vaccination', 'not vaccinated', 'unvaccinated',
                          'declined', 'refused', 'no history of vaccination']
        if any(pattern in text for pattern in no_vax_patterns):
            return 'UNVACCINATED'
        return 'ABSTAIN'

    def lf_vaccine_reaction(row):
        """Detect vaccine reactions/side effects"""
        text = str(row['vaccination history']).lower()
        reaction_patterns = ['reaction', 'allergy', 'side effect', 'adverse', 'anaphylaxis']
        if any(pattern in text for pattern in reaction_patterns):
            return 'VACCINE_REACTION'
        return 'ABSTAIN'

    # Since we have "years", "months", "weeks" as top entities, let's create time-based functions
    def lf_recent_vaccination(row):
        """Detect recent vaccinations"""
        text = str(row['vaccination history']).lower()
        # Look for patterns indicating recent vaccination
        if any(recent in text for recent in ['weeks ago', 'months ago', 'recently',
                                             'last month', 'last week', 'this year']):
            return 'RECENT_VACCINATION'
        return 'ABSTAIN'

    def lf_historical_vaccination(row):
        """Detect historical vaccination information"""
        text = str(row['vaccination history']).lower()
        if 'history' in text and any(vax in text for vax in ['vaccin', 'immuniz']):
            return 'VACCINATION_HISTORY'
        return 'ABSTAIN'

    return [
        lf_covid_vaccination,
        lf_childhood_vaccines,
        lf_influenza_vaccination,
        lf_hepatitis_vaccination,
        lf_tetanus_vaccination,
        lf_pneumococcal_vaccination,
        lf_travel_vaccines,
        lf_vaccination_timing,
        lf_no_vaccination,
        lf_vaccine_reaction,
        lf_recent_vaccination,
        lf_historical_vaccination
    ]

vaccination_lfs = create_vaccination_labeling_functions(df_vaccination_entities)

print("\n=== Testing Refined Vaccination Labeling Functions ===")
if not df_medical.empty and 'vaccination history' in df_medical.columns:
    # Test on multiple rows to see coverage
    test_rows = min(10, len(df_medical))
    results_summary = {lf.__name__: [] for lf in vaccination_lfs}

    for i in range(test_rows):
        row = df_medical.iloc[i]
        if pd.notna(row['vaccination history']):
            print(f"\nRow {i} vaccination history: {row['vaccination history'][:100]}...")
            for lf in vaccination_lfs:
                result = lf(row)
                if result != 'ABSTAIN':
                    print(f"  {lf.__name__}: {result}")
                    results_summary[lf.__name__].append(result)

# Analyze coverage
print("\n=== Labeling Function Coverage ===")
total_non_na = df_medical['vaccination history'].notna().sum()
for lf in vaccination_lfs:
    labeled_count = sum(lf(row) != 'ABSTAIN' for _, row in df_medical.iterrows()
                       if pd.notna(row.get('vaccination history', '')))
    coverage = (labeled_count / total_non_na * 100) if total_non_na > 0 else 0
    print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

# Save result
df_vaccination_entities.to_csv('vaccination_entities_comprehensive.csv', index=False)


Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: vaccination history


Processing batches:   2%|▏         | 2/100 [00:01<00:44,  2.22it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:15,  5.51it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:13,  5.74it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:06<00:12,  5.47it/s]


Checkpoint saved at batch 9000


Processing batches:  41%|████      | 41/100 [00:07<00:11,  5.28it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:09<00:08,  5.59it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:11<00:06,  5.69it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:12<00:04,  5.73it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:14<00:03,  5.80it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:16<00:01,  5.24it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:17<00:00,  5.66it/s]


Found 6 entities appearing >= 5 times

Testing generated labeling functions for vaccination history...
lf_disease_vaccination_history applied to row 0: ABSTAIN
lf_chemical_vaccination_history applied to row 0: ABSTAIN

=== DEBUG: Check extracted vaccination entities ===
Shape of df_vaccination_entities: (129, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
                text    label  start  end  \
0            Tetanus  DISEASE      0    7   
1            tetanus  DISEASE     25   32   
2       hyposplenism  DISEASE     39   51   
3            tetanus  DISEASE     14   21   
4  tetanus infection  DISEASE     37   54   

                                       original_text  row_idx  \
0  Tetanus vaccination with tetanus immunoglobuli...   119386   
1  Tetanus vaccination with tetanus immunoglobuli...   119386   
2  Vaccinated post-treatment for presumed hypospl...    13774   
3  No history of tetanus vaccination or teta

In [97]:
df_vaccination_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column
0,Tetanus,DISEASE,0,7,Tetanus vaccination with tetanus immunoglobuli...,119386,vaccination history
1,tetanus,DISEASE,25,32,Tetanus vaccination with tetanus immunoglobuli...,119386,vaccination history
2,hyposplenism,DISEASE,39,51,Vaccinated post-treatment for presumed hypospl...,13774,vaccination history
3,tetanus,DISEASE,14,21,No history of tetanus vaccination or tetanus i...,157338,vaccination history
4,tetanus infection,DISEASE,37,54,No history of tetanus vaccination or tetanus i...,157338,vaccination history
...,...,...,...,...,...,...,...
124,Calmette,CHEMICAL,9,17,bacillus Calmette–Guérin (BCG),42118,vaccination history
125,Guérin,CHEMICAL,18,24,bacillus Calmette–Guérin (BCG),42118,vaccination history
126,pandemic,DISEASE,49,57,Had not received seasonal influenza or 2009 H1...,74950,vaccination history
127,tetanus,DISEASE,16,23,Vaccinated with tetanus toxoid once,198046,vaccination history


##### Allergies

In [98]:
# Process allergies using BC5CDR model (recognizes DISEASE and CHEMICAL)
if __name__ == "__main__":
    # Run extraction for the 'allergies' column
    # BC5CDR model will identify allergens (often CHEMICAL entities)
    df_allergies_entities, allergies_summary, allergies_rules = run_medical_ner_extraction(
        df_medical, 
        text_column='allergies',
        model_name="en_ner_bc5cdr_md",  # This model recognizes DISEASE and CHEMICAL entities
        batch_size=300,
        id_column='idx' 
    )

    # Test the generated rules
    print("\nTesting generated labeling functions for allergies...")

    # Select a sample row from the original df_medical to test the rules
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in allergies_rules.items():
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'allergies' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Debug: Check extracted entities
    print("\n=== DEBUG: Check extracted allergy entities ===")
    print(f"Shape of df_allergies_entities: {df_allergies_entities.shape}")
    print(f"\nColumn names: {df_allergies_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_allergies_entities.head())

    # Check entity label distribution
    print("\n=== Entity Labels Found in Allergies ===")
    if not df_allergies_entities.empty:
        print(df_allergies_entities['label'].value_counts())

        # Let's look at the top entities
        print("\n=== Top Allergy-related Entities ===")
        allergy_entities = df_allergies_entities['text'].value_counts().head(30)
        print(allergy_entities)

    # Create specific labeling functions for allergy data
    def create_allergy_labeling_functions(entities_df):
        """Create specific labeling functions for allergies"""

        # Analyze entities to understand patterns
        top_entities = entities_df['text'].value_counts().head(100)
        print(f"\nAnalyzing allergy entities for patterns...")

        # Common allergen categories
        def lf_drug_allergy(row):
            """Detect drug/medication allergies"""
            text = str(row['allergies']).lower()
            drug_patterns = [
                'penicillin', 'amoxicillin', 'ampicillin', 'antibiotic',
                'sulfa', 'aspirin', 'nsaid', 'ibuprofen', 'morphine',
                'codeine', 'contrast', 'iodine', 'latex', 'adhesive'
            ]
            if any(drug in text for drug in drug_patterns):
                return 'DRUG_ALLERGY'
            return 'ABSTAIN'

        def lf_food_allergy(row):
            """Detect food allergies"""
            text = str(row['allergies']).lower()
            food_patterns = [
                'peanut', 'nut', 'shellfish', 'fish', 'milk', 'dairy',
                'egg', 'wheat', 'gluten', 'soy', 'sesame', 'food'
            ]
            if any(food in text for food in food_patterns):
                return 'FOOD_ALLERGY'
            return 'ABSTAIN'

        def lf_environmental_allergy(row):
            """Detect environmental allergies"""
            text = str(row['allergies']).lower()
            env_patterns = [
                'pollen', 'dust', 'mold', 'grass', 'tree', 'ragweed',
                'cat', 'dog', 'animal', 'dander', 'environmental'
            ]
            if any(env in text for env in env_patterns):
                return 'ENVIRONMENTAL_ALLERGY'
            return 'ABSTAIN'

        def lf_no_allergies(row):
            """Detect absence of allergies"""
            text = str(row['allergies']).lower()
            no_allergy_patterns = [
                'no known allergies', 'no allergies', 'nka', 'nkda',
                'no known drug allergies', 'denies allergies', 'none'
            ]
            if any(pattern in text for pattern in no_allergy_patterns):
                return 'NO_ALLERGIES'
            return 'ABSTAIN'

        def lf_allergy_severity(row):
            """Detect severe allergic reactions"""
            text = str(row['allergies']).lower()
            severity_patterns = [
                'anaphylaxis', 'anaphylactic', 'severe', 'life-threatening',
                'epipen', 'epinephrine', 'emergency'
            ]
            if any(pattern in text for pattern in severity_patterns):
                return 'SEVERE_ALLERGY'
            return 'ABSTAIN'

        def lf_allergy_reaction_type(row):
            """Detect specific reaction types"""
            text = str(row['allergies']).lower()
            reaction_patterns = [
                'rash', 'hives', 'swelling', 'itching', 'breathing',
                'wheezing', 'nausea', 'vomiting', 'throat'
            ]
            if any(reaction in text for reaction in reaction_patterns):
                return 'ALLERGIC_REACTION_DESCRIBED'
            return 'ABSTAIN'

        def lf_seasonal_allergy(row):
            """Detect seasonal allergies"""
            text = str(row['allergies']).lower()
            if any(season in text for season in ['seasonal', 'spring', 'fall', 'hay fever']):
                return 'SEASONAL_ALLERGY'
            return 'ABSTAIN'

        def lf_chemical_sensitivity(row):
            """Detect chemical sensitivities"""
            text = str(row['allergies']).lower()
            chemical_patterns = [
                'chemical', 'perfume', 'fragrance', 'smoke', 'detergent',
                'cleaning', 'formaldehyde'
            ]
            if any(chem in text for chem in chemical_patterns):
                return 'CHEMICAL_SENSITIVITY'
            return 'ABSTAIN'

        def lf_multiple_allergies(row):
            """Detect multiple allergies"""
            text = str(row['allergies']).lower()
            # Count commas or "and" as indicators of multiple allergies
            if (text.count(',') >= 2 or text.count(' and ') >= 2) and 'no known' not in text:
                return 'MULTIPLE_ALLERGIES'
            return 'ABSTAIN'

        def lf_allergy_testing(row):
            """Detect allergy testing mentions"""
            text = str(row['allergies']).lower()
            if any(test in text for test in ['tested', 'skin test', 'patch test', 'ige']):
                return 'ALLERGY_TESTED'
            return 'ABSTAIN'

        return [
            lf_drug_allergy,
            lf_food_allergy,
            lf_environmental_allergy,
            lf_no_allergies,
            lf_allergy_severity,
            lf_allergy_reaction_type,
            lf_seasonal_allergy,
            lf_chemical_sensitivity,
            lf_multiple_allergies,
            lf_allergy_testing
        ]

    # Apply the allergy-specific labeling functions
    allergy_lfs = create_allergy_labeling_functions(df_allergies_entities)

    print("\n=== Testing Allergy Labeling Functions ===")
    if not df_medical.empty and 'allergies' in df_medical.columns:
        # Test on multiple rows to see coverage
        test_rows = min(10, len(df_medical))

        for i in range(test_rows):
            row = df_medical.iloc[i]
            if pd.notna(row['allergies']):
                print(f"\nRow {i} allergies: {row['allergies'][:100]}...")
                for lf in allergy_lfs:
                    result = lf(row)
                    if result != 'ABSTAIN':
                        print(f"  {lf.__name__}: {result}")

    # Analyze coverage
    print("\n=== Allergy Labeling Function Coverage ===")
    total_non_na = df_medical['allergies'].notna().sum()
    for lf in allergy_lfs:
        labeled_count = sum(lf(row) != 'ABSTAIN' for _, row in df_medical.iterrows()
                           if pd.notna(row.get('allergies', '')))
        coverage = (labeled_count / total_non_na * 100) if total_non_na > 0 else 0
        print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

    # Analyze specific patterns in allergies
    print("\n=== Common Allergy Patterns ===")
    if 'allergies' in df_medical.columns:
        # Count specific allergen mentions
        penicillin_mentions = df_medical['allergies'].str.contains('penicillin|Penicillin', na=False).sum()
        no_allergy_mentions = df_medical['allergies'].str.contains('no known|NKA|NKDA', na=False, regex=True).sum()
        food_allergy_mentions = df_medical['allergies'].str.contains('peanut|shellfish|milk|egg', na=False, regex=True).sum()

        print(f"Penicillin allergy mentions: {penicillin_mentions}")
        print(f"No known allergies mentions: {no_allergy_mentions}")
        print(f"Food allergy mentions: {food_allergy_mentions}")
        print(f"Total allergy records: {df_medical['allergies'].notna().sum()}")

    # Create a summary of allergen types found
    print("\n=== Allergen Type Summary ===")
    allergen_summary = {
        'drugs': [],
        'foods': [],
        'environmental': [],
        'other': []
    }

    # Categorize top entities
    for entity, count in df_allergies_entities['text'].value_counts().head(50).items():
        entity_lower = entity.lower()
        if any(drug in entity_lower for drug in ['cillin', 'mycin', 'zole', 'statin']):
            allergen_summary['drugs'].append(entity)
        elif any(food in entity_lower for food in ['nut', 'milk', 'egg', 'fish']):
            allergen_summary['foods'].append(entity)
        elif any(env in entity_lower for env in ['pollen', 'dust', 'grass']):
            allergen_summary['environmental'].append(entity)
        else:
            allergen_summary['other'].append(entity)

    for category, items in allergen_summary.items():
        if items:
            print(f"\n{category.upper()}: {items[:10]}")  # Show top 10 in each category
# Save result
df_allergies_entities.to_csv('allergies_entities_comprehensive.csv', index=False)

Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: allergies


Processing batches:   2%|▏         | 2/100 [00:01<00:44,  2.19it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:15,  5.53it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:13,  5.71it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:06<00:12,  5.51it/s]


Checkpoint saved at batch 9000


Processing batches:  42%|████▏     | 42/100 [00:07<00:10,  5.65it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:09<00:08,  5.75it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:11<00:06,  5.54it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:13<00:05,  5.37it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:14<00:03,  5.59it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:16<00:01,  5.49it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:18<00:00,  5.54it/s]


Found 32 entities appearing >= 5 times

Testing generated labeling functions for allergies...
lf_disease_allergies applied to row 0: ABSTAIN
lf_chemical_allergies applied to row 0: ABSTAIN

=== DEBUG: Check extracted allergy entities ===
Shape of df_allergies_entities: (934, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
             text     label  start  end  \
0       allergies   DISEASE      9   18   
1  drug allergies   DISEASE      9   23   
2         allergy   DISEASE     18   25   
3         Allergy   DISEASE      0    7   
4     amoxicillin  CHEMICAL     11   22   

                           original_text  row_idx source_column  
0                     No known allergies    32488     allergies  
1                No known drug allergies    77061     allergies  
2  No sensitivity or allergy to any drug   149806     allergies  
3                 Allergy to amoxicillin    83662     allergies  
4                 All

In [99]:
df_allergies_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column
0,allergies,DISEASE,9,18,No known allergies,32488,allergies
1,drug allergies,DISEASE,9,23,No known drug allergies,77061,allergies
2,allergy,DISEASE,18,25,No sensitivity or allergy to any drug,149806,allergies
3,Allergy,DISEASE,0,7,Allergy to amoxicillin,83662,allergies
4,amoxicillin,CHEMICAL,11,22,Allergy to amoxicillin,83662,allergies
...,...,...,...,...,...,...,...
929,docetaxel,CHEMICAL,28,37,Severe allergic reaction to docetaxel chemothe...,40349,allergies
930,drug allergies,DISEASE,9,23,No known drug allergies,135761,allergies
931,Allergic,DISEASE,0,8,Allergic to penicillin,138116,allergies
932,penicillin,CHEMICAL,12,22,Allergic to penicillin,138116,allergies


##### Drug Usage

In [101]:
# Process drug usage using BC5CDR model (recognizes DISEASE and CHEMICAL)
if __name__ == "__main__":
    # Run extraction for the 'drug usage' column
    # BC5CDR model will identify drugs/substances (CHEMICAL entities)
    df_drug_usage_entities, drug_usage_summary, drug_usage_rules = run_medical_ner_extraction(
        df_medical,  
        text_column='drug usage',
        model_name="en_ner_bc5cdr_md",  # This model recognizes DISEASE and CHEMICAL entities
        batch_size=300, 
        id_column='idx'  )

    # Test the generated rules
    print("\nTesting generated labeling functions for drug usage...")

    # Select a sample row from the original df_medical to test the rules
    if not df_medical.empty:
        sample_row = df_medical.iloc[0]

        for rule_name, rule_func in drug_usage_rules.items():
            try:
                test_result = rule_func(sample_row)
                print(f"{rule_name} applied to row 0: {test_result}")
            except KeyError as e:
                print(f"Error applying rule {rule_name}: {e}. Make sure the column 'drug usage' exists in df_medical.")
    else:
        print("df_medical is empty, cannot test rules.")

    # Debug: Check extracted entities
    print("\n=== DEBUG: Check extracted drug usage entities ===")
    print(f"Shape of df_drug_usage_entities: {df_drug_usage_entities.shape}")
    print(f"\nColumn names: {df_drug_usage_entities.columns.tolist()}")
    print(f"\nFirst few rows:")
    print(df_drug_usage_entities.head())

    # Check entity label distribution
    print("\n=== Entity Labels Found in Drug Usage ===")
    if not df_drug_usage_entities.empty:
        print(df_drug_usage_entities['label'].value_counts())

        # Look at the top entities
        print("\n=== Top Drug Usage Entities ===")
        drug_entities = df_drug_usage_entities['text'].value_counts().head(30)
        print(drug_entities)

# Create specific labeling functions for drug usage data
def create_drug_usage_labeling_functions(entities_df):
        """Create specific labeling functions for drug usage"""

        # Analyze entities to understand patterns
        top_entities = entities_df['text'].value_counts().head(100)
        print(f"\nAnalyzing drug usage entities for patterns...")

        def lf_no_drug_use(row):
            """Detect absence of drug use"""
            text = str(row['drug usage']).lower()
            no_drug_patterns = [
                'no drug', 'denies', 'denied', 'no history', 'no illicit',
                'no substance', 'no recreational', 'never', 'none',
                'no personal history', 'negative'
            ]
            if any(pattern in text for pattern in no_drug_patterns):
                return 'NO_DRUG_USE'
            return 'ABSTAIN'

        def lf_alcohol_use(row):
            """Detect alcohol use patterns"""
            text = str(row['drug usage']).lower()
            alcohol_patterns = [
                'alcohol', 'drinking', 'beer', 'wine', 'liquor', 'spirits',
                'ethanol', 'etoh', 'drinks per', 'social drinking'
            ]
            if any(pattern in text for pattern in alcohol_patterns):
                return 'ALCOHOL_USE'
            return 'ABSTAIN'

        def lf_tobacco_use(row):
            """Detect tobacco/nicotine use"""
            text = str(row['drug usage']).lower()
            tobacco_patterns = [
                'tobacco', 'smoking', 'cigarette', 'nicotine', 'pack',
                'cigar', 'chewing tobacco', 'vaping', 'e-cigarette'
            ]
            if any(pattern in text for pattern in tobacco_patterns):
                return 'TOBACCO_USE'
            return 'ABSTAIN'

        def lf_cannabis_use(row):
            """Detect cannabis/marijuana use"""
            text = str(row['drug usage']).lower()
            cannabis_patterns = [
                'cannabis', 'marijuana', 'thc', 'weed', 'pot', 'hemp',
                'mary jane', 'ganja', 'hash', 'cannabinoid'
            ]
            if any(pattern in text for pattern in cannabis_patterns):
                return 'CANNABIS_USE'
            return 'ABSTAIN'

        def lf_opioid_use(row):
            """Detect opioid use"""
            text = str(row['drug usage']).lower()
            opioid_patterns = [
                'opioid', 'opiate', 'heroin', 'morphine', 'oxycodone',
                'hydrocodone', 'fentanyl', 'codeine', 'tramadol', 'methadone',
                'percocet', 'vicodin', 'oxycontin'
            ]
            if any(pattern in text for pattern in opioid_patterns):
                return 'OPIOID_USE'
            return 'ABSTAIN'

        def lf_stimulant_use(row):
            """Detect stimulant use"""
            text = str(row['drug usage']).lower()
            stimulant_patterns = [
                'cocaine', 'crack', 'amphetamine', 'methamphetamine', 'meth',
                'speed', 'crystal', 'adderall', 'ritalin', 'mdma', 'ecstasy'
            ]
            if any(pattern in text for pattern in stimulant_patterns):
                return 'STIMULANT_USE'
            return 'ABSTAIN'

        def lf_iv_drug_use(row):
            """Detect intravenous drug use"""
            text = str(row['drug usage']).lower()
            iv_patterns = [
                'iv drug', 'intravenous', 'injection', 'needle', 'inject',
                'ivdu', 'shooting up'
            ]
            if any(pattern in text for pattern in iv_patterns):
                return 'IV_DRUG_USE'
            return 'ABSTAIN'

        def lf_prescription_abuse(row):
            """Detect prescription drug abuse"""
            text = str(row['drug usage']).lower()
            if ('prescription' in text or 'prescribed' in text) and \
               any(abuse in text for abuse in ['abuse', 'misuse', 'dependency', 'addiction']):
                return 'PRESCRIPTION_ABUSE'
            return 'ABSTAIN'

        def lf_polysubstance_use(row):
            """Detect multiple substance use"""
            text = str(row['drug usage']).lower()
            # Count different drug mentions
            substances = ['alcohol', 'tobacco', 'cannabis', 'cocaine', 'heroin', 'meth']
            substance_count = sum(1 for sub in substances if sub in text)
            if substance_count >= 2:
                return 'POLYSUBSTANCE_USE'
            return 'ABSTAIN'

        def lf_past_drug_use(row):
            """Detect past/former drug use"""
            text = str(row['drug usage']).lower()
            past_patterns = [
                'former', 'past', 'history of', 'previously', 'quit',
                'stopped', 'used to', 'years ago', 'in recovery', 'sober'
            ]
            if any(pattern in text for pattern in past_patterns) and \
               not any(current in text for current in ['current', 'active', 'ongoing']):
                return 'PAST_DRUG_USE'
            return 'ABSTAIN'

        def lf_current_drug_use(row):
            """Detect current/active drug use"""
            text = str(row['drug usage']).lower()
            current_patterns = [
                'current', 'active', 'ongoing', 'continues', 'daily',
                'regular', 'frequent', 'occasional'
            ]
            if any(pattern in text for pattern in current_patterns):
                return 'CURRENT_DRUG_USE'
            return 'ABSTAIN'

        def lf_drug_treatment(row):
            """Detect drug treatment/rehabilitation"""
            text = str(row['drug usage']).lower()
            treatment_patterns = [
                'rehab', 'treatment', 'recovery', 'aa', 'na', 'methadone clinic',
                'suboxone', 'detox', 'counseling', 'therapy'
            ]
            if any(pattern in text for pattern in treatment_patterns):
                return 'DRUG_TREATMENT'
            return 'ABSTAIN'

        def lf_drug_screen_result(row):
            """Detect drug screening results"""
            text = str(row['drug usage']).lower()
            if any(screen in text for screen in ['drug screen', 'urine test', 'tested positive', 'tested negative']):
                return 'DRUG_SCREEN_MENTIONED'
            return 'ABSTAIN'

        return [
            lf_no_drug_use,
            lf_alcohol_use,
            lf_tobacco_use,
            lf_cannabis_use,
            lf_opioid_use,
            lf_stimulant_use,
            lf_iv_drug_use,
            lf_prescription_abuse,
            lf_polysubstance_use,
            lf_past_drug_use,
            lf_current_drug_use,
            lf_drug_treatment,
            lf_drug_screen_result
        ]
# Apply the drug usage-specific labeling functions
drug_usage_lfs = create_drug_usage_labeling_functions(df_drug_usage_entities)

print("\n=== Testing Drug Usage Labeling Functions ===")
if not df_medical.empty and 'drug usage' in df_medical.columns:
        # Test on multiple rows to see coverage
        test_rows = min(10, len(df_medical))

        for i in range(test_rows):
            row = df_medical.iloc[i]
            if pd.notna(row['drug usage']):
                print(f"\nRow {i} drug usage: {row['drug usage'][:100]}...")
                for lf in drug_usage_lfs:
                    result = lf(row)
                    if result != 'ABSTAIN':
                        print(f"  {lf.__name__}: {result}")

    # Analyze coverage
print("\n=== Drug Usage Labeling Function Coverage ===")
total_non_na = df_medical['drug usage'].notna().sum()
for lf in drug_usage_lfs:
    labeled_count = sum(lf(row) != 'ABSTAIN' for _, row in df_medical.iterrows()
                        if pd.notna(row.get('drug usage', '')))
    coverage = (labeled_count / total_non_na * 100) if total_non_na > 0 else 0
    print(f"{lf.__name__}: {coverage:.1f}% coverage ({labeled_count}/{total_non_na})")

# Analyze specific patterns in drug usage
print("\n=== Common Drug Usage Patterns ===")
if 'drug usage' in df_medical.columns:
    # Count specific substance mentions
    no_drug_mentions = df_medical['drug usage'].str.contains('no drug|denied|no history', na=False, regex=True).sum()
    alcohol_mentions = df_medical['drug usage'].str.contains('alcohol|drinking|beer|wine', na=False, regex=True).sum()
    tobacco_mentions = df_medical['drug usage'].str.contains('tobacco|smoking|cigarette', na=False, regex=True).sum()
    illicit_mentions = df_medical['drug usage'].str.contains('cocaine|heroin|meth|cannabis', na=False, regex=True).sum()

    print(f"No drug use mentions: {no_drug_mentions}")
    print(f"Alcohol mentions: {alcohol_mentions}")
    print(f"Tobacco mentions: {tobacco_mentions}")
    print(f"Illicit drug mentions: {illicit_mentions}")
    print(f"Total drug usage records: {df_medical['drug usage'].notna().sum()}")

# Create risk stratification based on drug usage
def stratify_drug_use_risk(row):
    """Stratify risk based on drug usage patterns"""
    if pd.isna(row.get('drug usage', '')):
        return 'UNKNOWN_RISK'

    text = str(row['drug usage']).lower()

    # High risk indicators
    high_risk = ['iv drug', 'heroin', 'cocaine', 'meth', 'overdose', 'daily use']
    if any(risk in text for risk in high_risk):
        return 'HIGH_RISK'

    # Moderate risk
    moderate_risk = ['alcohol', 'cannabis', 'prescription']
    if any(risk in text for risk in moderate_risk):
        return 'MODERATE_RISK'

    # Low risk
    if any(pattern in text for pattern in ['no drug', 'denied', 'none']):
        return 'LOW_RISK'

    return 'UNSPECIFIED_RISK'

# Apply risk stratification
print("\n=== Drug Use Risk Stratification ===")
if 'drug usage' in df_medical.columns:
    risk_levels = df_medical.apply(stratify_drug_use_risk, axis=1)
    print(risk_levels.value_counts())

# save result
df_drug_usage_entities.to_csv('drug_usage_entities_comprehensive.csv', index=False)


Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 29755 texts in 100 batches...
Using model: en_ner_bc5cdr_md for column: drug usage


Processing batches:   2%|▏         | 2/100 [00:00<00:39,  2.46it/s]


Checkpoint saved at batch 0


Processing batches:  12%|█▏        | 12/100 [00:02<00:17,  5.05it/s]


Checkpoint saved at batch 3000


Processing batches:  22%|██▏       | 22/100 [00:04<00:14,  5.30it/s]


Checkpoint saved at batch 6000


Processing batches:  32%|███▏      | 32/100 [00:06<00:12,  5.35it/s]


Checkpoint saved at batch 9000


Processing batches:  42%|████▏     | 42/100 [00:08<00:10,  5.39it/s]


Checkpoint saved at batch 12000


Processing batches:  52%|█████▏    | 52/100 [00:10<00:09,  5.32it/s]


Checkpoint saved at batch 15000


Processing batches:  62%|██████▏   | 62/100 [00:12<00:07,  5.12it/s]


Checkpoint saved at batch 18000


Processing batches:  72%|███████▏  | 72/100 [00:13<00:05,  5.03it/s]


Checkpoint saved at batch 21000


Processing batches:  82%|████████▏ | 82/100 [00:15<00:03,  5.15it/s]


Checkpoint saved at batch 24000


Processing batches:  92%|█████████▏| 92/100 [00:17<00:01,  5.17it/s]


Checkpoint saved at batch 27000


Processing batches: 100%|██████████| 100/100 [00:19<00:00,  5.23it/s]


Found 27 entities appearing >= 5 times

Testing generated labeling functions for drug usage...
lf_disease_drug_usage applied to row 0: ABSTAIN
lf_chemical_drug_usage applied to row 0: ABSTAIN

=== DEBUG: Check extracted drug usage entities ===
Shape of df_drug_usage_entities: (715, 7)

Column names: ['text', 'label', 'start', 'end', 'original_text', 'row_idx', 'source_column']

First few rows:
                 text     label  start  end                    original_text  \
0            Diazepam  CHEMICAL      0    8  Diazepam and methadone overdose   
1  methadone overdose  CHEMICAL     13   31  Diazepam and methadone overdose   
2            zolpidem  CHEMICAL      9   17                Abuse of zolpidem   
3          drug abuse   DISEASE     11   21            History of drug abuse   
4          drug abuse   DISEASE     12   22           Intravenous drug abuse   

   row_idx source_column  
0   163624    drug usage  
1   163624    drug usage  
2    90815    drug usage  
3    43921    

In [102]:
df_drug_usage_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column
0,Diazepam,CHEMICAL,0,8,Diazepam and methadone overdose,163624,drug usage
1,methadone overdose,CHEMICAL,13,31,Diazepam and methadone overdose,163624,drug usage
2,zolpidem,CHEMICAL,9,17,Abuse of zolpidem,90815,drug usage
3,drug abuse,DISEASE,11,21,History of drug abuse,43921,drug usage
4,drug abuse,DISEASE,12,22,Intravenous drug abuse,84350,drug usage
...,...,...,...,...,...,...,...
710,bipolar disorder,DISEASE,34,50,Remote history of lithium use for bipolar diso...,113022,drug usage
711,cocaine,CHEMICAL,24,31,"Occasional recreational cocaine use, most rece...",137819,drug usage
712,cannabis,CHEMICAL,93,101,Heavy e-cigarette use for the previous 2 years...,99885,drug usage
713,neck pain,DISEASE,61,70,"Used heroin, street bought oral opiates to sel...",97753,drug usage


### Extracting Surgeries

In [103]:
df_surgery

Unnamed: 0,idx,has_surgery,reason,Type,time,outcome,details,combined_text
0,155216,False,,,,,,nan nan nan nan
1,133948,True,Idiopathic osteonecrosis of the femoral head,Total Hip Arthroplasty (THA),After diagnosis,Discharged in good condition without specific ...,First THA on the left hip,Idiopathic osteonecrosis of the femoral head T...
2,133948,True,Pain and limited ROM in the contralateral hip ...,Total Hip Arthroplasty (THA),One year after the first THA,Discharged in good condition without specific ...,Second THA on the contralateral hip,Pain and limited ROM in the contralateral hip ...
3,80176,True,Posttraumatic arthritis,Left elbow arthrodesis,At the age of 18,,Elbow was fused at 90 degrees,Posttraumatic arthritis Left elbow arthrodesis...
4,80176,True,Hypertrophic nonunion of ulnar shaft fracture ...,Repair of nonunion and conversion of elbow art...,Three months after the fall and subsequent con...,,The stem of the ulnar component would act as a...,Hypertrophic nonunion of ulnar shaft fracture ...
...,...,...,...,...,...,...,...,...
35859,98004,True,Inferior segment elevation (ST) elevation myoc...,Primary percutaneous coronary intervention (dr...,,Successful treatment of right coronary artery ...,Procedure complicated by Ventricular Fibrillat...,Inferior segment elevation (ST) elevation myoc...
35860,133320,True,Leiomyosarcoma,Wide tumor resection,,Successful with no adjuvant chemotherapy and r...,,Leiomyosarcoma Wide tumor resection nan Succes...
35861,133320,True,Lung nodules,Excisional biopsy,One year and 3 months postoperatively,Histopathological diagnosis was consistent wit...,,Lung nodules Excisional biopsy nan Histopathol...
35862,133320,True,Bone metastasis of the right femur,Cryoablation under CT guidance,,,Ablation needles were inserted into the proxim...,Bone metastasis of the right femur Cryoablatio...


In [104]:
class TemporalStandardizer:
    """Extract and standardize temporal expressions from medical text"""

    def __init__(self):
        # Load spaCy model
        try:
            self.nlp = spacy.load("en_core_web_sm")
        except:
            self.nlp = spacy.load("en_core_web_sm")

        # Conversion mappings - Define this BEFORE adding patterns
        self.time_units = {
            'day': 1, 'days': 1, 'd': 1,
            'week': 7, 'weeks': 7, 'wk': 7, 'wks': 7,
            'month': 30, 'months': 30, 'mo': 30, 'mos': 30,
            'year': 365, 'years': 365, 'yr': 365, 'yrs': 365,
            'hour': 1/24, 'hours': 1/24, 'hr': 1/24, 'hrs': 1/24,
            'minute': 1/1440, 'minutes': 1/1440, 'min': 1/1440, 'mins': 1/1440
        }

        # Initialize matcher for temporal patterns
        self.matcher = Matcher(self.nlp.vocab)
        # Now call _add_temporal_patterns AFTER time_units is defined
        self._add_temporal_patterns()


    def _add_temporal_patterns(self):
        """Add temporal patterns to spaCy matcher"""

        # Pattern: "X days/weeks/months"
        pattern1 = [
            {"LIKE_NUM": True},
            # Access self.time_units here is now safe
            {"LOWER": {"IN": list(self.time_units.keys())}}
        ]
        self.matcher.add("DURATION", [pattern1])

        # Pattern: "past X days/weeks"
        pattern2 = [
            {"LOWER": {"IN": ["past", "last", "previous"]}},
            {"LIKE_NUM": True},
            {"LOWER": {"IN": list(self.time_units.keys())}}
        ]
        self.matcher.add("PAST_DURATION", [pattern2])

        # Pattern: "X days/weeks ago"
        pattern3 = [
            {"LIKE_NUM": True},
            {"LOWER": {"IN": list(self.time_units.keys())}},
            {"LOWER": "ago"}
        ]
        self.matcher.add("AGO_DURATION", [pattern3])

    def extract_all_temporal_info(self, text):
        """Extract all temporal information from text"""
        if pd.isna(text) or text == '':
            return {
                'duration_days': None,
                'is_ongoing': False,
                'has_date': False,
                'temporal_type': None,
                'original_text': text
            }

        text = str(text)
        info = {
            'duration_days': None,
            'is_ongoing': False,
            'has_date': False,
            'temporal_type': None,
            'original_text': text
        }

        # Check for ongoing conditions
        ongoing_patterns = [
            'persisting', 'continuing', 'ongoing', 'current',
            'still', 'continues', 'persistent', 'chronic'
        ]
        info['is_ongoing'] = any(pattern in text.lower() for pattern in ongoing_patterns)

        # Try to extract specific dates
        dates = self._extract_dates(text)
        if dates:
            info['has_date'] = True
            info['extracted_dates'] = dates

        # Extract duration
        duration = self._extract_duration(text)
        if duration:
            info['duration_days'] = duration

        # Classify temporal type
        info['temporal_type'] = self._classify_temporal_type(text)

        return info

    def _extract_dates(self, text):
        """Extract actual dates from text"""
        dates = []

        # Common date patterns
        date_patterns = [
            r'\b\d{1,2}[-/]\d{1,2}[-/]\d{2,4}\b',  # MM/DD/YYYY or MM-DD-YYYY
            r'\b\d{4}[-/]\d{1,2}[-/]\d{1,2}\b',    # YYYY-MM-DD
            r'\b(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{1,2},?\s+\d{4}\b',  # Month DD, YYYY
            r'\b\d{1,2}\s+(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)[a-z]*\s+\d{4}\b',     # DD Month YYYY
        ]

        for pattern in date_patterns:
            matches = re.finditer(pattern, text, re.IGNORECASE)
            for match in matches:
                try:
                    parsed_date = parser.parse(match.group(), fuzzy=False)
                    dates.append(parsed_date)
                except:
                    continue

        return dates

    def _extract_duration(self, text):
        """Extract duration in days from temporal expressions"""
        text_lower = text.lower()

        # Use regex to find duration patterns
        patterns = [
            # "X days/weeks/months"
            r'(\d+)\s*(day|days|week|weeks|month|months|year|years|hour|hours)',
            # "a few days/weeks"
            r'(a few|several|couple of)\s*(day|days|week|weeks|month|months)',
            # Written numbers
            r'(one|two|three|four|five|six|seven|eight|nine|ten)\s*(day|days|week|weeks|month|months|year|years)',
        ]

        for pattern in patterns:
            match = re.search(pattern, text_lower)
            if match:
                # Extract number
                number_text = match.group(1)
                unit = match.group(2)

                # Convert to number
                if number_text.isdigit():
                    number = int(number_text)
                elif number_text in ['a few', 'several']:
                    number = 3  # Approximate
                elif number_text == 'couple of':
                    number = 2
                else:
                    # Convert written numbers
                    number_map = {
                        'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
                        'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10
                    }
                    number = number_map.get(number_text, 1)

                # Convert to days
                if unit in self.time_units:
                    return number * self.time_units[unit]

        return None

    def _classify_temporal_type(self, text):
        """Classify the type of temporal expression"""
        text_lower = text.lower()

        if any(word in text_lower for word in ['ago', 'before', 'prior', 'previously']):
            return 'past_reference'
        elif any(word in text_lower for word in ['since', 'from', 'started']):
            return 'onset_reference'
        elif any(word in text_lower for word in ['for', 'duration', 'lasted']):
            return 'duration_reference'
        elif any(word in text_lower for word in ['until', 'through', 'to']):
            return 'range_reference'
        elif any(word in text_lower for word in ['after', 'following', 'post']):
            return 'post_event'
        elif re.search(r'\d{4}', text):  # Contains year
            return 'absolute_date'
        else:
            return 'unspecified'

    def standardize_temporal_column(self, df, column_name):
        """Standardize an entire temporal column"""

        print(f"\nProcessing temporal column: {column_name}")

        # Apply extraction to all values
        temporal_data = df[column_name].apply(self.extract_all_temporal_info)

        # Convert to DataFrame
        temporal_df = pd.DataFrame(temporal_data.tolist())

        # Add prefix to column names
        temporal_df.columns = [f"{column_name}_{col}" for col in temporal_df.columns]

        # Concatenate with original dataframe
        result_df = pd.concat([df, temporal_df], axis=1)

        # Generate report
        report = self._generate_temporal_report(temporal_df, column_name)

        return result_df, report

    def _generate_temporal_report(self, temporal_df, column_name):
        """Generate report on temporal extraction"""

        duration_col = f"{column_name}_duration_days"
        type_col = f"{column_name}_temporal_type"

        report = {
            'column': column_name,
            'total_entries': len(temporal_df),
            'extracted_durations': temporal_df[duration_col].notna().sum(),
            'ongoing_conditions': temporal_df[f"{column_name}_is_ongoing"].sum(),
            'has_specific_dates': temporal_df[f"{column_name}_has_date"].sum(),
            'temporal_types': temporal_df[type_col].value_counts().to_dict()
        }

        # Duration statistics
        if temporal_df[duration_col].notna().any():
            report['duration_stats'] = {
                'min_days': temporal_df[duration_col].min(),
                'max_days': temporal_df[duration_col].max(),
                'mean_days': temporal_df[duration_col].mean(),
                'median_days': temporal_df[duration_col].median()
            }

        return report

In [141]:
if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_surgery_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_surgery, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine surgery text columns
    # (use str() to be tolerant to NaNs)
    df_surgery['combined_text'] = df_surgery.apply(
        lambda row: f"{str(row.get('reason',''))} {str(row.get('Type',''))} {str(row.get('details',''))} {str(row.get('outcome',''))}",
        axis=1
    )

    # Copy combined_text to processed dataframe (keeps alignment with df_surgery by index)
    df_surgery_processed['combined_text'] = df_surgery['combined_text']

    # Add temporal features to combined text
    df_surgery_processed['combined_text_enriched'] = df_surgery_processed.apply(
        lambda row: f"{row['combined_text']} {'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # NER over the (non-enriched) combined_text in df_surgery
    df_surgery_entities, surgery_summary, surgery_rules = run_medical_ner_extraction(
        df_surgery,
        text_column='combined_text',
        model_name='en_ner_bionlp13cg_md', 
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== Entity Types Found ===")
    print(surgery_summary['entity_types'])

    # ========== CUSTOM ANATOMY & PROCEDURE EXTRACTION (with correct row_idx) ==========
    print("\n=== CUSTOM ANATOMY & PROCEDURE EXTRACTION ===")

    def extract_surgical_entities_custom(df_surgery):
        """
        Custom extraction for anatomy and procedures not caught by the model.
        Uses 'idx' column for row identity when available; falls back to df.index otherwise.
        Adds 'row_idx' to every extracted entity.
        Uses regex word-boundary matching; allows simple plurals (e.g., hips, arteries).
        """

        custom_entities = []

        # Anatomy vocab by category
        anatomy_patterns = {
            'hip':        ['hip', 'acetabul', 'femoral head'],
            'knee':       ['knee', 'meniscus', 'patella', 'tibia'],
            'bone':       ['bone', 'femur', 'humerus', 'radius', 'ulna'],
            'joint':      ['joint', 'articulation'],
            'spine':      ['spine', 'vertebra', 'disc', 'lumbar', 'cervical'],
            'organ':      ['kidney', 'liver', 'heart', 'lung', 'spleen'],
            'vessel':     ['artery', 'vein', 'vessel', 'vascular'],
            'nerve':      ['nerve', 'neural', 'plexus'],
            'muscle':     ['muscle', 'tendon', 'ligament']
        }

        # Procedure vocab by category
        procedure_patterns = {
            'replacement':   ['arthroplasty', 'replacement', 'implant'],
            'repair':        ['repair', 'reconstruction', 'fixation'],
            'removal':       ['removal', 'excision', 'resection', 'ectomy'],
            'fusion':        ['fusion', 'arthrodesis', 'osteotomy'],
            'diagnostic':    ['biopsy', 'exploration', 'diagnostic'],
            'decompression': ['decompression', 'release', 'neurolysis'],
            'transplant':    ['transplant', 'graft', 'implantation']
        }

        laterality_terms = ['left', 'right', 'bilateral', 'unilateral']

        for df_idx, row in df_surgery.iterrows():
            # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
            if 'idx' in row and pd.notna(row['idx']):
                row_id = row['idx']
            else:
                row_id = df_idx

            # ---- text (keep original for offsets; search on lowercased copy)
            original_text = str(row.get('combined_text', ''))
            text = original_text.lower()

            # ---- Anatomy extraction (regex with word boundaries; allow simple plural 's')
            for category, terms in anatomy_patterns.items():
                for term in terms:
                    # allow plural 's' for basic nouns (skip stems like 'acetabul' where 's' doesn't apply cleanly)
                    plural_opt = 's?' if term.isalpha() and not term.endswith(('al','el','ul')) else ''
                    pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                    for m in re.finditer(pattern, text):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'ANATOMY',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'custom_extraction',
                            'row_idx': row_id,   # <-- correct ID retained
                        })

            # ---- Procedure extraction (also boundary-based)
            for category, terms in procedure_patterns.items():
                for term in terms:
                    plural_opt = 's?' if term.isalpha() else ''
                    pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                    for m in re.finditer(pattern, text):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'PROCEDURE',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'custom_extraction',
                            'row_idx': row_id,   # <-- correct ID retained
                        })

            # ---- Laterality extraction
            for lat in laterality_terms:
                pattern = r'\b' + re.escape(lat) + r'\b'
                for m in re.finditer(pattern, text):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'LATERALITY',
                        'category': 'laterality',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'custom_extraction',
                        'row_idx': row_id,   # <-- correct ID retained
                    })

        return pd.DataFrame(custom_entities)

    # Extract custom entities
    df_custom_entities = extract_surgical_entities_custom(df_surgery)

    print(f"\nCustom extraction found:")
    if not df_custom_entities.empty and 'label' in df_custom_entities.columns:
        print(df_custom_entities['label'].value_counts())
    else:
        print("No custom entities found.")

    # Combine model entities and custom entities
    # (keep row_idx/category when present for traceability)
    df_all_surgery_entities = pd.concat(
        [df_surgery_entities, df_custom_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_surgery_entities.empty and 'label' in df_all_surgery_entities.columns:
        print(df_all_surgery_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze anatomy entities
    anatomy_entities = df_all_surgery_entities[df_all_surgery_entities['label'] == 'ANATOMY'] if not df_all_surgery_entities.empty else pd.DataFrame()
    if not anatomy_entities.empty:
        print(f"\n=== Top Anatomical Entities ===")
        print(anatomy_entities['text'].str.lower().value_counts().head(20))

        if 'category' in anatomy_entities.columns:
            print(f"\n=== Anatomy by Category ===")
            print(anatomy_entities['category'].value_counts())

    # Analyze procedure entities
    procedure_entities = df_all_surgery_entities[df_all_surgery_entities['label'] == 'PROCEDURE'] if not df_all_surgery_entities.empty else pd.DataFrame()
    if not procedure_entities.empty:
        print(f"\n=== Top Surgical Procedures ===")
        print(procedure_entities['text'].str.lower().value_counts().head(20))

        if 'category' in procedure_entities.columns:
            print(f"\n=== Procedures by Category ===")
            print(procedure_entities['category'].value_counts())

    # ========== Labeling functions (unchanged logic) ==========
    def create_surgical_labeling_functions():
        """Create labeling functions specific to surgical data"""

        def lf_hip_surgery(row):
            text = str(row.get('combined_text','')).lower()
            if 'hip' in text and any(proc in text for proc in ['arthroplasty', 'replacement', 'repair']):
                return 'HIP_SURGERY'
            return 'ABSTAIN'

        def lf_knee_surgery(row):
            text = str(row.get('combined_text','')).lower()
            if 'knee' in text and any(proc in text for proc in ['arthroplasty', 'replacement', 'arthroscopy']):
                return 'KNEE_SURGERY'
            return 'ABSTAIN'

        def lf_fracture_surgery(row):
            text = str(row.get('combined_text','')).lower()
            if 'fracture' in text and any(proc in text for proc in ['fixation', 'repair', 'reduction']):
                return 'FRACTURE_SURGERY'
            return 'ABSTAIN'

        def lf_bilateral_procedure(row):
            text = str(row.get('combined_text','')).lower()
            if 'bilateral' in text:
                return 'BILATERAL_PROCEDURE'
            return 'ABSTAIN'

        def lf_minimally_invasive(row):
            text = str(row.get('combined_text','')).lower()
            if any(term in text for term in ['arthroscopic', 'endoscopic', 'laparoscopic', 'minimally invasive']):
                return 'MINIMALLY_INVASIVE'
            return 'ABSTAIN'

        def lf_emergency_procedure(row):
            reason = str(row.get('reason', '')).lower()
            if any(term in reason for term in ['emergency', 'urgent', 'acute', 'trauma']):
                return 'EMERGENCY_SURGERY'
            return 'ABSTAIN'

        return [lf_hip_surgery, lf_knee_surgery, lf_fracture_surgery,
                lf_bilateral_procedure, lf_minimally_invasive, lf_emergency_procedure]

    # Test labeling functions
    surgical_lfs = create_surgical_labeling_functions()

    print("\nTesting surgical labeling functions:")
    for i in range(min(5, len(df_surgery))):
        row = df_surgery.iloc[i]
        print(f"\nRow {i}: {str(row.get('Type',''))[:50]}...")
        for lf in surgical_lfs:
            result = lf(row)
            if result != 'ABSTAIN':
                print(f"  {lf.__name__}: {result}")

    # ========== Extract relations (now per row_idx to avoid cross-row mixing) ==========
    print("\n=== Extracting Surgical Relations ===")

    def extract_surgical_relations(df_entities):
        """
        Extract relations between anatomy and procedures within the same row (row_idx).
        Uses character proximity (< 50 chars) heuristic.
        """
        relations = []
        if df_entities.empty:
            return relations

        # Ensure row grouping key exists
        group_key = 'row_idx' if 'row_idx' in df_entities.columns else 'original_text'

        for key, group in df_entities.groupby(group_key):
            if 'label' not in group.columns or 'start' not in group.columns:
                continue

            anatomy_ents = group[group['label'] == 'ANATOMY']
            procedure_ents = group[group['label'] == 'PROCEDURE']
            if anatomy_ents.empty or procedure_ents.empty:
                continue

            # Pick a representative text (for preview only)
            sample_text = group['original_text'].iloc[0] if 'original_text' in group.columns else ''

            for _, anat in anatomy_ents.iterrows():
                for _, proc in procedure_ents.iterrows():
                    try:
                        if abs(int(anat['start']) - int(proc['start'])) < 50:
                            relations.append({
                                'type': 'PROCEDURE_ON_ANATOMY',
                                'procedure': str(proc['text']),
                                'anatomy': str(anat['text']),
                                'row_idx': key if group_key == 'row_idx' else None,
                                'text': sample_text[:120]
                            })
                    except Exception:
                        # skip malformed rows without numeric offsets
                        pass

        return relations

    surgical_relations = extract_surgical_relations(df_all_surgery_entities)

    print(f"Found {len(surgical_relations)} surgical relations")
    if surgical_relations:
        print("\nSample relations:")
        for rel in surgical_relations[:10]:
            rid = f" (row {rel['row_idx']})" if rel.get('row_idx') is not None else ''
            print(f"  {rel['procedure']} -> {rel['anatomy']}{rid}")

    # Save results
    df_all_surgery_entities.to_csv('surgery_entities_comprehensive.csv', index=False)
    print("\nSaved: surgery_entities_comprehensive.csv")


Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 6388
Temporal types: {'unspecified': 3554, 'absolute_date': 2836, 'post_event': 2806, 'past_reference': 2754, 'range_reference': 654, 'duration_reference': 261, 'onset_reference': 191}
Loading en_ner_bionlp13cg_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 35864 texts in 120 batches...
Using model: en_ner_bionlp13cg_md for column: combined_text


Processing batches:   1%|          | 1/120 [00:03<06:41,  3.37s/it]


Checkpoint saved at batch 0


Processing batches:   9%|▉         | 11/120 [00:22<03:26,  1.90s/it]


Checkpoint saved at batch 3000


Processing batches:  18%|█▊        | 21/120 [00:40<03:04,  1.86s/it]


Checkpoint saved at batch 6000


Processing batches:  26%|██▌       | 31/120 [00:58<02:50,  1.91s/it]


Checkpoint saved at batch 9000


Processing batches:  34%|███▍      | 41/120 [01:16<02:31,  1.92s/it]


Checkpoint saved at batch 12000


Processing batches:  42%|████▎     | 51/120 [01:34<02:20,  2.03s/it]


Checkpoint saved at batch 15000


Processing batches:  51%|█████     | 61/120 [01:55<02:13,  2.26s/it]


Checkpoint saved at batch 18000


Processing batches:  58%|█████▊    | 70/120 [02:12<01:36,  1.93s/it]


Checkpoint saved at batch 21000


Processing batches:  68%|██████▊   | 81/120 [02:36<01:23,  2.14s/it]


Checkpoint saved at batch 24000


Processing batches:  76%|███████▌  | 91/120 [02:54<00:58,  2.01s/it]


Checkpoint saved at batch 27000


Processing batches:  84%|████████▍ | 101/120 [03:13<00:40,  2.14s/it]


Checkpoint saved at batch 30000


Processing batches:  92%|█████████▎| 111/120 [03:32<00:18,  2.10s/it]


Checkpoint saved at batch 33000


Processing batches: 100%|██████████| 120/120 [03:49<00:00,  1.91s/it]


Found 3452 entities appearing >= 5 times

=== Entity Types Found ===
{'MULTI_TISSUE_STRUCTURE': 32237, 'CANCER': 17003, 'TISSUE': 16665, 'ORGAN': 14854, 'PATHOLOGICAL_FORMATION': 14321, 'ORGANISM_SUBDIVISION': 4084, 'GENE_OR_GENE_PRODUCT': 3097, 'CELL': 2784, 'SIMPLE_CHEMICAL': 2742, 'ORGANISM': 2378, 'IMMATERIAL_ANATOMICAL_ENTITY': 1800, 'CELLULAR_COMPONENT': 1777, 'ORGANISM_SUBSTANCE': 1306, 'AMINO_ACID': 111, 'ANATOMICAL_SYSTEM': 48, 'DEVELOPING_ANATOMICAL_STRUCTURE': 1}

=== CUSTOM ANATOMY & PROCEDURE EXTRACTION ===

Custom extraction found:
label
PROCEDURE     24770
ANATOMY       20351
LATERALITY    19618
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
MULTI_TISSUE_STRUCTURE             32237
PROCEDURE                          24770
ANATOMY                            20351
LATERALITY                         19618
CANCER                             17003
TISSUE                             16665
ORGAN                              14854
PATHOLOGICAL_FORMATION   

In [142]:
df_all_surgery_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category,source
0,femoral head,PATHOLOGICAL_FORMATION,32,44,Idiopathic osteonecrosis of the femoral head T...,133948,combined_text,,
1,left hip,PATHOLOGICAL_FORMATION,91,99,Idiopathic osteonecrosis of the femoral head T...,133948,combined_text,,
2,joint Total,MULTI_TISSUE_STRUCTURE,46,57,Pain and limited ROM in the contralateral hip ...,133948,combined_text,,
3,Left elbow,TISSUE,24,34,Posttraumatic arthritis Left elbow arthrodesis...,80176,combined_text,,
4,ulnar shaft,MULTI_TISSUE_STRUCTURE,25,36,Hypertrophic nonunion of ulnar shaft fracture ...,80176,combined_text,,
...,...,...,...,...,...,...,...,...,...
179942,biopsy,PROCEDURE,24,30,Lung nodules Excisional biopsy nan Histopathol...,133320,,diagnostic,custom_extraction
179943,Bone,ANATOMY,0,4,Bone metastasis of the right femur Cryoablatio...,133320,,bone,custom_extraction
179944,femur,ANATOMY,29,34,Bone metastasis of the right femur Cryoablatio...,133320,,bone,custom_extraction
179945,right,LATERALITY,23,28,Bone metastasis of the right femur Cryoablatio...,133320,,laterality,custom_extraction


### Extracting Symptoms

In [107]:
df_symptoms

Unnamed: 0,idx,has_symptom,name of symptom,intensity of symptom,location,time,temporalisation,behaviours affecting the symptom,details,duration_days,temporal_pattern,combined_text
0,155216,True,"Discomfort in the neck and lower back, restric...",,Neck and lower back,Past four months,,Standing up from a sitting position,Head turned to the right and upwards due to su...,,,"Discomfort in the neck and lower back, restric..."
1,133948,True,Pain,Severe,Left hip joint,Persisting for two months,Increased over the following three weeks,Aggravated by hip joint flexion or rotation,Also complained of pain and limited ROM in the...,,Increased over the following three weeks,Pain with Severe intensity located in Left hip...
2,133948,True,Restricted range of motion,,Left hip joint,Persisting for two months,,,,,,Restricted range of motion with nan intensity ...
3,133948,True,Gait disturbance,Severe,,,,Secondary to hip pain,Continued for two months and increased over th...,,,Gait disturbance with Severe intensity located...
4,133948,True,Moderate moon face,Moderate,Face,At the time of the second surgery,,,Initially overlooked as weight gain,,,Moderate moon face with Moderate intensity loc...
...,...,...,...,...,...,...,...,...,...,...,...,...
54939,137017,True,Left-sided weakness,,Left side,,,,,,,Left-sided weakness with nan intensity located...
54940,98004,True,Chest pain,,Chest,,,,Cardiac sounding,,,Chest pain with nan intensity located in Chest...
54941,133320,True,Mass in right thigh,,Lateral side of the right thigh,Noticed four years prior to presentation,,,"Diameter of 4 cm, no adhesion with skin and no...",,,Mass in right thigh with nan intensity located...
54942,97973,True,Crushing substernal chest pressure,Acute onset,Substernal,,Following 1-week-long febrile illness,,Accompanied by dyspnea and profuse sweating,,Following 1-week-long febrile illness,Crushing substernal chest pressure with Acute ...


In [145]:
if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_symptoms_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_symptoms, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine symptom information into comprehensive text (robust to NaNs)
    df_symptoms['combined_text'] = df_symptoms.apply(
        lambda row: f"{str(row.get('name of symptom',''))} with {str(row.get('intensity of symptom',''))} "
                    f"intensity located in {str(row.get('location',''))} lasting {str(row.get('time',''))} "
                    f"{str(row.get('temporalisation',''))} {str(row.get('details',''))}",
        axis=1
    )

    # Copy combined_text to processed dataframe (corrected to use df_symptoms)
    df_symptoms_processed['combined_text'] = df_symptoms['combined_text']

    # Add temporal features to combined text (uses time_duration_days if present)
    df_symptoms_processed['combined_text_enriched'] = df_symptoms_processed.apply(
        lambda row: f"{row['combined_text']} "
                    f"{'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # Run NER extraction with BC5CDR over combined_text
    df_symptoms_entities, symptoms_summary, symptoms_rules = run_medical_ner_extraction(
        df_symptoms,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(symptoms_summary['entity_types'])

    # ================= CUSTOM SYMPTOM ENTITY EXTRACTION (keeps correct row_idx) =================
    print("\n=== CUSTOM SYMPTOM ENTITY EXTRACTION ===")

    def extract_symptom_entities_custom(df_symptoms):
        """
        Extract symptom-specific entities not caught by BC5CDR.
        Uses 'idx' for row identity when available; falls back to df.index otherwise.
        Adds 'row_idx' to every extracted entity.
        Regex with word boundaries is used to avoid substring false positives.
        """
        custom_entities = []

        # Symptom type patterns
        symptom_patterns = {
            'pain':         ['pain', 'ache', 'soreness', 'tenderness', 'discomfort'],
            'neurological': ['numbness', 'tingling', 'weakness', 'paralysis', 'tremor'],
            'mobility':     ['inability to walk', 'gait disturbance', 'limited range of motion', 'stiffness'],
            'swelling':     ['swelling', 'edema', 'inflammation', 'mass', 'lump'],
            'sensory':      ['vision', 'hearing', 'taste', 'smell', 'sensation'],
            'systemic':     ['fever', 'fatigue', 'malaise', 'weight loss', 'night sweats'],
            'respiratory':  ['dyspnea', 'cough', 'wheezing', 'shortness of breath'],
            'gi':           ['nausea', 'vomiting', 'diarrhea', 'constipation', 'abdominal'],
            'skin':         ['rash', 'itching', 'lesion', 'discoloration', 'bruising']
        }

        # Anatomical location patterns
        anatomy_patterns = {
            'head_neck':      ['head', 'neck', 'scalp', 'face', 'throat', 'cervical'],
            'upper_extremity': ['shoulder', 'arm', 'elbow', 'forearm', 'wrist', 'hand', 'finger'],
            'lower_extremity': ['hip', 'thigh', 'knee', 'leg', 'ankle', 'foot', 'toe'],
            'trunk':           ['chest', 'back', 'abdomen', 'pelvis', 'spine', 'lumbar'],
            'joint':           ['joint', 'articulation'],
            'internal':        ['heart', 'lung', 'liver', 'kidney', 'stomach']
        }

        # Severity/Intensity (column-driven)
        severity_values = {'mild', 'moderate', 'severe', 'extreme', 'minimal', 'significant'}

        # Temporal patterns
        temporal_patterns = {
            'acute':        ['acute', 'sudden', 'abrupt', 'rapid'],
            'chronic':      ['chronic', 'persistent', 'ongoing', 'continuous'],
            'intermittent': ['intermittent', 'episodic', 'recurrent', 'periodic'],
            'progressive':  ['progressive', 'worsening', 'increasing', 'deteriorating']
        }

        laterality_terms = ['left', 'right', 'bilateral', 'both']

        for df_index, row in df_symptoms.iterrows():
            # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

            original_text = str(row.get('combined_text', ''))
            combined = original_text.lower()

            # Columns (original + lower)
            symptom_name_orig = str(row.get('name of symptom', ''))
            symptom_name = symptom_name_orig.lower()
            location_orig = str(row.get('location', ''))
            location = location_orig.lower()
            intensity_orig = str(row.get('intensity of symptom', ''))
            intensity = intensity_orig.lower()
            temporalisation_orig = str(row.get('temporalisation', ''))
            temporalisation = temporalisation_orig.lower()

            # ---- Primary SYMPTOM from name column (if present)
            if pd.notna(row.get('name of symptom')) and symptom_name.strip() != '' and symptom_name != 'nan':
                # try to find position in combined; if not found, set to 0
                start = combined.find(symptom_name)
                if start < 0:
                    start = 0
                end = start + len(symptom_name_orig)
                end = min(end, len(original_text))
                custom_entities.append({
                    'text': symptom_name_orig,
                    'label': 'SYMPTOM',
                    'category': 'primary_symptom',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'symptom_name_column',
                    'row_idx': row_id
                })

            # ---- Symptom types (regex boundary matches)
            for category, terms in symptom_patterns.items():
                for term in terms:
                    pattern = r'\b' + re.escape(term.lower()) + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'SYMPTOM_TYPE',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'pattern_matching',
                            'row_idx': row_id
                        })

            # ---- Anatomical location from location column
            if pd.notna(row.get('location')) and location.strip() != '' and location != 'nan':
                start = combined.find(location)
                if start >= 0:
                    end = start + len(location)
                    custom_entities.append({
                        'text': location_orig,
                        'label': 'ANATOMY',
                        'category': 'symptom_location',
                        'start': start,
                        'end': end,
                        'original_text': original_text,
                        'source': 'location_column',
                        'row_idx': row_id
                    })

            # ---- Additional anatomy from patterns
            for category, terms in anatomy_patterns.items():
                for term in terms:
                    # allow simple plural 's' for nouns
                    plural_opt = 's?' if term.isalpha() else ''
                    pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'ANATOMY',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'pattern_matching',
                            'row_idx': row_id
                        })

            # ---- Severity from intensity column (exact value match)
            if pd.notna(row.get('intensity of symptom')) and intensity in severity_values:
                start = combined.find(intensity)
                if start >= 0:
                    custom_entities.append({
                        'text': intensity_orig,
                        'label': 'SEVERITY',
                        'category': 'intensity',
                        'start': start,
                        'end': start + len(intensity_orig),
                        'original_text': original_text,
                        'source': 'intensity_column',
                        'row_idx': row_id
                    })

            # ---- Temporal patterns (look in temporalisation column and combined text)
            for category, terms in temporal_patterns.items():
                for term in terms:
                    found = False
                    # search in combined first (to get offsets)
                    m = re.search(r'\b' + re.escape(term) + r'\b', combined)
                    if m:
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'TEMPORAL_PATTERN',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'temporal_extraction',
                            'row_idx': row_id
                        })
                        found = True
                    # if not found in combined but present in temporalisation text, add without exact offset
                    if not found and term in temporalisation:
                        custom_entities.append({
                            'text': term,
                            'label': 'TEMPORAL_PATTERN',
                            'category': category,
                            'start': 0,
                            'end': len(term),
                            'original_text': original_text,
                            'source': 'temporal_extraction',
                            'row_idx': row_id
                        })

            # ---- Laterality (location or combined)
            for lat in laterality_terms:
                pattern = r'\b' + re.escape(lat) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'LATERALITY',
                        'category': 'laterality',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'laterality_extraction',
                        'row_idx': row_id
                    })

        return pd.DataFrame(custom_entities)

    # Extract custom entities
    df_custom_symptom_entities = extract_symptom_entities_custom(df_symptoms)

    print(f"\nCustom extraction found:")
    if not df_custom_symptom_entities.empty and 'label' in df_custom_symptom_entities.columns:
        print(df_custom_symptom_entities['label'].value_counts())
    else:
        print("No custom symptom entities found.")

    # Combine all entities (preserve row_idx/category when present)
    df_all_symptom_entities = pd.concat(
        [df_symptoms_entities, df_custom_symptom_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_symptom_entities.empty and 'label' in df_all_symptom_entities.columns:
        print(df_all_symptom_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze symptom entities
    symptom_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'SYMPTOM'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not symptom_entities.empty:
        print(f"\n=== Top Symptoms ===")
        print(symptom_entities['text'].str.lower().value_counts().head(20))

    # Analyze anatomical locations
    anatomy_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'ANATOMY'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not anatomy_entities.empty:
        print(f"\n=== Top Symptom Locations ===")
        print(anatomy_entities['text'].str.lower().value_counts().head(20))

    # Analyze severity
    severity_entities = df_all_symptom_entities[df_all_symptom_entities['label'] == 'SEVERITY'] if not df_all_symptom_entities.empty else pd.DataFrame()
    if not severity_entities.empty:
        print(f"\n=== Severity Distribution ===")
        print(severity_entities['text'].str.lower().value_counts())

    # ================= Symptom-specific labeling functions =================
    print("\n=== Creating Symptom-Specific Labeling Functions ===")

    def create_symptom_labeling_functions():
        """Create labeling functions for symptom patterns"""

        def lf_severe_pain(row):
            symptom = str(row.get('name of symptom', '')).lower()
            intensity = str(row.get('intensity of symptom', '')).lower()
            if 'pain' in symptom and intensity == 'severe':
                return 'SEVERE_PAIN'
            return 'ABSTAIN'

        def lf_chronic_symptom(row):
            temporal = str(row.get('temporalisation', '')).lower()
            time_text = str(row.get('time', '')).lower()
            chronic_indicators = ['chronic', 'persistent', 'ongoing', 'months', 'years']
            if any(ind in (temporal + " " + time_text) for ind in chronic_indicators):
                return 'CHRONIC_SYMPTOM'
            return 'ABSTAIN'

        def lf_neurological_symptom(row):
            symptom = str(row.get('name of symptom', '')).lower()
            details = str(row.get('details', '')).lower()
            neuro_terms = ['numbness', 'tingling', 'weakness', 'paralysis', 'sensation']
            if any(term in (symptom + " " + details) for term in neuro_terms):
                return 'NEUROLOGICAL'
            return 'ABSTAIN'

        def lf_bilateral_symptom(row):
            location = str(row.get('location', '')).lower()
            if any(term in location for term in ['bilateral', 'both', 'left and right']):
                return 'BILATERAL_SYMPTOM'
            return 'ABSTAIN'

        def lf_acute_onset(row):
            temporal = str(row.get('temporalisation', '')).lower()
            details = str(row.get('details', '')).lower()
            acute_terms = ['sudden', 'acute', 'abrupt', 'rapid onset']
            if any(term in (temporal + " " + details) for term in acute_terms):
                return 'ACUTE_ONSET'
            return 'ABSTAIN'

        def lf_progressive_symptom(row):
            temporal = str(row.get('temporalisation', '')).lower()
            details = str(row.get('details', '')).lower()
            if any(term in (temporal + " " + details) for term in ['worsening', 'progressive', 'increasing']):
                return 'PROGRESSIVE_SYMPTOM'
            return 'ABSTAIN'

        def lf_mobility_issue(row):
            symptom = str(row.get('name of symptom', '')).lower()
            if any(term in symptom for term in ['walk', 'gait', 'mobility', 'movement']):
                return 'MOBILITY_ISSUE'
            return 'ABSTAIN'

        def lf_pain_with_location(row):
            symptom = str(row.get('name of symptom', '')).lower()
            location = str(row.get('location', '')).lower()
            if 'pain' in symptom and location not in ('', 'nan'):
                if 'hip' in location:
                    return 'HIP_PAIN'
                elif 'knee' in location:
                    return 'KNEE_PAIN'
                elif 'back' in location:
                    return 'BACK_PAIN'
                else:
                    return 'LOCALIZED_PAIN'
            return 'ABSTAIN'

        def lf_systemic_symptom(row):
            symptom = str(row.get('name of symptom', '')).lower()
            systemic_terms = ['fever', 'fatigue', 'weight loss', 'malaise']
            if any(term in symptom for term in systemic_terms):
                return 'SYSTEMIC_SYMPTOM'
            return 'ABSTAIN'

        def lf_symptom_duration(row):
            # Accept either 'time_duration_days' or 'duration_days'
            duration = row.get('time_duration_days')
            if pd.isna(duration):
                duration = row.get('duration_days')
            if pd.notna(duration):
                try:
                    d = float(duration)
                except Exception:
                    return 'ABSTAIN'
                if d <= 7:
                    return 'ACUTE_DURATION'
                elif d <= 30:
                    return 'SUBACUTE_DURATION'
                else:
                    return 'CHRONIC_DURATION'
            return 'ABSTAIN'

        return [
            lf_severe_pain, lf_chronic_symptom, lf_neurological_symptom,
            lf_bilateral_symptom, lf_acute_onset, lf_progressive_symptom,
            lf_mobility_issue, lf_pain_with_location, lf_systemic_symptom,
            lf_symptom_duration
        ]

    # Test labeling functions
    symptom_lfs = create_symptom_labeling_functions()

    print("\nTesting symptom labeling functions:")
    for i in range(min(10, len(df_symptoms))):
        row = df_symptoms.iloc[i]
        print(f"\nRow {i}: {str(row.get('name of symptom',''))} - {str(row.get('location',''))}")
        for lf in symptom_lfs:
            result = lf(row)
            if result != 'ABSTAIN':
                print(f"  {lf.__name__}: {result}")

    # ================= Extract symptom relations (use correct row_idx) =================
    print("\n=== Extracting Symptom Relations ===")

    def extract_symptom_relations(df_symptoms):
        """
        Build row-level relations from structured columns.
        Uses 'idx' for row identity when available; falls back to df.index.
        """
        relations = []
        for df_index, row in df_symptoms.iterrows():
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index
            symptom = row.get('name of symptom')
            location = row.get('location')
            intensity = row.get('intensity of symptom')
            duration = row.get('time_duration_days') if pd.notna(row.get('time_duration_days', pd.NA)) else row.get('duration_days')
            temporal = row.get('temporalisation')

            # SYMPTOM_LOCATED_IN
            if pd.notna(symptom) and pd.notna(location) and str(location).lower() != 'nan':
                relations.append({
                    'type': 'SYMPTOM_LOCATED_IN',
                    'symptom': symptom,
                    'location': location,
                    'row_idx': row_id
                })

            # SYMPTOM_HAS_SEVERITY
            if pd.notna(symptom) and pd.notna(intensity) and str(intensity).lower() != 'nan':
                relations.append({
                    'type': 'SYMPTOM_HAS_SEVERITY',
                    'symptom': symptom,
                    'severity': intensity,
                    'row_idx': row_id
                })

            # SYMPTOM_HAS_DURATION
            if pd.notna(symptom) and pd.notna(duration):
                relations.append({
                    'type': 'SYMPTOM_HAS_DURATION',
                    'symptom': symptom,
                    'duration_days': duration,
                    'row_idx': row_id
                })

            # SYMPTOM_HAS_TEMPORAL_PATTERN
            if pd.notna(symptom) and pd.notna(temporal) and str(temporal).lower() != 'nan':
                relations.append({
                    'type': 'SYMPTOM_HAS_TEMPORAL_PATTERN',
                    'symptom': symptom,
                    'pattern': temporal,
                    'row_idx': row_id
                })

        return relations

    symptom_relations = extract_symptom_relations(df_symptoms)

    print(f"Found {len(symptom_relations)} symptom relations")
    if symptom_relations:
        relation_types = pd.DataFrame(symptom_relations)['type'].value_counts()
        print("\nRelation type distribution:")
        print(relation_types)

        print("\nSample relations:")
        for rel in symptom_relations[:10]:
            if rel['type'] == 'SYMPTOM_LOCATED_IN':
                print(f"  {rel['symptom']} -> located in -> {rel['location']} (row {rel['row_idx']})")
            elif rel['type'] == 'SYMPTOM_HAS_SEVERITY':
                print(f"  {rel['symptom']} -> has severity -> {rel['severity']} (row {rel['row_idx']})")

    # Save results
    df_all_symptom_entities.to_csv('symptom_entities_comprehensive.csv', index=False)
    pd.DataFrame(symptom_relations).to_csv('symptom_relations.csv', index=False)

    print("\n\nSymptom entity extraction and labeling complete!")
    print(f"Saved {len(df_all_symptom_entities)} entities and {len(symptom_relations)} relations.")


Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 18893
Temporal types: {'unspecified': 13201, 'post_event': 6146, 'range_reference': 5023, 'past_reference': 4704, 'duration_reference': 4063, 'onset_reference': 2904, 'absolute_date': 1164}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 54944 texts in 184 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   1%|          | 1/184 [00:01<05:38,  1.85s/it]


Checkpoint saved at batch 0


Processing batches:   6%|▌         | 11/184 [00:20<05:04,  1.76s/it]


Checkpoint saved at batch 3000


Processing batches:  11%|█▏        | 21/184 [00:36<04:31,  1.67s/it]


Checkpoint saved at batch 6000


Processing batches:  17%|█▋        | 31/184 [00:58<06:17,  2.47s/it]


Checkpoint saved at batch 9000


Processing batches:  22%|██▏       | 41/184 [01:18<05:03,  2.12s/it]


Checkpoint saved at batch 12000


Processing batches:  28%|██▊       | 51/184 [01:37<04:28,  2.02s/it]


Checkpoint saved at batch 15000


Processing batches:  33%|███▎      | 61/184 [01:58<04:20,  2.12s/it]


Checkpoint saved at batch 18000


Processing batches:  39%|███▊      | 71/184 [02:17<03:54,  2.08s/it]


Checkpoint saved at batch 21000


Processing batches:  44%|████▍     | 81/184 [02:36<03:28,  2.03s/it]


Checkpoint saved at batch 24000


Processing batches:  49%|████▉     | 91/184 [02:56<03:12,  2.06s/it]


Checkpoint saved at batch 27000


Processing batches:  55%|█████▍    | 101/184 [03:12<02:22,  1.72s/it]


Checkpoint saved at batch 30000


Processing batches:  60%|██████    | 111/184 [03:30<02:18,  1.89s/it]


Checkpoint saved at batch 33000


Processing batches:  66%|██████▌   | 121/184 [03:49<02:21,  2.25s/it]


Checkpoint saved at batch 36000


Processing batches:  71%|███████   | 131/184 [04:05<01:26,  1.63s/it]


Checkpoint saved at batch 39000


Processing batches:  77%|███████▋  | 141/184 [04:20<01:17,  1.80s/it]


Checkpoint saved at batch 42000


Processing batches:  82%|████████▏ | 151/184 [04:37<01:00,  1.82s/it]


Checkpoint saved at batch 45000


Processing batches:  88%|████████▊ | 161/184 [04:54<00:42,  1.86s/it]


Checkpoint saved at batch 48000


Processing batches:  93%|█████████▎| 171/184 [05:11<00:23,  1.80s/it]


Checkpoint saved at batch 51000


Processing batches:  98%|█████████▊| 181/184 [05:32<00:06,  2.19s/it]


Checkpoint saved at batch 54000


Processing batches: 100%|██████████| 184/184 [05:36<00:00,  1.83s/it]


Found 1857 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 84179, 'CHEMICAL': 2286}

=== CUSTOM SYMPTOM ENTITY EXTRACTION ===

Custom extraction found:
label
ANATOMY             67817
SYMPTOM_TYPE        56789
SYMPTOM             53580
LATERALITY          31971
TEMPORAL_PATTERN    14118
SEVERITY             5306
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
DISEASE             84179
ANATOMY             67817
SYMPTOM_TYPE        56789
SYMPTOM             53580
LATERALITY          31971
TEMPORAL_PATTERN    14118
SEVERITY             5306
CHEMICAL             2286
Name: count, dtype: int64

=== Top Symptoms ===
text
pain                    1796
abdominal pain          1182
swelling                1146
fever                    825
headache                 722
chest pain               576
vomiting                 575
dyspnea                  519
shortness of breath      501
weight loss              459
nausea                   399
dysphag

In [147]:
df_all_symptom_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category,source
0,Pain,DISEASE,0,4,Pain with Severe intensity located in Left hip...,133948,combined_text,,
1,pain,DISEASE,147,151,Pain with Severe intensity located in Left hip...,133948,combined_text,,
2,weight gain,DISEASE,129,140,Moderate moon face with Moderate intensity loc...,133948,combined_text,,
3,Central obesity,DISEASE,0,15,Central obesity with nan intensity located in ...,133948,combined_text,,
4,Muscle mass reduction,DISEASE,0,21,Muscle mass reduction with nan intensity locat...,133948,combined_text,,
...,...,...,...,...,...,...,...,...,...
316041,chest,ANATOMY,20,25,Crushing substernal chest pressure with Acute ...,97973,,trunk,pattern_matching
316042,Acute,TEMPORAL_PATTERN,40,45,Crushing substernal chest pressure with Acute ...,97973,,acute,temporal_extraction
316043,Dyspnea,SYMPTOM,0,7,Dyspnea with Rapidly developed intensity locat...,97973,,primary_symptom,symptom_name_column
316044,Dyspnea,SYMPTOM_TYPE,0,7,Dyspnea with Rapidly developed intensity locat...,97973,,respiratory,pattern_matching


### Extracting Diagnosis

In [111]:
df_diagnosis

Unnamed: 0,idx,has_diagnosis,test,severity,result,condition,time,details,combined_text
0,155216,False,,,,,,,nan performed with nan severity showed nan ind...
1,133948,True,Magnetic resonance imaging (MRI) scan,,Increased amount of joint fluid and bone marro...,Idiopathic osteonecrosis of the femoral head,,Patient did not complain of any pain on the co...,Magnetic resonance imaging (MRI) scan performe...
2,133948,True,Repeat MRI,,Similar findings to those noted previously in ...,,One year after the initial surgery and symptom...,,Repeat MRI performed with nan severity showed ...
3,80176,True,Radiographs,Minimally displaced,Proximal ulnar shaft fracture,"Proximal ulnar shaft fracture, hypertrophic no...",,Elbow arthrodesis at 90 degrees with retained ...,Radiographs performed with Minimally displaced...
4,72232,True,MRI,Moderate-sized,Focal area of marrow edema/contusion involving...,Bone marrow edema,"September 2016, three months later, April 2017...",Involvement of medial femoral condyle in mid a...,MRI performed with Moderate-sized severity sho...
...,...,...,...,...,...,...,...,...,...
61350,133320,True,Histopathological examination,,Consistent with lung metastasis of leiomyosarcoma,Lung metastasis of leiomyosarcoma,One year and 3 months postoperatively,,Histopathological examination performed with n...
61351,97973,True,Electrocardiogram (ECG),,Diffuse ST depressions in all precordial leads,Consistent with an acute coronary syndrome,,,Electrocardiogram (ECG) performed with nan sev...
61352,97973,True,Transthoracic echocardiogram,Ejection fraction (EF) of 45% with severe aort...,Torn right coronary cusp,Severe aortic insufficiency,,Emergent transthoracic echocardiogram performed,Transthoracic echocardiogram performed with Ej...
61353,97973,True,Blood cultures,,Positive for S.\nlugdunensis in both bottles,,,,Blood cultures performed with nan severity sho...


In [148]:
if __name__ == "__main__":
    # Apply temporal standardization
    temporal_standardizer = TemporalStandardizer()
    df_diagnosis_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_diagnosis, 
        'time'
    )
   
    print(f"\nTemporal extraction results:")
    print(f"Successfully extracted duration: {temporal_report['extracted_durations']}")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine diagnosis text columns (robust to NaNs)
    df_diagnosis['combined_text'] = df_diagnosis.apply(
        lambda row: f"{str(row.get('test',''))} performed with {str(row.get('severity',''))} severity "
                    f"showed {str(row.get('result',''))} indicating {str(row.get('condition',''))} "
                    f"{str(row.get('details',''))} at {str(row.get('time',''))}",
        axis=1        
    )

    # Copy combined_text to processed dataframe
    df_diagnosis_processed['combined_text'] = df_diagnosis['combined_text']

    # Add temporal features to combined text
    df_diagnosis_processed['combined_text_enriched'] = df_diagnosis_processed.apply(
        lambda row: f"{row['combined_text']} "
                    f"{'lasting ' + str(row.get('time_duration_days', '')) + ' days' if pd.notna(row.get('time_duration_days')) else ''}",
        axis=1
    )

    # Run NER extraction with BC5CDR
    df_diagnosis_entities, diagnosis_summary, diagnosis_rules = run_medical_ner_extraction(
        df_diagnosis,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300, 
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(diagnosis_summary['entity_types'])

    # ================= CUSTOM DIAGNOSIS ENTITY EXTRACTION (keeps correct row_idx) =================
    print("\n=== CUSTOM DIAGNOSIS ENTITY EXTRACTION ===")

    def extract_diagnosis_entities_custom(df_diagnosis):
        """
        Extract diagnosis-specific entities not (fully) caught by BC5CDR.
        Uses 'idx' for row identity when available; falls back to df.index otherwise.
        Adds 'row_idx' and character offsets for all extracted entities.
        Regex with word boundaries is used to avoid substring false positives.
        """
        custom_entities = []

        # Test/Procedure patterns
        test_patterns = {
            'imaging': [
                'mri', 'magnetic resonance', 'ct', 'computed tomography',
                'x-ray', 'radiograph', 'ultrasound', 'sonography', 'scan',
                'pet', 'spect', 'angiography', 'mammography'
            ],
            'laboratory': [
                'blood test', 'serum', 'plasma', 'biochemical', 'hematology',
                'urinalysis', 'culture', 'biopsy', 'cytology', 'pathology'
            ],
            'functional': [
                'ecg', 'ekg', 'electrocardiogram', 'eeg', 'electroencephalogram',
                'emg', 'electromyography', 'spirometry', 'pulmonary function'
            ],
            'endoscopy': [
                'endoscopy', 'colonoscopy', 'gastroscopy', 'bronchoscopy',
                'cystoscopy', 'arthroscopy', 'laparoscopy'
            ]
        }

        # Finding/Result patterns
        finding_patterns = {
            'structural': [
                'fracture', 'lesion', 'mass', 'tumor', 'cyst', 'nodule',
                'stenosis', 'occlusion', 'herniation', 'displacement'
            ],
            'inflammatory': [
                'inflammation', 'edema', 'swelling', 'effusion', 'congestion',
                'infiltration', 'consolidation'
            ],
            'degenerative': [
                'degeneration', 'atrophy', 'necrosis', 'fibrosis', 'sclerosis',
                'osteoarthritis', 'spondylosis'
            ],
            'vascular': [
                'ischemia', 'infarction', 'hemorrhage', 'aneurysm', 'thrombosis',
                'embolism', 'vasculitis'
            ],
            'neoplastic': [
                'malignant', 'benign', 'metastasis', 'carcinoma', 'sarcoma',
                'lymphoma', 'adenoma'
            ]
        }

        # Anatomical patterns specific to diagnosis
        anatomy_patterns = {
            'bone':   ['femur', 'tibia', 'fibula', 'humerus', 'radius', 'ulna', 'vertebra'],
            'joint':  ['hip', 'knee', 'shoulder', 'elbow', 'ankle', 'wrist'],
            'organ':  ['liver', 'kidney', 'heart', 'lung', 'brain', 'pancreas', 'spleen'],
            'vessel': ['artery', 'vein', 'aorta', 'carotid', 'coronary'],
            'region': ['parietal', 'temporal', 'frontal', 'occipital', 'cervical', 'lumbar']
        }

        # Severity/Grade patterns (from column)
        severity_patterns = {
            'mild':     ['mild', 'minimal', 'slight', 'minor'],
            'moderate': ['moderate', 'medium', 'intermediate'],
            'severe':   ['severe', 'significant', 'marked', 'extensive'],
            'critical': ['critical', 'life-threatening', 'emergency']
        }

        # Laterality
        laterality_terms = ['left', 'right', 'bilateral']

        for df_index, row in df_diagnosis.iterrows():
            # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

            original_text = str(row.get('combined_text', ''))
            combined = original_text.lower()

            test_text_orig = str(row.get('test', ''))
            test_text = test_text_orig.lower()
            result_text = str(row.get('result', '')).lower()
            condition_text_orig = str(row.get('condition', ''))
            condition_text = condition_text_orig.lower()
            severity_text_orig = str(row.get('severity', ''))
            severity_text = severity_text_orig.lower()

            # ---- TEST from test column (place in the combined text if possible)
            if test_text and test_text != 'nan':
                start = combined.find(test_text)
                if start < 0:
                    start = 0
                end = min(start + len(test_text_orig), len(original_text))
                custom_entities.append({
                    'text': test_text_orig,
                    'label': 'TEST',
                    'category': 'diagnostic_test',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'test_column',
                    'row_idx': row_id
                })
                # TEST_TYPE via patterns (regex boundaries)
                for category, terms in test_patterns.items():
                    for term in terms:
                        pattern = r'\b' + re.escape(term.lower()) + r'\b'
                        for m in re.finditer(pattern, combined):
                            custom_entities.append({
                                'text': original_text[m.start():m.end()],
                                'label': 'TEST_TYPE',
                                'category': category,
                                'start': m.start(),
                                'end': m.end(),
                                'original_text': original_text,
                                'source': 'pattern_matching',
                                'row_idx': row_id
                            })

            # ---- FINDINGS from result (search in combined for offsets)
            if result_text and result_text != 'nan':
                for category, terms in finding_patterns.items():
                    for term in terms:
                        pattern = r'\b' + re.escape(term.lower()) + r'\b'
                        for m in re.finditer(pattern, combined):
                            custom_entities.append({
                                'text': original_text[m.start():m.end()],
                                'label': 'FINDING',
                                'category': category,
                                'start': m.start(),
                                'end': m.end(),
                                'original_text': original_text,
                                'source': 'result_extraction',
                                'row_idx': row_id
                            })

            # ---- CONDITION from condition column
            if condition_text and condition_text != 'nan':
                start = combined.find(condition_text)
                if start < 0:
                    start = 0
                end = min(start + len(condition_text_orig), len(original_text))
                custom_entities.append({
                    'text': condition_text_orig,
                    'label': 'CONDITION',
                    'category': 'diagnosis',
                    'start': start,
                    'end': end,
                    'original_text': original_text,
                    'source': 'condition_column',
                    'row_idx': row_id
                })

            # ---- ANATOMY patterns (allow simple plural 's')
            for category, terms in anatomy_patterns.items():
                for term in terms:
                    plural_opt = 's?' if term.isalpha() else ''
                    pattern = r'\b' + re.escape(term.lower()) + plural_opt + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': original_text[m.start():m.end()],
                            'label': 'ANATOMY',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': original_text,
                            'source': 'anatomy_extraction',
                            'row_idx': row_id
                        })

            # ---- SEVERITY from severity column
            if severity_text and severity_text != 'nan':
                for category, terms in severity_patterns.items():
                    if severity_text in terms:
                        start = combined.find(severity_text)
                        if start < 0:
                            start = 0
                        end = min(start + len(severity_text_orig), len(original_text))
                        custom_entities.append({
                            'text': severity_text_orig,
                            'label': 'SEVERITY',
                            'category': category,
                            'start': start,
                            'end': end,
                            'original_text': original_text,
                            'source': 'severity_column',
                            'row_idx': row_id
                        })
                        break

            # ---- LATERALITY
            for lat in laterality_terms:
                pattern = r'\b' + re.escape(lat) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': original_text[m.start():m.end()],
                        'label': 'LATERALITY',
                        'category': 'laterality',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': original_text,
                        'source': 'laterality_extraction',
                        'row_idx': row_id
                    })

            # ---- MEASUREMENTS (capture full token with units)
            # e.g., "12 mm", "3.5 cm", "45%", "10 mg", "30 ml"
            measurement_pattern = r'\b\d+(?:\.\d+)?\s*(?:mm|cm|ml|mg|%)\b'
            for m in re.finditer(measurement_pattern, combined):
                custom_entities.append({
                    'text': original_text[m.start():m.end()],
                    'label': 'MEASUREMENT',
                    'category': 'quantitative',
                    'start': m.start(),
                    'end': m.end(),
                    'original_text': original_text,
                    'source': 'measurement_extraction',
                    'row_idx': row_id
                })

        return pd.DataFrame(custom_entities)

    # Extract custom entities
    df_custom_diagnosis_entities = extract_diagnosis_entities_custom(df_diagnosis)

    print(f"\nCustom extraction found:")
    if not df_custom_diagnosis_entities.empty and 'label' in df_custom_diagnosis_entities.columns:
        print(df_custom_diagnosis_entities['label'].value_counts())
    else:
        print("No custom diagnosis entities found.")

    # Combine all entities (preserve row_idx/category when present)
    df_all_diagnosis_entities = pd.concat(
        [df_diagnosis_entities, df_custom_diagnosis_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_diagnosis_entities.empty and 'label' in df_all_diagnosis_entities.columns:
        print(df_all_diagnosis_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze test entities
    test_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'TEST'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not test_entities.empty:
        print(f"\n=== Top Diagnostic Tests ===")
        print(test_entities['text'].str.lower().value_counts().head(20))

    # Analyze findings
    finding_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'FINDING'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not finding_entities.empty:
        print(f"\n=== Top Diagnostic Findings ===")
        print(finding_entities['text'].str.lower().value_counts().head(20))

    # Analyze conditions
    condition_entities = df_all_diagnosis_entities[df_all_diagnosis_entities['label'] == 'CONDITION'] if not df_all_diagnosis_entities.empty else pd.DataFrame()
    if not condition_entities.empty:
        print(f"\n=== Top Diagnosed Conditions ===")
        print(condition_entities['text'].str.lower().value_counts().head(20))

    # ================= Diagnosis-specific labeling functions =================
    print("\n=== Creating Diagnosis-Specific Labeling Functions ===")

    def create_diagnosis_labeling_functions():
        """Create labeling functions for diagnosis patterns"""

        def lf_imaging_test(row):
            test = str(row.get('test', '')).lower()
            imaging_keywords = ['mri', 'ct', 'x-ray', 'radiograph', 'ultrasound', 'scan']
            if any(keyword in test for keyword in imaging_keywords):
                return 'IMAGING_TEST'
            return 'ABSTAIN'

        def lf_fracture_diagnosis(row):
            result = str(row.get('result', '')).lower()
            condition = str(row.get('condition', '')).lower()
            if 'fracture' in result or 'fracture' in condition:
                if 'no fracture' in result or 'no fracture' in condition:
                    return 'NO_FRACTURE'
                else:
                    return 'FRACTURE_PRESENT'
            return 'ABSTAIN'

        def lf_neoplastic_finding(row):
            combined = str(row.get('combined_text', '')).lower()
            neoplastic_terms = ['tumor', 'mass', 'lesion', 'malignant', 'benign', 'metastasis']
            if any(term in combined for term in neoplastic_terms):
                return 'NEOPLASTIC_FINDING'
            return 'ABSTAIN'

        def lf_normal_finding(row):
            result = str(row.get('result', '')).lower()
            normal_terms = ['normal', 'negative', 'no abnormality', 'unremarkable']
            if any(term in result for term in normal_terms):
                return 'NORMAL_FINDING'
            return 'ABSTAIN'

        def lf_critical_finding(row):
            severity = str(row.get('severity', '')).lower()
            details = str(row.get('details', '')).lower()
            critical_terms = ['critical', 'emergency', 'urgent', 'life-threatening']
            if any(term in (severity + details) for term in critical_terms):
                return 'CRITICAL_FINDING'
            return 'ABSTAIN'

        def lf_bone_pathology(row):
            combined = str(row.get('combined_text', '')).lower()
            bone_terms = ['bone', 'osseous', 'fracture', 'osteo', 'marrow']
            pathology_terms = ['lesion', 'edema', 'necrosis', 'fracture']
            if any(b in combined for b in bone_terms) and any(p in combined for p in pathology_terms):
                return 'BONE_PATHOLOGY'
            return 'ABSTAIN'

        def lf_vascular_finding(row):
            combined = str(row.get('combined_text', '')).lower()
            vascular_terms = ['vascular', 'artery', 'vein', 'aneurysm', 'stenosis', 'occlusion']
            if any(term in combined for term in vascular_terms):
                return 'VASCULAR_FINDING'
            return 'ABSTAIN'

        def lf_inflammatory_finding(row):
            combined = str(row.get('combined_text', '')).lower()
            inflammatory_terms = ['inflammation', 'inflammatory', 'edema', 'effusion', 'swelling']
            if any(term in combined for term in inflammatory_terms):
                return 'INFLAMMATORY_FINDING'
            return 'ABSTAIN'

        def lf_bilateral_finding(row):
            combined = str(row.get('combined_text', '')).lower()
            if 'bilateral' in combined:
                return 'BILATERAL_FINDING'
            return 'ABSTAIN'

        def lf_followup_needed(row):
            details = str(row.get('details', '')).lower()
            followup_terms = ['follow-up', 'followup', 'repeat', 'monitor', 'reassess']
            if any(term in details for term in followup_terms):
                return 'FOLLOWUP_NEEDED'
            return 'ABSTAIN'

        return [
            lf_imaging_test, lf_fracture_diagnosis, lf_neoplastic_finding,
            lf_normal_finding, lf_critical_finding, lf_bone_pathology,
            lf_vascular_finding, lf_inflammatory_finding, lf_bilateral_finding,
            lf_followup_needed
        ]

    # Test labeling functions
    diagnosis_lfs = create_diagnosis_labeling_functions()

    print("\nTesting diagnosis labeling functions:")
    for i in range(min(10, len(df_diagnosis))):
        row = df_diagnosis.iloc[i]
        print(f"\nRow {i}: {str(row.get('test',''))} - {str(row.get('condition',''))}")
        for lf in diagnosis_lfs:
            result = lf(row)
            if result != 'ABSTAIN':
                print(f"  {lf.__name__}: {result}")

    # ================= Extract diagnosis relations (use correct row_idx) =================
    print("\n=== Extracting Diagnosis Relations ===")

    def extract_diagnosis_relations(df_diagnosis):
        """
        Extract relations between tests, findings, and conditions at the row level.
        Uses 'idx' for row identity when available; falls back to df.index.
        """
        relations = []
        for df_index, row in df_diagnosis.iterrows():
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index
            test = row.get('test')
            result = row.get('result')
            condition = row.get('condition')
            severity = row.get('severity')

            # TEST_REVEALS_FINDING
            if pd.notna(test) and pd.notna(result) and str(result).lower() != 'nan':
                relations.append({
                    'type': 'TEST_REVEALS',
                    'test': test,
                    'finding': result,
                    'row_idx': row_id
                })

            # TEST_DIAGNOSES_CONDITION
            if pd.notna(test) and pd.notna(condition) and str(condition).lower() != 'nan':
                relations.append({
                    'type': 'TEST_DIAGNOSES',
                    'test': test,
                    'condition': condition,
                    'row_idx': row_id
                })

            # FINDING_INDICATES_CONDITION
            if pd.notna(result) and pd.notna(condition) and str(result).lower() != 'nan' and str(condition).lower() != 'nan':
                relations.append({
                    'type': 'FINDING_INDICATES',
                    'finding': result,
                    'condition': condition,
                    'row_idx': row_id
                })

            # FINDING_HAS_SEVERITY
            if pd.notna(result) and pd.notna(severity) and str(severity).lower() != 'nan':
                relations.append({
                    'type': 'FINDING_HAS_SEVERITY',
                    'finding': result,
                    'severity': severity,
                    'row_idx': row_id
                })

            # Temporal relations (if available from temporal standardizer)
            temporal_info = row.get('temporal_info')
            if temporal_info and pd.notna(test):
                relations.append({
                    'type': 'TEST_PERFORMED_AT',
                    'test': test,
                    'time': row.get('time'),
                    'temporal_type': temporal_info.get('type'),
                    'row_idx': row_id
                })

        return relations

    diagnosis_relations = extract_diagnosis_relations(df_diagnosis)

    print(f"\nFound {len(diagnosis_relations)} diagnosis relations")
    if diagnosis_relations:
        relation_types = pd.DataFrame(diagnosis_relations)['type'].value_counts()
        print("\nRelation type distribution:")
        print(relation_types)

        print("\nSample relations:")
        for rel in diagnosis_relations[:10]:
            if rel['type'] == 'TEST_REVEALS':
                print(f"  {rel['test']} -> reveals -> {rel['finding']} (row {rel['row_idx']})")
            elif rel['type'] == 'TEST_DIAGNOSES':
                print(f"  {rel['test']} -> diagnoses -> {rel['condition']} (row {rel['row_idx']})")

    # Analyze test-finding patterns
    print("\n=== Test-Finding Pattern Analysis ===")
    test_finding_pairs = {}
    for rel in diagnosis_relations:
        if rel['type'] == 'TEST_REVEALS':
            pair = f"{rel['test']} -> {rel['finding']}"
            test_finding_pairs[pair] = test_finding_pairs.get(pair, 0) + 1

    print("\nCommon test-finding patterns:")
    for pair, count in sorted(test_finding_pairs.items(), key=lambda x: x[1], reverse=True)[:10]:
        print(f"  {pair}: {count} occurrences")

    # Coverage analysis
    print("\n=== Labeling Function Coverage Analysis ===")
    coverage_results = {}
    for lf in diagnosis_lfs:
        labeled = sum(1 for _, row in df_diagnosis.iterrows() if lf(row) != 'ABSTAIN')
        coverage = (labeled / len(df_diagnosis)) * 100 if len(df_diagnosis) else 0.0
        coverage_results[lf.__name__] = {'labeled': labeled, 'coverage': coverage}

    print("\nLabeling function coverage:")
    for lf_name, stats in sorted(coverage_results.items(), key=lambda x: x[1]['coverage'], reverse=True):
        print(f"  {lf_name}: {stats['labeled']} labels ({stats['coverage']:.1f}% coverage)")

    # Diagnostic summary (guard against empties)
    print("\n=== Diagnostic Summary ===")
    if not df_custom_diagnosis_entities.empty:
        if 'TEST_TYPE' in df_custom_diagnosis_entities['label'].values:
            test_types = df_custom_diagnosis_entities[
                df_custom_diagnosis_entities['label'] == 'TEST_TYPE'
            ]['category'].value_counts()
            print("\nTest types distribution:")
            print(test_types)

        if 'FINDING' in df_custom_diagnosis_entities['label'].values:
            finding_categories = df_custom_diagnosis_entities[
                df_custom_diagnosis_entities['label'] == 'FINDING'
            ]['category'].value_counts()
            print("\nFinding categories:")
            print(finding_categories)

    # Save results
    df_all_diagnosis_entities.to_csv('diagnosis_entities_comprehensive.csv', index=False)
    pd.DataFrame(diagnosis_relations).to_csv('diagnosis_relations.csv', index=False)

    # Export diagnostic patterns for future use
    diagnostic_patterns = {
        'common_tests': test_entities['text'].value_counts().head(20).to_dict() if not test_entities.empty else {},
        'common_findings': finding_entities['text'].value_counts().head(20).to_dict() if not finding_entities.empty else {},
        'test_finding_pairs': test_finding_pairs
    }
    with open('diagnostic_patterns.json', 'w') as f:
        json.dump(diagnostic_patterns, f, indent=2)

    print("\n\nDiagnosis entity extraction and labeling complete!")
    print(f"Saved {len(df_all_diagnosis_entities)} entities and {len(diagnosis_relations)} relations")
    print("Diagnostic patterns exported to diagnostic_patterns.json")



Processing temporal column: time

Temporal extraction results:
Successfully extracted duration: 3879
Temporal types: {'unspecified': 4734, 'post_event': 3354, 'absolute_date': 1766, 'past_reference': 1280, 'range_reference': 1258, 'duration_reference': 299, 'onset_reference': 155}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 61355 texts in 205 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   0%|          | 1/205 [00:03<10:23,  3.06s/it]


Checkpoint saved at batch 0


Processing batches:   5%|▌         | 11/205 [00:21<06:05,  1.88s/it]


Checkpoint saved at batch 3000


Processing batches:  10%|█         | 21/205 [00:39<05:51,  1.91s/it]


Checkpoint saved at batch 6000


Processing batches:  15%|█▌        | 31/205 [00:59<05:47,  2.00s/it]


Checkpoint saved at batch 9000


Processing batches:  20%|██        | 41/205 [01:17<05:11,  1.90s/it]


Checkpoint saved at batch 12000


Processing batches:  25%|██▍       | 51/205 [01:36<04:54,  1.91s/it]


Checkpoint saved at batch 15000


Processing batches:  30%|██▉       | 61/205 [01:54<04:36,  1.92s/it]


Checkpoint saved at batch 18000


Processing batches:  35%|███▍      | 71/205 [02:14<04:19,  1.94s/it]


Checkpoint saved at batch 21000


Processing batches:  40%|███▉      | 81/205 [02:38<04:58,  2.41s/it]


Checkpoint saved at batch 24000


Processing batches:  44%|████▍     | 91/205 [02:59<04:04,  2.15s/it]


Checkpoint saved at batch 27000


Processing batches:  49%|████▉     | 101/205 [03:18<03:22,  1.95s/it]


Checkpoint saved at batch 30000


Processing batches:  54%|█████▍    | 111/205 [03:37<03:03,  1.95s/it]


Checkpoint saved at batch 33000


Processing batches:  59%|█████▉    | 121/205 [03:55<02:45,  1.97s/it]


Checkpoint saved at batch 36000


Processing batches:  64%|██████▍   | 131/205 [04:14<02:26,  1.98s/it]


Checkpoint saved at batch 39000


Processing batches:  69%|██████▉   | 141/205 [04:34<02:23,  2.24s/it]


Checkpoint saved at batch 42000


Processing batches:  74%|███████▎  | 151/205 [04:52<01:47,  2.00s/it]


Checkpoint saved at batch 45000


Processing batches:  79%|███████▊  | 161/205 [05:12<01:28,  2.02s/it]


Checkpoint saved at batch 48000


Processing batches:  83%|████████▎ | 171/205 [05:30<01:10,  2.07s/it]


Checkpoint saved at batch 51000


Processing batches:  88%|████████▊ | 181/205 [05:49<00:47,  1.99s/it]


Checkpoint saved at batch 54000


Processing batches:  93%|█████████▎| 191/205 [06:08<00:28,  2.05s/it]


Checkpoint saved at batch 57000


Processing batches:  98%|█████████▊| 201/205 [06:27<00:08,  2.10s/it]


Checkpoint saved at batch 60000


Processing batches: 100%|██████████| 205/205 [06:34<00:00,  1.92s/it]


Found 2745 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 77150, 'CHEMICAL': 8577}

=== CUSTOM DIAGNOSIS ENTITY EXTRACTION ===

Custom extraction found:
label
TEST           56150
TEST_TYPE      51197
FINDING        40916
CONDITION      36529
ANATOMY        27435
LATERALITY     25373
MEASUREMENT    10044
SEVERITY        1054
Name: count, dtype: int64

=== COMBINED Entity Distribution ===
label
DISEASE        77150
TEST           56150
TEST_TYPE      51197
FINDING        40916
CONDITION      36529
ANATOMY        27435
LATERALITY     25373
MEASUREMENT    10044
CHEMICAL        8577
SEVERITY        1054
Name: count, dtype: int64

=== Top Diagnostic Tests ===
text
mri                                 1005
ct scan                              786
biopsy                               663
magnetic resonance imaging (mri)     655
histopathological examination        577
chest x-ray                          537
computed tomography (ct) scan        402
laboratory test

In [150]:
df_all_diagnosis_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category,source
0,bone marrow edema,DISEASE,109,126,Magnetic resonance imaging (MRI) scan performe...,133948,combined_text,,
1,femoral head necrosis,DISEASE,148,169,Magnetic resonance imaging (MRI) scan performe...,133948,combined_text,,
2,Idiopathic osteonecrosis of the femoral head P...,DISEASE,207,263,Magnetic resonance imaging (MRI) scan performe...,133948,combined_text,,
3,pain,DISEASE,284,288,Magnetic resonance imaging (MRI) scan performe...,133948,combined_text,,
4,fracture,DISEASE,84,92,Radiographs performed with Minimally displaced...,80176,combined_text,,
...,...,...,...,...,...,...,...,...,...
334420,right,LATERALITY,128,133,Transthoracic echocardiogram performed with Ej...,97973,,laterality,laterality_extraction
334421,Blood cultures,TEST,0,14,Blood cultures performed with nan severity sho...,97973,,diagnostic_test,test_column
334422,Transesophageal echocardiogram,TEST,0,30,Transesophageal echocardiogram performed with ...,97973,,diagnostic_test,test_column
334423,Acute severe aortic insufficiency from endocar...,CONDITION,145,196,Transesophageal echocardiogram performed with ...,97973,,diagnosis,condition_column


### Extracting Treatments

In [37]:
df_treatments

Unnamed: 0,idx,has_treatments,name,related condition,dosage,time,frequency,duration,reason for taking,reaction to treatment,details
0,155216,True,Olanzapine tablets,Bipolar affective disorder,5 mg per day,Past four months,Daily,,Control of exacerbated mental illness,"Pain and discomfort in neck, sustained and abn...",Previously managed with olanzapine tablets in ...
1,155216,True,Trihexyphenidyl,Rigidity in upper limbs,4 mg per day,Brief period of around three weeks,Daily,,Rigidity in upper limbs,Good response,
2,133948,False,,,,,,,,,
3,80176,True,Closed treatment in a cast,Proximal ulnar shaft fracture,,Initially after the fall,,,To treat the ulnar shaft fracture,Developed a hypertrophic nonunion,
4,80176,True,Conservative treatment,Ulna nonunion,,Three months after the fall,,An additional three months,To treat the ulna nonunion,Worsening motion through the nonunion site,
...,...,...,...,...,...,...,...,...,...,...,...
50421,98004,True,Hypovolaemic shock treatment,Haemodynamic instability and hypovolaemic shock,,,,,To maintain blood pressure and treat shock,Required large doses of vasopressor and blood ...,Treatment given after becoming haemodynamicall...
50422,133320,True,Systemic chemotherapy,Lung and bone metastases,,,,,To treat lung and bone metastases,,Using doxorubicin and ifosfamide
50423,97973,True,Rapid sequence intubation,Cardiogenic shock and flash pulmonary edema,,,,,To manage suspected cardiogenic shock and flas...,,
50424,97973,True,Advanced cardiac life support (ACLS) protocol,Cardiac arrest,,13 min,,,To restore return of spontaneous circulation a...,Return of spontaneous circulation was restored,


In [114]:
def create_treatment_text(row):
    """Create comprehensive text from treatment row"""
    parts = []
    
    if pd.notna(row.get('name')) and str(row['name']) != 'NaN':
        parts.append(f"Treatment: {row['name']}")
    
    if pd.notna(row.get('related condition')) and str(row['related condition']) != 'NaN':
        parts.append(f"for {row['related condition']}")
    
    if pd.notna(row.get('dosage')) and str(row['dosage']) != 'NaN':
        parts.append(f"dosage {row['dosage']}")
    
    if pd.notna(row.get('frequency')) and str(row['frequency']) != 'NaN':
        parts.append(f"frequency {row['frequency']}")
    
    if pd.notna(row.get('time')) and str(row['time']) != 'NaN':
        parts.append(f"time {row['time']}")
    
    if pd.notna(row.get('duration')) and str(row['duration']) != 'NaN':
        parts.append(f"duration {row['duration']}")
    
    if pd.notna(row.get('reason for taking')) and str(row['reason for taking']) != 'NaN':
        parts.append(f"reason: {row['reason for taking']}")
    
    if pd.notna(row.get('reaction to treatment')) and str(row['reaction to treatment']) != 'NaN':
        parts.append(f"reaction: {row['reaction to treatment']}")
    
    if pd.notna(row.get('details')) and str(row['details']) != 'NaN':
        parts.append(f"details: {row['details']}")
    
    return " ".join(parts)

In [None]:
if __name__ == "__main__":
    # Use the global TemporalStandardizer class 
    temporal_standardizer = TemporalStandardizer()

    # Process temporal information using the standardize_temporal_column method
    df_treatments_processed, temporal_report = temporal_standardizer.standardize_temporal_column(
        df_treatments, 
        'time'
    )
    
    # Process duration column as well
    df_treatments_processed, duration_report = temporal_standardizer.standardize_temporal_column(
        df_treatments_processed, 
        'duration'
    )

    # Ensure original idx is preserved in the processed dataframe
    if ('idx' in df_treatments.columns) and ('idx' not in df_treatments_processed.columns):
        df_treatments_processed['idx'] = df_treatments['idx']

    print(f"\nTemporal extraction results:")
    print(f"Duration extracted from 'time' column: {temporal_report['extracted_durations']} rows")
    print(f"Temporal types: {temporal_report['temporal_types']}")

    # Combine treatments information into comprehensive text
    df_treatments_processed['combined_text'] = df_treatments_processed.apply(
        lambda row: create_treatment_text(row),
        axis=1
    )

    # Run NER extraction with BC5CDR
    df_treatments_entities, treatments_summary, treatments_rules = run_medical_ner_extraction(
        df_treatments_processed,
        text_column='combined_text',
        model_name="en_ner_bc5cdr_md",
        batch_size=300,
        id_column='idx'
    )

    print(f"\n=== BC5CDR Entity Types Found ===")
    print(treatments_summary['entity_types'])

    # Custom entity extraction for treatment-specific entities (keeps correct row_idx)
    print("\n=== CUSTOM TREATMENT ENTITY EXTRACTION ===")

    def extract_treatment_entities_custom(df_treatments):
        """Extract treatment-specific entities not caught by BC5CDR.
           Uses 'idx' for row identity when available; falls back to df.index otherwise.
           Adds 'row_idx' to every extracted entity.
        """
        custom_entities = []
        
        # Treatment type patterns
        treatment_type_patterns = {
            'medication': ['tablet', 'tablets', 'pill', 'pills', 'capsule', 'injection', 'infusion', 'medication', 'drug'],
            'procedure': ['surgery', 'surgical', 'operation', 'procedure', 'therapy', 'intubation', 'resection', 'removal', 'repair'],
            'supportive': ['support', 'life support', 'ventilation', 'oxygen', 'fluid', 'nutrition', 'acls', 'protocol'],
            'chemotherapy': ['chemotherapy', 'chemo', 'cytotoxic', 'antineoplastic', 'systemic chemotherapy'],
            'conservative': ['conservative', 'non-operative', 'non-surgical', 'observation', 'closed treatment'],
            'emergency': ['emergency', 'urgent', 'rapid', 'resuscitation', 'life-saving'],
            'diagnostic': ['biopsy', 'exploration', 'diagnostic', 'assessment']
        }
        
        # Drug/medication patterns
        medication_patterns = {
            'antipsychotic': ['olanzapine', 'risperidone', 'quetiapine', 'haloperidol', 'clozapine'],
            'muscle_relaxant': ['trihexyphenidyl', 'baclofen', 'tizanidine', 'cyclobenzaprine'],
            'antibiotic': ['nafcillin', 'vancomycin', 'ceftriaxone', 'penicillin', 'amoxicillin'],
            'cardiovascular': ['vasopressor', 'inotrope', 'beta blocker', 'ace inhibitor', 'hypovolaemic'],
            'analgesic': ['morphine', 'fentanyl', 'oxycodone', 'acetaminophen', 'ibuprofen']
        }
        
        # Condition patterns (treated conditions)
        condition_patterns = {
            'psychiatric': ['bipolar', 'affective disorder', 'psychosis', 'mania', 'depression', 'mental illness'],
            'orthopedic': ['fracture', 'joint', 'bone', 'hip', 'knee', 'spine', 'ulnar shaft'],
            'cardiovascular': ['cardiac', 'heart', 'arrhythmia', 'shock', 'arrest', 'hypovolaemic', 'endocarditis'],
            'oncological': ['cancer', 'metastases', 'tumor', 'malignancy', 'carcinoma'],
            'neurological': ['rigidity', 'tremor', 'paralysis', 'nerve', 'neurological'],
            'infectious': ['infection', 'sepsis', 'endocarditis', 'abscess', 'pneumonia']
        }
        
        # Dosage unit patterns
        dosage_patterns = {
            'mg': r'\b(\d+\.?\d*)\s*mg\b',
            'mcg': r'\b(\d+\.?\d*)\s*mcg\b',
            'units': r'\b(\d+\.?\d*)\s*units?\b',
            'ml': r'\b(\d+\.?\d*)\s*ml\b',
            'per_day': r'\b(?:per\s*day|/day|daily)\b',
            'min': r'\b(\d+\.?\d*)\s*min(?:utes?)?\b'
        }
        
        # Frequency patterns
        frequency_terms = ['daily', 'twice daily', 'three times', 'qid', 'bid', 'tid', 'prn', 
                           'as needed', 'every', 'once', 'continuous', 'intermittent']
        
        # Route of administration patterns
        route_terms = ['oral', 'intravenous', 'iv', 'im', 'intramuscular', 'subcutaneous', 
                       'topical', 'inhaled', 'nasal', 'rectal']
        
        # Treatment response/outcome patterns
        response_patterns = {
            'positive': ['good response', 'improved', 'resolved', 'successful', 'effective', 
                         'restored', 'return of', 'recovered'],
            'negative': ['worsening', 'failed', 'no response', 'adverse', 'side effect', 
                         'complication', 'deterioration'],
            'neutral': ['no change', 'stable', 'maintained', 'sustained', 'continued'],
            'partial': ['partial response', 'some improvement', 'limited response']
        }
        
        # Temporal patterns specific to treatments
        temporal_patterns = {
            'acute': ['acute', 'sudden', 'rapid', 'emergency', 'immediate'],
            'chronic': ['chronic', 'long-term', 'maintenance', 'ongoing', 'continuous'],
            'perioperative': ['preoperative', 'postoperative', 'intraoperative', 'perioperative'],
            'duration': ['months', 'weeks', 'days', 'hours', 'years']
        }
        
        for df_index, row in df_treatments.iterrows():
            # ---- choose the correct row id from 'idx' with fallback to the DataFrame index
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

            # Get relevant text fields with safe string conversion
            name_orig = str(row.get('name', ''))
            name = name_orig.lower() if pd.notna(row.get('name')) else ''
            condition_orig = str(row.get('related condition', ''))
            condition = condition_orig.lower() if pd.notna(row.get('related condition')) else ''
            dosage_orig = str(row.get('dosage', ''))
            dosage = dosage_orig.lower() if pd.notna(row.get('dosage')) else ''
            frequency_orig = str(row.get('frequency', ''))
            frequency = frequency_orig.lower() if pd.notna(row.get('frequency')) else ''
            time_orig = str(row.get('time', ''))
            time_l = time_orig.lower() if pd.notna(row.get('time')) else ''
            duration_orig = str(row.get('duration', ''))
            duration_l = duration_orig.lower() if pd.notna(row.get('duration')) else ''
            reason_orig = str(row.get('reason for taking', ''))
            reason = reason_orig.lower() if pd.notna(row.get('reason for taking')) else ''
            reaction_orig = str(row.get('reaction to treatment', ''))
            reaction = reaction_orig.lower() if pd.notna(row.get('reaction to treatment')) else ''
            details_orig = str(row.get('details', ''))
            details = details_orig.lower() if pd.notna(row.get('details')) else ''
            combined_orig = str(row.get('combined_text', ''))
            combined = combined_orig.lower()

            # Primary TREATMENT from name column (position in combined if possible)
            if name and name != 'nan':
                start = combined.find(name)
                if start < 0:
                    start = 0
                end = min(start + len(name_orig), len(combined_orig))
                custom_entities.append({
                    'text': name_orig,
                    'label': 'TREATMENT',
                    'category': 'primary_treatment',
                    'start': start,
                    'end': end,
                    'original_text': combined_orig if combined_orig else name_orig,
                    'source': 'name_column',
                    'row_idx': row_id
                })
            
            # CONDITION (indication)
            if condition and condition != 'nan':
                start = combined.find(condition)
                if start < 0:
                    start = 0
                end = min(start + len(condition_orig), len(combined_orig))
                custom_entities.append({
                    'text': condition_orig,
                    'label': 'CONDITION',
                    'category': 'treatment_indication',
                    'start': start,
                    'end': end,
                    'original_text': combined_orig if combined_orig else condition_orig,
                    'source': 'condition_column',
                    'row_idx': row_id
                })
            
            # TREATMENT_TYPE (search in combined)
            for category, terms in treatment_type_patterns.items():
                for term in terms:
                    pattern = r'\b' + re.escape(term) + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': combined_orig[m.start():m.end()],
                            'label': 'TREATMENT_TYPE',
                            'category': category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': combined_orig,
                            'source': 'pattern_matching',
                            'row_idx': row_id
                        })
            
            # MEDICATION names/classes (search in combined)
            for drug_class, drugs in medication_patterns.items():
                for drug in drugs:
                    pattern = r'\b' + re.escape(drug) + r'\b'
                    for m in re.finditer(pattern, combined):
                        custom_entities.append({
                            'text': combined_orig[m.start():m.end()],
                            'label': 'MEDICATION',
                            'category': drug_class,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': combined_orig,
                            'source': 'medication_pattern',
                            'row_idx': row_id
                        })
            
            # DOSAGE (from dosage column)
            if dosage and dosage != 'nan':
                for unit, pattern in dosage_patterns.items():
                    for match in re.finditer(pattern, dosage, re.IGNORECASE):
                        custom_entities.append({
                            'text': match.group(0),
                            'label': 'DOSAGE',
                            'category': unit,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': dosage_orig,
                            'source': 'dosage_column',
                            'row_idx': row_id
                        })
            
            # FREQUENCY (from frequency column)
            if frequency and frequency != 'nan':
                for freq_term in frequency_terms:
                    pattern = r'\b' + re.escape(freq_term) + r'\b'
                    for m in re.finditer(pattern, frequency):
                        custom_entities.append({
                            'text': frequency_orig[m.start():m.end()],
                            'label': 'FREQUENCY',
                            'category': 'dosing_frequency',
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': frequency_orig,
                            'source': 'frequency_column',
                            'row_idx': row_id
                        })
            
            # ROUTE (search across combined context)
            for route in route_terms:
                pattern = r'\b' + re.escape(route) + r'\b'
                for m in re.finditer(pattern, combined):
                    custom_entities.append({
                        'text': combined_orig[m.start():m.end()],
                        'label': 'ROUTE',
                        'category': 'administration_route',
                        'start': m.start(),
                        'end': m.end(),
                        'original_text': combined_orig,
                        'source': 'route_extraction',
                        'row_idx': row_id
                    })
            
            # TREATMENT_RESPONSE (reaction + details + combined)
            response_text_orig = ' '.join([reaction_orig, details_orig, combined_orig])
            response_text = response_text_orig.lower()
            for response_type, patterns in response_patterns.items():
                for phrase in patterns:
                    pattern = r'\b' + re.escape(phrase) + r'\b'
                    for m in re.finditer(pattern, response_text):
                        custom_entities.append({
                            'text': response_text_orig[m.start():m.end()],
                            'label': 'TREATMENT_RESPONSE',
                            'category': response_type,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': response_text_orig,
                            'source': 'response_extraction',
                            'row_idx': row_id
                        })
            
            # CONDITION_TYPE categories (condition or reason)
            for condition_type, condition_terms in condition_patterns.items():
                for term in condition_terms:
                    pattern = r'\b' + re.escape(term) + r'\b'
                    # search in condition text first, then reason, else combined
                    found = False
                    for text_orig, text_l, source in [
                        (condition_orig, condition, 'condition_pattern'),
                        (reason_orig, reason, 'reason_pattern'),
                        (combined_orig, combined, 'combined_pattern'),
                    ]:
                        m = re.search(pattern, text_l)
                        if m:
                            custom_entities.append({
                                'text': text_orig[m.start():m.end()],
                                'label': 'CONDITION_TYPE',
                                'category': condition_type,
                                'start': m.start(),
                                'end': m.end(),
                                'original_text': text_orig,
                                'source': source,
                                'row_idx': row_id
                            })
                            found = True
                            break
            
            # TEMPORAL_PATTERN
            temporal_text_orig = ' '.join([time_orig, duration_orig, details_orig])
            temporal_text = temporal_text_orig.lower()
            for temp_category, temp_terms in temporal_patterns.items():
                for term in temp_terms:
                    pattern = r'\b' + re.escape(term) + r'\b'
                    for m in re.finditer(pattern, temporal_text):
                        custom_entities.append({
                            'text': temporal_text_orig[m.start():m.end()],
                            'label': 'TEMPORAL_PATTERN',
                            'category': temp_category,
                            'start': m.start(),
                            'end': m.end(),
                            'original_text': temporal_text_orig,
                            'source': 'temporal_extraction',
                            'row_idx': row_id
                        })
            
            # TREATMENT_REASON (from reason column specific patterns)
            if reason and reason != 'nan':
                reason_patterns = [
                    (r'to\s+treat\s+(\w+(?:\s+\w+)*)', 'treatment_goal'),
                    (r'to\s+manage\s+(\w+(?:\s+\w+)*)', 'management_goal'),
                    (r'control\s+of\s+(\w+(?:\s+\w+)*)', 'control_goal'),
                    (r'for\s+(\w+(?:\s+\w+)*)', 'indication')
                ]
                for pattern, category in reason_patterns:
                    for match in re.finditer(pattern, reason, re.IGNORECASE):
                        custom_entities.append({
                            'text': match.group(0),
                            'label': 'TREATMENT_REASON',
                            'category': category,
                            'start': match.start(),
                            'end': match.end(),
                            'original_text': reason_orig,
                            'source': 'reason_column',
                            'row_idx': row_id
                        })
        
        return pd.DataFrame(custom_entities)

    # Extract custom entities
    df_custom_treatments_entities = extract_treatment_entities_custom(df_treatments_processed)

    # Combine all entities (preserve row_idx/category when present)
    df_all_treatments_entities = pd.concat(
        [df_treatments_entities, df_custom_treatments_entities],
        ignore_index=True,
        sort=False
    )

    print(f"\n=== COMBINED Entity Distribution ===")
    if not df_all_treatments_entities.empty and 'label' in df_all_treatments_entities.columns:
        print(df_all_treatments_entities['label'].value_counts())
    else:
        print("No combined entities to show.")

    # Analyze treatment entities
    treatment_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not treatment_entities.empty:
        print(f"\n=== Top Treatments ===")
        print(treatment_entities['text'].str.lower().value_counts().head(20))

    # Analyze medications
    medication_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'MEDICATION'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not medication_entities.empty:
        print(f"\n=== Top Medications ===")
        print(medication_entities['text'].str.lower().value_counts().head(20))

    # Analyze conditions being treated
    condition_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'CONDITION'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not condition_entities.empty:
        print(f"\n=== Top Conditions Treated ===")
        print(condition_entities['text'].str.lower().value_counts().head(20))

    # Analyze dosages
    dosage_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'DOSAGE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not dosage_entities.empty:
        print(f"\n=== Dosage Distribution ===")
        print(dosage_entities['text'].str.lower().value_counts().head(15))

    # Analyze treatment responses
    response_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT_RESPONSE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not response_entities.empty:
        print(f"\n=== Treatment Response Distribution ===")
        print(response_entities['category'].value_counts())

    # Analyze treatment types
    treatment_type_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'TREATMENT_TYPE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not treatment_type_entities.empty:
        print(f"\n=== Treatment Type Distribution ===")
        print(treatment_type_entities['category'].value_counts())

    # Create treatment-specific labeling functions (unchanged logic)
    print("\n=== Creating Treatment-Specific Labeling Functions ===")

    def create_treatment_labeling_functions():
        """Create labeling functions for treatment patterns"""

        def lf_has_treatment(row):
            if row.get('has_treatments') is True:
                return 'HAS_TREATMENT'
            elif row.get('has_treatments') is False:
                return 'NO_TREATMENT'
            return 'ABSTAIN'

        def lf_medication_treatment(row):
            name = str(row.get('name', '')).lower()
            dosage = str(row.get('dosage', '')).lower()
            medication_terms = ['tablet', 'tablets', 'pill', 'mg', 'mcg', 'capsule', 'injection', 'infusion']
            if any(term in (name + " " + dosage) for term in medication_terms):
                return 'MEDICATION_TREATMENT'
            return 'ABSTAIN'

        def lf_surgical_treatment(row):
            name = str(row.get('name', '')).lower()
            details = str(row.get('details', '')).lower()
            surgical_terms = ['surgery', 'surgical', 'operation', 'resection', 'removal', 'repair', 'intubation']
            if any(term in (name + " " + details) for term in surgical_terms):
                return 'SURGICAL_TREATMENT'
            return 'ABSTAIN'

        def lf_emergency_treatment(row):
            name = str(row.get('name', '')).lower()
            condition = str(row.get('related condition', '')).lower()
            details = str(row.get('details', '')).lower()
            emergency_terms = ['emergency', 'urgent', 'cardiac arrest', 'shock', 'life support', 'acls', 'resuscitation', 'rapid sequence']
            if any(term in (name + " " + condition + " " + details) for term in emergency_terms):
                return 'EMERGENCY_TREATMENT'
            return 'ABSTAIN'

        def lf_cancer_treatment(row):
            name = str(row.get('name', '')).lower()
            condition = str(row.get('related condition', '')).lower()
            reason = str(row.get('reason for taking', '')).lower()
            cancer_terms = ['chemotherapy', 'cancer', 'metastases', 'tumor', 'oncology', 'malignant', 'carcinoma']
            if any(term in (name + " " + condition + " " + reason) for term in cancer_terms):
                return 'CANCER_TREATMENT'
            return 'ABSTAIN'

        def lf_psychiatric_treatment(row):
            name = str(row.get('name', '')).lower()
            condition = str(row.get('related condition', '')).lower()
            psych_drugs = ['olanzapine', 'risperidone', 'haloperidol', 'quetiapine', 'trihexyphenidyl']
            psych_conditions = ['bipolar', 'psychosis', 'mania', 'depression', 'anxiety', 'affective disorder']
            if any(drug in name for drug in psych_drugs) or any(cond in condition for cond in psych_conditions):
                return 'PSYCHIATRIC_TREATMENT'
            return 'ABSTAIN'

        def lf_chronic_treatment(row):
            time = str(row.get('time', '')).lower()
            duration = str(row.get('duration', '')).lower()
            chronic_indicators = ['months', 'years', 'chronic', 'long-term', 'maintenance', 'past four months']
            if any(indicator in (time + " " + duration) for indicator in chronic_indicators):
                return 'CHRONIC_TREATMENT'
            return 'ABSTAIN'

        def lf_daily_medication(row):
            frequency = str(row.get('frequency', '')).lower()
            if any(term in frequency for term in ['daily', 'every day', 'per day']):
                return 'DAILY_MEDICATION'
            return 'ABSTAIN'

        def lf_positive_response(row):
            reaction = str(row.get('reaction to treatment', '')).lower()
            details = str(row.get('details', '')).lower()
            positive_terms = ['good response', 'improved', 'resolved', 'successful', 'effective', 'restored', 'return of', 'recovered']
            if any(term in (reaction + " " + details) for term in positive_terms):
                return 'POSITIVE_RESPONSE'
            return 'ABSTAIN'

        def lf_conservative_treatment(row):
            name = str(row.get('name', '')).lower()
            if any(t in name for t in ['conservative', 'non-operative', 'closed treatment']):
                return 'CONSERVATIVE_TREATMENT'
            return 'ABSTAIN'

        def lf_infection_treatment(row):
            name = str(row.get('name', '')).lower()
            condition = str(row.get('related condition', '')).lower()
            reason = str(row.get('reason for taking', '')).lower()
            infection_terms = ['antibiotic', 'infection', 'endocarditis', 'sepsis', 'nafcillin', 'antimicrobial']
            if any(term in (name + " " + condition + " " + reason) for term in infection_terms):
                return 'INFECTION_TREATMENT'
            return 'ABSTAIN'

        def lf_cardiovascular_treatment(row):
            name = str(row.get('name', '')).lower()
            condition = str(row.get('related condition', '')).lower()
            cardio_terms = ['cardiac', 'heart', 'hypovolaemic', 'shock', 'arrest', 'vasopressor', 'arrhythmia']
            if any(term in (name + " " + condition) for term in cardio_terms):
                return 'CARDIOVASCULAR_TREATMENT'
            return 'ABSTAIN'

        def lf_treatment_duration(row):
            # Prefer standardized duration if present
            duration_days = row.get('duration_duration_days')
            if pd.isna(duration_days):
                duration_days = row.get('time_duration_days')
            if pd.notna(duration_days):
                try:
                    d = float(duration_days)
                except Exception:
                    return 'ABSTAIN'
                if d <= 7:
                    return 'SHORT_TERM_TREATMENT'
                elif d <= 30:
                    return 'MEDIUM_TERM_TREATMENT'
                else:
                    return 'LONG_TERM_TREATMENT'
            return 'ABSTAIN'

        return [
            lf_has_treatment, lf_medication_treatment, lf_surgical_treatment,
            lf_emergency_treatment, lf_cancer_treatment, lf_psychiatric_treatment,
            lf_chronic_treatment, lf_daily_medication, lf_positive_response,
            lf_conservative_treatment, lf_infection_treatment, lf_cardiovascular_treatment,
            lf_treatment_duration
        ]

    # Test labeling functions
    treatment_lfs = create_treatment_labeling_functions()

    print("\nTesting treatment labeling functions:")
    for i in range(min(10, len(df_treatments_processed))):
        row = df_treatments_processed.iloc[i]
        treatment_name = row.get('name', 'No treatment')
        condition = row.get('related condition', 'No condition')
        print(f"\nRow {i}: {treatment_name} for {condition}")
        for lf in treatment_lfs:
            result = lf(row)
            if result != 'ABSTAIN':
                print(f"  {lf.__name__}: {result}")

    # Extract treatment relations (use correct row_idx)
    print("\n=== Extracting Treatment Relations ===")

    def extract_treatment_relations(df_treatments):
        """Extract relations between treatments, conditions, and outcomes.
           Uses 'idx' for row identity when available; falls back to df.index.
        """
        relations = []

        for df_index, row in df_treatments.iterrows():
            row_id = row['idx'] if ('idx' in row and pd.notna(row['idx'])) else df_index

            treatment = row.get('name')
            condition = row.get('related condition')
            dosage = row.get('dosage')
            frequency = row.get('frequency')
            reaction = row.get('reaction to treatment')
            reason = row.get('reason for taking')

            # TREATMENT_FOR_CONDITION
            if pd.notna(treatment) and pd.notna(condition) and str(condition).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_FOR_CONDITION',
                    'treatment': treatment,
                    'condition': condition,
                    'row_idx': row_id
                })

            # TREATMENT_HAS_DOSAGE
            if pd.notna(treatment) and pd.notna(dosage) and str(dosage).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_HAS_DOSAGE',
                    'treatment': treatment,
                    'dosage': dosage,
                    'row_idx': row_id
                })

            # TREATMENT_HAS_FREQUENCY
            if pd.notna(treatment) and pd.notna(frequency) and str(frequency).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_HAS_FREQUENCY',
                    'treatment': treatment,
                    'frequency': frequency,
                    'row_idx': row_id
                })

            # TREATMENT_HAS_RESPONSE
            if pd.notna(treatment) and pd.notna(reaction) and str(reaction).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_HAS_RESPONSE',
                    'treatment': treatment,
                    'response': reaction,
                    'row_idx': row_id
                })

            # TREATMENT_HAS_DURATION (standardized days)
            duration_days = row.get('duration_duration_days') or row.get('time_duration_days')
            if pd.notna(treatment) and pd.notna(duration_days):
                relations.append({
                    'type': 'TREATMENT_HAS_DURATION',
                    'treatment': treatment,
                    'duration_days': duration_days,
                    'row_idx': row_id
                })

            # TREATMENT_HAS_REASON
            if pd.notna(treatment) and pd.notna(reason) and str(reason).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_HAS_REASON',
                    'treatment': treatment,
                    'reason': reason,
                    'row_idx': row_id
                })

            # TREATMENT_TEMPORAL_PATTERN (raw time text)
            time_info = row.get('time', '')
            if pd.notna(treatment) and pd.notna(time_info) and str(time_info).lower() != 'nan':
                relations.append({
                    'type': 'TREATMENT_TEMPORAL_PATTERN',
                    'treatment': treatment,
                    'temporal_info': time_info,
                    'row_idx': row_id
                })

        return relations

    treatment_relations = extract_treatment_relations(df_treatments_processed)

    print(f"Found {len(treatment_relations)} treatment relations")
    if treatment_relations:
        relation_types = pd.DataFrame(treatment_relations)['type'].value_counts()
        print("\nRelation type distribution:")
        print(relation_types)

        # Sample relations
        print("\nSample relations:")
        for rel in treatment_relations[:10]:
            if rel['type'] == 'TREATMENT_FOR_CONDITION':
                print(f"  {rel['treatment']} -> treats -> {rel['condition']} (row {rel['row_idx']})")
            elif rel['type'] == 'TREATMENT_HAS_DOSAGE':
                print(f"  {rel['treatment']} -> dosage -> {rel['dosage']} (row {rel['row_idx']})")
            elif rel['type'] == 'TREATMENT_HAS_RESPONSE':
                print(f"  {rel['treatment']} -> response -> {rel['response']} (row {rel['row_idx']})")

    # Coverage analysis
    print("\n=== Labeling Function Coverage Analysis ===")
    coverage_results = {}
    for lf in treatment_lfs:
        labeled = sum(1 for _, row in df_treatments_processed.iterrows() if lf(row) != 'ABSTAIN')
        coverage = (labeled / len(df_treatments_processed)) * 100 if len(df_treatments_processed) > 0 else 0
        coverage_results[lf.__name__] = {'labeled': labeled, 'coverage': coverage}

    print("\nLabeling function coverage:")
    for lf_name, stats in sorted(coverage_results.items(), key=lambda x: x[1]['coverage'], reverse=True):
        print(f"  {lf_name}: {stats['labeled']} labels ({stats['coverage']:.1f}% coverage)")

    # Additional analysis for treatments
    print("\n=== Additional Treatment Analysis ===")
    
    if 'has_treatments' in df_treatments_processed.columns:
        print("\nHas treatments distribution:")
        print(df_treatments_processed['has_treatments'].value_counts())
    
    if 'frequency' in df_treatments_processed.columns and not df_all_treatments_entities.empty:
        freq_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'FREQUENCY']
        if not freq_entities.empty:
            print("\n=== Frequency Distribution ===")
            print(freq_entities['text'].str.lower().value_counts().head(10))
    
    route_entities = df_all_treatments_entities[df_all_treatments_entities['label'] == 'ROUTE'] if not df_all_treatments_entities.empty else pd.DataFrame()
    if not route_entities.empty:
        print("\n=== Route of Administration ===")
        print(route_entities['text'].str.lower().value_counts())

    # Save results
    df_all_treatments_entities.to_csv('treatment_entities_comprehensive.csv', index=False)
    pd.DataFrame(treatment_relations).to_csv('treatment_relations.csv', index=False)

    print("\n\nTreatment entity extraction and labeling complete!")
    print(f"Saved {len(df_all_treatments_entities)} entities and {len(treatment_relations)} relations")



Processing temporal column: time

Processing temporal column: duration

Temporal extraction results:
Duration extracted from 'time' column: 5037 rows
Temporal types: {'unspecified': 6645, 'post_event': 5777, 'range_reference': 2315, 'past_reference': 2193, 'onset_reference': 2007, 'absolute_date': 1267, 'duration_reference': 914}
Loading en_ner_bc5cdr_md...
Model loaded. Active pipes: ['tok2vec', 'ner']
Stamping row identifier from column: 'idx'
Processing 50426 texts in 169 batches...
Using model: en_ner_bc5cdr_md for column: combined_text


Processing batches:   1%|          | 1/169 [00:02<06:22,  2.28s/it]


Checkpoint saved at batch 0


Processing batches:   7%|▋         | 11/169 [00:15<03:36,  1.37s/it]


Checkpoint saved at batch 3000


Processing batches:  12%|█▏        | 21/169 [00:28<03:24,  1.38s/it]


Checkpoint saved at batch 6000


Processing batches:  18%|█▊        | 31/169 [00:42<03:12,  1.39s/it]


Checkpoint saved at batch 9000


Processing batches:  24%|██▍       | 41/169 [00:55<02:57,  1.39s/it]


Checkpoint saved at batch 12000


Processing batches:  30%|███       | 51/169 [01:08<02:41,  1.37s/it]


Checkpoint saved at batch 15000


Processing batches:  36%|███▌      | 61/169 [01:21<02:25,  1.35s/it]


Checkpoint saved at batch 18000


Processing batches:  42%|████▏     | 71/169 [01:34<02:18,  1.41s/it]


Checkpoint saved at batch 21000


Processing batches:  48%|████▊     | 81/169 [01:47<02:03,  1.40s/it]


Checkpoint saved at batch 24000


Processing batches:  54%|█████▍    | 91/169 [02:00<01:51,  1.43s/it]


Checkpoint saved at batch 27000


Processing batches:  60%|█████▉    | 101/169 [02:14<01:39,  1.46s/it]


Checkpoint saved at batch 30000


Processing batches:  66%|██████▌   | 111/169 [02:27<01:23,  1.44s/it]


Checkpoint saved at batch 33000


Processing batches:  72%|███████▏  | 121/169 [02:40<01:10,  1.46s/it]


Checkpoint saved at batch 36000


Processing batches:  78%|███████▊  | 131/169 [02:53<00:56,  1.50s/it]


Checkpoint saved at batch 39000


Processing batches:  83%|████████▎ | 141/169 [03:07<00:42,  1.50s/it]


Checkpoint saved at batch 42000


Processing batches:  89%|████████▉ | 151/169 [03:21<00:27,  1.52s/it]


Checkpoint saved at batch 45000


Processing batches:  95%|█████████▌| 161/169 [03:34<00:12,  1.55s/it]


Checkpoint saved at batch 48000


Processing batches: 100%|██████████| 169/169 [03:43<00:00,  1.32s/it]


Found 3696 entities appearing >= 5 times

=== BC5CDR Entity Types Found ===
{'DISEASE': 80940, 'CHEMICAL': 37907}

=== CUSTOM SYMPTOM ENTITY EXTRACTION ===

=== COMBINED Entity Distribution ===
label
DISEASE               80940
TREATMENT             44086
CONDITION             43157
CHEMICAL              37907
TREATMENT_TYPE        23652
ROUTE                 18590
TREATMENT_REASON      16765
CONDITION_TYPE        13410
DOSAGE                12320
TEMPORAL_PATTERN      12264
TREATMENT_RESPONSE     7463
FREQUENCY              6705
MEDICATION             1921
Name: count, dtype: int64

=== Top Treatments ===
text
Antibiotics                502
Chemotherapy               499
Surgery                    272
Conservative treatment     271
Aspirin                    268
Surgical excision          263
Blood transfusion          251
Conservative management    244
Prednisone                 206
Surgical resection         170
Intravenous antibiotics    159
Radiotherapy               153
Adjuvant 

In [116]:
df_all_treatments_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category,source
0,Olanzapine,CHEMICAL,11,21,Treatment: Olanzapine tablets for Bipolar affe...,155216,combined_text,,
1,Bipolar affective disorder,DISEASE,34,60,Treatment: Olanzapine tablets for Bipolar affe...,155216,combined_text,,
2,mental illness reaction,DISEASE,150,173,Treatment: Olanzapine tablets for Bipolar affe...,155216,combined_text,,
3,Pain,DISEASE,175,179,Treatment: Olanzapine tablets for Bipolar affe...,155216,combined_text,,
4,olanzapine,CHEMICAL,326,336,Treatment: Olanzapine tablets for Bipolar affe...,155216,combined_text,,
...,...,...,...,...,...,...,...,...,...
319175,endocarditis,CONDITION_TYPE,0,12,endocarditis,50425,,cardiovascular,condition_pattern
319176,infection,CONDITION_TYPE,24,33,to treat s.\nlugdunensis infection,50425,,infectious,condition_pattern
319177,endocarditis,CONDITION_TYPE,0,12,endocarditis,50425,,infectious,condition_pattern
319178,postoperative,TEMPORAL_PATTERN,0,13,postoperative patient was discharged home on ...,50425,,perioperative,temporal_extraction


### Extracting Info

In [41]:
df_info

Unnamed: 0,idx,age,sex
0,155216,Sixteen years old,Female
2,133948,36 years old,Female
3,80176,49,male
4,72232,47,Male
5,31864,24 years,Female
...,...,...,...
29995,39279,28,male
29996,137017,82,Male
29997,98004,54,Male
29998,133320,49,Woman


In [42]:
df_info['sex'].value_counts()

sex
Female                                                             10077
Male                                                                9974
male                                                                4588
Woman                                                               2486
female                                                               993
man                                                                  497
woman                                                                391
boy                                                                  190
Boy                                                                   78
Girl                                                                  77
girl                                                                  72
Man                                                                   68
Gentleman                                                              7
Trans man                                      

In [43]:
df_info['age'].value_counts()

age
62                                                                                  474
65                                                                                  465
35                                                                                  432
63                                                                                  432
45                                                                                  429
                                                                                   ... 
Initially 21 years old, 33 years old at last mention                                  1
29 at first admission, 55 at the time of the last mentioned clinical examination      1
18 yr old                                                                             1
37-years old                                                                          1
Almost three-year old                                                                 1
Name: count, Length: 1296, d

In [46]:
def standardize_sex(sex_str):
    """
    Standardize sex/gender values based on the variations in your data.
    
    Special cases handled:
    - Veterinary cases (neutered/castrated)
    - Multiple patients in one record
    - Trans individuals
    - Various capitalizations and terms
    """
    if pd.isna(sex_str):
        return None
    
    sex_str = str(sex_str).strip().lower()
    
    # Handle multiple patients first
    if 'both' in sex_str or 'second patient' in sex_str or 'patient case' in sex_str:
        return 'Multiple_Patients'
    
    # Map variations to standard values
    female_terms = ['female', 'woman', 'girl', 'lady']
    male_terms = ['male', 'man', 'boy', 'gentleman']
    
    # Check for trans individuals
    if 'trans' in sex_str:
        if 'trans man' in sex_str or 'transitioned to male' in sex_str:
            return 'Trans_Male'
        elif 'trans woman' in sex_str:
            return 'Trans_Female'
    
    # Check for assigned at birth
    if 'assigned female at birth' in sex_str:
        return 'AFAB'
    
    # Check for phenotype mentions
    if 'phenotype' in sex_str:
        if 'female' in sex_str:
            return 'Female_Phenotype'
    
    # Check for veterinary cases (neutered/castrated)
    if 'neutered' in sex_str or 'castrated' in sex_str or 'entire' in sex_str:
        if any(term in sex_str for term in female_terms):
            if 'neutered' in sex_str:
                return 'Female_Neutered'
            else:
                return 'Female_Intact'
        elif any(term in sex_str for term in male_terms):
            if 'neutered' in sex_str or 'castrated' in sex_str:
                return 'Male_Neutered'
            else:
                return 'Male_Intact'
    
    # Standard cases
    for term in female_terms:
        if term == sex_str:
            return 'Female'
    
    for term in male_terms:
        if term == sex_str:
            return 'Male'
    
    # If we can't classify, return as unclassified
    return 'Unclassified'


def standardize_sex_simple(sex_str):
    """
    Simplified version that maps to just Male/Female/Other categories
    """
    if pd.isna(sex_str):
        return None
    
    sex_str = str(sex_str).strip().lower()
    
    # Handle multiple patients
    if 'both' in sex_str or 'second patient' in sex_str or 'patient case' in sex_str:
        return 'Multiple_Records'
    
    # Simple mapping
    female_terms = ['female', 'woman', 'girl', 'lady', 'trans woman', 'female phenotype']
    male_terms = ['male', 'man', 'boy', 'gentleman', 'trans man']
    
    # Check for main terms
    for term in female_terms:
        if term in sex_str:
            return 'Female'
    
    for term in male_terms:
        if term in sex_str:
            return 'Male'
    
    return 'Other'


def extract_age_from_text(age_str):
    """
    Extract numeric age from various text formats in your data.
    
    Handles cases like:
    - Simple numbers: "62", "35"
    - Years old format: "18 yr old", "37-years old"
    - Written numbers: "Sixteen years old", "Almost three-year old"
    - Complex cases: "Initially 21 years old, 33 years old at last mention"
    - Age ranges: "29 at first admission, 55 at the time of the last mentioned clinical examination"
    """
    if pd.isna(age_str):
        return None
    
    age_str = str(age_str).strip()
    
    # First check if it's already a simple number
    if age_str.isdigit():
        return int(age_str)
    
    # Convert written numbers to digits
    written_numbers = {
        'one': 1, 'two': 2, 'three': 3, 'four': 4, 'five': 5,
        'six': 6, 'seven': 7, 'eight': 8, 'nine': 9, 'ten': 10,
        'eleven': 11, 'twelve': 12, 'thirteen': 13, 'fourteen': 14,
        'fifteen': 15, 'sixteen': 16, 'seventeen': 17, 'eighteen': 18,
        'nineteen': 19, 'twenty': 20, 'thirty': 30, 'forty': 40,
        'fifty': 50, 'sixty': 60, 'seventy': 70, 'eighty': 80,
        'ninety': 90
    }
    
    # Replace written numbers with digits
    age_str_lower = age_str.lower()
    for word, num in written_numbers.items():
        age_str_lower = age_str_lower.replace(word, str(num))
    
    # Handle compound written numbers (e.g., "twenty-one")
    age_str_lower = re.sub(r'(\d+)\s*-\s*(\d+)', lambda m: str(int(m.group(1)) + int(m.group(2))), age_str_lower)
    
    # Extract all numbers from the text
    numbers = re.findall(r'\d+', age_str_lower)
    
    if not numbers:
        return None
    
    # For multiple ages (patient history), typically want the first mentioned age
    # You might want to change this logic based on your needs
    if 'initially' in age_str_lower or 'first' in age_str_lower:
        # Return the first number
        return int(numbers[0])
    elif 'last' in age_str_lower or 'current' in age_str_lower:
        # Return the last number
        return int(numbers[-1])
    else:
        # Default to first number found
        return int(numbers[0])


def get_age_category(age):
    """
    Categorize age into standard medical categories
    """
    if pd.isna(age):
        return 'Unknown'
    
    if age < 2:
        return 'Infant'
    elif age < 13:
        return 'Child'
    elif age < 18:
        return 'Adolescent'
    elif age < 65:
        return 'Adult'
    else:
        return 'Elderly'


def standardize_demographics(df):
    """
    Apply all standardization to the dataframe
    """
    # Create a copy to avoid modifying original
    df_clean = df.copy()
    
    # Standardize sex - both detailed and simple versions
    df_clean['sex_detailed'] = df_clean['sex'].apply(standardize_sex)
    df_clean['sex_standardized'] = df_clean['sex'].apply(standardize_sex_simple)
    
    # Extract numeric age
    df_clean['age_numeric'] = df_clean['age'].apply(extract_age_from_text)
    
    # Add age category
    df_clean['age_category'] = df_clean['age_numeric'].apply(get_age_category)
    
    # Create flags for special cases
    df_clean['is_veterinary'] = df_clean['sex_detailed'].str.contains('Neutered|Intact', na=False)
    df_clean['is_multiple_patients'] = df_clean['sex_detailed'] == 'Multiple_Patients'
    df_clean['has_complex_age'] = df_clean['age'].str.contains('initially|first|last|mention', case=False, na=False)
    
    return df_clean


    
# Apply to your dataframe:
df_info_clean = standardize_demographics(df_info)

# View results:
print("\nStandardized Sex Distribution:")
print(df_info_clean['sex_standardized'].value_counts())
print("\nAge Distribution:")
print(df_info_clean['age_category'].value_counts())
print("\nSpecial Cases:")
print(f"Veterinary cases: {df_info_clean['is_veterinary'].sum()}")
print(f"Multiple patient records: {df_info_clean['is_multiple_patients'].sum()}")
print(f"Complex age descriptions: {df_info_clean['has_complex_age'].sum()}")


Standardized Sex Distribution:
sex_standardized
Male                15429
Female              14129
Multiple_Records        5
Other                   3
Name: count, dtype: int64

Age Distribution:
age_category
Adult         19376
Elderly        6153
Child          2566
Adolescent     1489
Infant           89
Unknown          82
Name: count, dtype: int64

Special Cases:
Veterinary cases: 13
Multiple patient records: 5
Complex age descriptions: 120


In [151]:
df_info_clean

Unnamed: 0,idx,age,sex,sex_detailed,sex_standardized,age_numeric,age_category,is_veterinary,is_multiple_patients,has_complex_age
0,155216,Sixteen years old,Female,Female,Female,6.0,Child,False,False,False
2,133948,36 years old,Female,Female,Female,36.0,Adult,False,False,False
3,80176,49,male,Male,Male,49.0,Adult,False,False,False
4,72232,47,Male,Male,Male,47.0,Adult,False,False,False
5,31864,24 years,Female,Female,Female,24.0,Adult,False,False,False
...,...,...,...,...,...,...,...,...,...,...
29995,39279,28,male,Male,Male,28.0,Adult,False,False,False
29996,137017,82,Male,Male,Male,82.0,Elderly,False,False,False
29997,98004,54,Male,Male,Male,54.0,Adult,False,False,False
29998,133320,49,Woman,Female,Female,49.0,Adult,False,False,False


In [152]:
# df_info_clean has: idx, age, sex, sex_detailed, sex_standardized, age_numeric, ...

info = df_info_clean.rename(columns={"idx":"row_idx"}).copy()

def age_variants(n: int):
    n = int(n)
    return [
        f"{n} years old", f"{n} year old",
        f"{n}-year-old", f"{n} yo", f"{n} y/o", f"aged {n}"
    ]

SEX_SYNONYMS = {
    "female": ["female", "woman", "female patient", "women"],
    "male":   ["male", "man", "male patient", "men"],
}

rows = []
for r in info.itertuples(index=False):
    rid = r.row_idx

    # 1) Use the raw age text if present
    if isinstance(r.age, str) and r.age.strip():
        rows.append({"row_idx": rid, "text": r.age.strip(), "label": "Age", "table": "info"})

    # 2) Generate common age variants from numeric
    if pd.notna(r.age_numeric):
        for t in age_variants(r.age_numeric):
            rows.append({"row_idx": rid, "text": t, "label": "Age", "table": "info"})

    # 3) Sex synonyms (driven by standardized sex if available)
    sex_std = None
    for s in (getattr(r, "sex_standardized", None), getattr(r, "sex_detailed", None), getattr(r, "sex", None)):
        if isinstance(s, str) and s.strip():
            sex_std = s.strip().lower()
            break
    if sex_std in SEX_SYNONYMS:
        for t in SEX_SYNONYMS[sex_std]:
            rows.append({"row_idx": nid, "text": t, "label": "Sex", "table": "info"})

        # 4) Age+sex combos (very common in narratives, e.g., "36-year-old female")
        if pd.notna(r.age_numeric):
            for av in age_variants(r.age_numeric):
                for sx in SEX_SYNONYMS[sex_std]:
                    rows.append({"row_idx": rid, "text": f"{av} {sx}", "label": "AgeSex", "table": "info"})

df_info_entities = pd.DataFrame(rows).drop_duplicates(["row_idx","text","label"])
print(df_info_entities.shape, df_info_entities.label.value_counts())


(910897, 4) label
AgeSex    707568
Age       203321
Sex            8
Name: count, dtype: int64


In [157]:
all_entities = pd.concat([
    df_all_physiological_entities.assign(table='physiological'),
    df_psychological_entities.assign(table='psychological'),
    df_vaccination_entities.assign(table='vaccination'),
    df_allergies_entities.assign(table='allergies'),
    df_drug_usage_entities.assign(table='drug_usage'),
    df_all_surgery_entities.assign(table='surgery'),
    df_all_symptom_entities.assign(table='symptoms'),
    df_all_diagnosis_entities.assign(table='diagnosis'),
    df_all_treatments_entities.assign(table='treatments'),
    df_info_entities.assign(table='info')
], ignore_index=True)

print(f"Total entities across all tables: {len(all_entities):,}")
print("\nEntity distribution by table:")
print(all_entities['table'].value_counts())
print("\nEntity types found:")
print(all_entities['label'].value_counts())

Total entities across all tables: 2,113,057

Entity distribution by table:
table
info             910897
diagnosis        334425
treatments       319180
symptoms         316046
surgery          179947
physiological     47382
psychological      3402
allergies           934
drug_usage          715
vaccination         129
Name: count, dtype: int64

Entity types found:
label
AgeSex                             707568
DISEASE                            281269
Age                                203321
ANATOMY                            125820
CONDITION                           79686
LATERALITY                          76962
SYMPTOM_TYPE                        56789
TEST                                56150
SYMPTOM                             53580
CHEMICAL                            51222
TEST_TYPE                           51197
TREATMENT                           44086
FINDING                             40916
MULTI_TISSUE_STRUCTURE              32237
TEMPORAL_PATTERN                    26

In [158]:
all_entities

Unnamed: 0,text,label,start,end,original_text,row_idx,source_column,category,table,source
0,posttraumatic arthritis,DISEASE,48.0,71.0,History of left elbow arthrodesis performed fo...,80176,physiological context,,physiological,
1,pain,DISEASE,116.0,120.0,"Inability to walk since babyhood, did not walk...",31864,physiological context,,physiological,
2,fracture,DISEASE,151.0,159.0,"Inability to walk since babyhood, did not walk...",31864,physiological context,,physiological,
3,Coxa vara deformity,DISEASE,0.0,19.0,"Coxa vara deformity of bilateral hips, bilater...",149866,physiological context,,physiological,
4,fracture,DISEASE,75.0,83.0,"Coxa vara deformity of bilateral hips, bilater...",149866,physiological context,,physiological,
...,...,...,...,...,...,...,...,...,...,...
2113052,31 y/o men,AgeSex,,,,97973,,,info,
2113053,aged 31 male,AgeSex,,,,97973,,,info,
2113054,aged 31 man,AgeSex,,,,97973,,,info,
2113055,aged 31 male patient,AgeSex,,,,97973,,,info,


In [159]:
all_entities.to_csv('all_entities.csv', index=False)