In [None]:
"""
CADEC Dataset Entity Analysis Script
This script processes the CADEC dataset to extract and count distinct entities
for each label type (ADR, Drug, Disease, Symptom)
"""


# STEP 1: IMPORT REQUIRED LIBRARIES


import os
from collections import defaultdict, Counter
import glob
from typing import Dict, Set, List, Tuple


# STEP 2: MOUNT GOOGLE DRIVE (FOR COLAB)


from google.colab import drive
try:
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")

# STEP 3: SET UP PATHS

BASE_PATH = '/content/drive/MyDrive/cadec'

# Define paths to each subdirectory
TEXT_DIR = os.path.join(BASE_PATH, 'text')
ORIGINAL_DIR = os.path.join(BASE_PATH, 'original')
SCT_DIR = os.path.join(BASE_PATH, 'sct')
MEDDRA_DIR = os.path.join(BASE_PATH, 'meddra')


# STEP 4: DEFINE ENTITY CLASS

class Entity:


    def __init__(self, entity_id: str, label: str, start: int,
                 end: int, text: str, filename: str):
        self.entity_id = entity_id
        self.label = label
        self.start = start
        self.end = end
        self.text = text
        self.filename = filename

    def __repr__(self):
        """String representation for easy printing"""
        return f"Entity(id={self.entity_id}, label={self.label}, text='{self.text}', file={self.filename})"

# STEP 5: DEFINE ANNOTATION PARSER

def parse_original_file(filepath: str) -> List[Entity]:


    entities = []
    filename = os.path.basename(filepath)


    parsed_count = 0


    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()

                if not line or line.startswith('#'):
                    continue

                try:
                    parts = line.split('\t')

                    if len(parts) < 3:

                        continue

                    entity_id = parts[0]
                    label_and_ranges_str = parts[1]
                    text_span = '\t'.join(parts[2:])

                    first_space_idx = label_and_ranges_str.find(' ')
                    if first_space_idx == -1:

                        continue

                    label = label_and_ranges_str[:first_space_idx]
                    range_str_with_semicolons = label_and_ranges_str[first_space_idx:].strip()

                    # Take only the first range if multiple are present (e.g., '0 17;18 20' -> '0 17')
                    individual_ranges = range_str_with_semicolons.split(';')
                    if not individual_ranges or not individual_ranges[0].strip():

                        continue

                    first_range = individual_ranges[0].strip()
                    start_end_parts = first_range.split()

                    if len(start_end_parts) < 2:
                        continue

                    start = int(start_end_parts[0])
                    end = int(start_end_parts[1])

                    # Create Entity object
                    entity = Entity(entity_id, label, start, end, text_span, filename)
                    entities.append(entity)
                    parsed_count += 1



                except (ValueError, IndexError) as e:
                    # print(f"    ERROR processing line {line_num} in {filename}: {e} - Line: '{line}'")
                    continue # Skip to the next line on error

    except FileNotFoundError:
        print(f"ERROR: File not found: {filepath}")
    except Exception as e:
        print(f"AN UNEXPECTED ERROR OCCURRED while reading {filepath}: {e}")



    return entities


# STEP 6: PROCESS ALL FILES


def process_all_files(original_dir: str) -> List[Entity]:


    all_entities = []  # Master list to store all entities

    # --- KEY CHANGE HERE
    # Now using '*.ann' pattern
    file_pattern = os.path.join(original_dir, '*.ann')
    annotation_files = glob.glob(file_pattern)

    print(f"\nFound {len(annotation_files)} annotation files to process in {original_dir} using pattern '{file_pattern}'")
    if not annotation_files:
        print("WARNING: No .ann files found. Double-check the actual filenames in your 'original' folder (e.g., '001.ann', '002.ann').")
        return []

    # Process each file
    for i, filepath in enumerate(annotation_files):
        entities = parse_original_file(filepath)
        all_entities.extend(entities)

        # Print progress every 100 files
        if (i + 1) % 100 == 0:
            print(f"Processed {i + 1}/{len(annotation_files)} files. Total entities found so far: {len(all_entities)}")

    print(f"✓ Completed processing all files")
    print(f"Total entities extracted: {len(all_entities)}")

    return all_entities


# STEP 7: ANALYZE ENTITIES BY LABEL TYPE

def analyze_entities(entities: List[Entity]) -> Dict[str, Set[str]]:

    label_to_entities = defaultdict(set)

    for entity in entities:
        label = entity.label
        text = entity.text.lower().strip() # Normalize text
        label_to_entities[label].add(text)

    return label_to_entities

# STEP 8: DISPLAY RESULTS

def display_results(label_to_entities: Dict[str, Set[str]]):


    print("\n" + "="*80)
    print("CADEC DATASET ENTITY ANALYSIS RESULTS")
    print("="*80)

    label_order = ['ADR', 'Drug', 'Disease', 'Symptom']
    total_distinct_overall = 0 # Renamed for clarity

    for label in label_order:
        entities = label_to_entities.get(label, set())
        count = len(entities)
        total_distinct_overall += count

        print(f"\n{'─'*80}")
        print(f"LABEL TYPE: {label}")
        print(f"{'─'*80}")
        print(f"Total Distinct Entities: {count}")
        print(f"\nAll Distinct {label} Entities:")
        print(f"{'-'*80}")

        sorted_entities = sorted(entities)
        if count == 0:
            print("  No distinct entities found for this label type.")
        else:
            for i, entity_text in enumerate(sorted_entities, 1):
                print(f"{i:4d}. {entity_text}")

    print(f"\n{'='*80}")
    print("SUMMARY STATISTICS")
    print(f"{'='*80}")
    for label in label_order:
        count = len(label_to_entities.get(label, set()))
        print(f"{label:10s}: {count:6d} distinct entities")
    print(f"{'-'*80}")
    print(f"{'TOTAL':10s}: {total_distinct_overall:6d} distinct entities (across all labels)")
    print(f"{'='*80}\n")


# STEP 9: CREATE SUMMARY STATISTICS


def create_statistics(entities: List[Entity],
                      label_to_entities: Dict[str, Set[str]]):


    print("\n" + "="*80)
    print("ADDITIONAL STATISTICS")
    print("="*80)

    label_counts = Counter(entity.label for entity in entities)

    print("\n1. Total Mentions (including duplicates):")
    print(f"{'-'*80}")
    for label in ['ADR', 'Drug', 'Disease', 'Symptom']:
        total_mentions = label_counts.get(label, 0)
        distinct_entities_count = len(label_to_entities.get(label, set()))
        avg_mentions = total_mentions / distinct_entities_count if distinct_entities_count > 0 else 0
        print(f"{label:10s}: {total_mentions:6d} mentions, "
              f"{distinct_entities_count:6d} distinct, "
              f"avg {avg_mentions:.2f} mentions per entity")

    print(f"\n2. Most Frequently Mentioned Entities:")
    print(f"{'-'*80}")

    for label in ['ADR', 'Drug', 'Disease', 'Symptom']:
        texts = [e.text.lower().strip() for e in entities if e.label == label]
        counter = Counter(texts)
        most_common = counter.most_common(10)

        if most_common:
            print(f"\n{label}:")
            for i, (text, count) in enumerate(most_common, 1):
                print(f"  {i:2d}. {text:40s} ({count} mentions)")
        else:
            print(f"\n{label}: No mentions found.")

    print(f"\n3. File Distribution:")
    print(f"{'-'*80}")

    for label in ['ADR', 'Drug', 'Disease', 'Symptom']:
        files_with_label = set(e.filename for e in entities if e.label == label)
        print(f"{label:10s}: appears in {len(files_with_label)} files")


# STEP 10: MAIN EXECUTION


def main():
    """
    Main function that orchestrates the entire analysis
    """

    print("="*80)
    print("CADEC DATASET ENTITY EXTRACTION AND ANALYSIS")
    print("="*80)

    # Step 1: Process all annotation files
    print("\nStep 1: Processing annotation files...")
    all_entities = process_all_files(ORIGINAL_DIR)

    # If no entities were found, stop execution early
    if not all_entities:
        print("\nNo entities were extracted. Please check the BASE_PATH and dataset integrity.")
        return [], {} # Return empty lists/dicts to prevent further errors

    # Step 2: Analyze entities by label type
    print("\nStep 2: Analyzing entities by label type...")
    label_to_entities = analyze_entities(all_entities)

    # Step 3: Display detailed results
    print("\nStep 3: Generating results...")
    display_results(label_to_entities)

    # Step 4: Create additional statistics
    create_statistics(all_entities, label_to_entities)

    return all_entities, label_to_entities



# Execute the main function
all_entities, label_to_entities = main()



Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully.
Checking if directories exist...
✓ text directory found with 1250 files at: /content/drive/MyDrive/cadec/text
✓ original directory found with 1250 files at: /content/drive/MyDrive/cadec/original
✓ sct directory found with 1250 files at: /content/drive/MyDrive/cadec/sct
✓ meddra directory found with 1250 files at: /content/drive/MyDrive/cadec/meddra
CADEC DATASET ENTITY EXTRACTION AND ANALYSIS

Step 1: Processing annotation files...

Found 1250 annotation files to process in /content/drive/MyDrive/cadec/original using pattern '/content/drive/MyDrive/cadec/original/*.ann'
Processed 100/1250 files. Total entities found so far: 763
Processed 200/1250 files. Total entities found so far: 1557
Processed 300/1250 files. Total entities found so far: 2318
Processed 400/1250 files. Total entities found so far: 3082
Processed 500/1250 f

In [None]:
import os
import re
from typing import List, Tuple, Dict
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline

BASE_PATH = '/content/drive/MyDrive/cadec'
TEXT_DIR = os.path.join(BASE_PATH, 'text')
OUTPUT_DIR = os.path.join(BASE_PATH, 'generated_annotations')


class CADECNERLabeler:

    def __init__(self, model_name: str = "d4data/biomedical-ner-all"):
        print(f"Loading model: {model_name}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForTokenClassification.from_pretrained(model_name)
        self.ner_pipeline = pipeline(
            "ner",
            model=self.model,
            tokenizer=self.tokenizer,
            aggregation_strategy="simple",
            device=0 if torch.cuda.is_available() else -1
        )
        print(f"✓ Model loaded\n")

    def extract_entities(self, text: str) -> List[Tuple[str, str, int, int]]:
        results = self.ner_pipeline(text)
        entities = []

        for entity in results:
            word = entity['word']
            entity_type = entity['entity_group']
            start = entity['start']
            end = entity['end']

            cadec_type = self._map_to_cadec_type(entity_type, word)
            if cadec_type:
                entities.append((word, cadec_type, start, end))

        entities = self._merge_adjacent_entities(entities)
        entities = self._filter_modifiers(entities)

        return entities

    def _map_to_cadec_type(self, bio_type: str, text: str) -> str:
        bio_lower = bio_type.lower()
        text_lower = text.lower()

        if 'therapeutic' in bio_lower or 'drug' in bio_lower:
            return 'Drug'

        if 'sign_symptom' in bio_lower or 'symptom' in bio_lower:
            return 'Symptom'

        if 'diagnostic' in bio_lower or 'disease' in bio_lower:
            symptom_words = ['pain', 'ache', 'nausea', 'fever', 'rash', 'weakness', 'fatigue']
            if any(sw in text_lower for sw in symptom_words):
                return 'Symptom'
            return 'Disease'

        if 'severity' in bio_lower or 'biological_structure' in bio_lower:
            return 'Symptom'

        if 'lab_value' in bio_lower or 'duration' in bio_lower or 'age' in bio_lower:
            return None

        return None

    def _merge_adjacent_entities(self, entities: List[Tuple[str, str, int, int]]) -> List[Tuple[str, str, int, int]]:
        if not entities:
            return []

        entities = sorted(entities, key=lambda x: x[2])
        merged = []
        i = 0

        while i < len(entities):
            current_text, current_type, current_start, current_end = entities[i]

            while i + 1 < len(entities):
                next_text, next_type, next_start, next_end = entities[i + 1]
                gap = next_start - current_end

                if gap <= 1 and (current_type == next_type or
                               (current_type in ['Symptom', 'Disease'] and next_type in ['Symptom', 'Disease'])):
                    current_text = current_text + ' ' + next_text
                    current_type = 'Symptom' if 'Symptom' in [current_type, next_type] else current_type
                    current_end = next_end
                    i += 1
                else:
                    break

            merged.append((current_text, current_type, current_start, current_end))
            i += 1

        return merged

    def _filter_modifiers(self, entities: List[Tuple[str, str, int, int]]) -> List[Tuple[str, str, int, int]]:
        modifiers = {'severe', 'mild', 'moderate', 'high', 'low', 'acute', 'chronic'}
        return [(t, ty, s, e) for t, ty, s, e in entities if t.lower().strip() not in modifiers]

    def entities_to_bio(self, text: str, entities: List[Tuple[str, str, int, int]]) -> List[Tuple[str, str]]:
        words = [(m.group(), m.start(), m.end()) for m in re.finditer(r'\S+', text)]
        word_tags = []

        for word, word_start, word_end in words:
            tag = 'O'

            for entity_text, entity_type, entity_start, entity_end in entities:
                if word_start >= entity_start and word_start < entity_end:
                    if word_start == entity_start:
                        tag = f"B-{entity_type}"
                    else:
                        tag = f"I-{entity_type}"
                    break

            word_tags.append((word, tag))

        return word_tags

    def bio_to_cadec_format(self, word_tag_pairs: List[Tuple[str, str]], text: str) -> List[Dict]:
        annotations = []
        entity_id = 1

        i = 0
        while i < len(word_tag_pairs):
            word, tag = word_tag_pairs[i]

            if tag.startswith('B-'):
                entity_type = tag[2:]
                entity_words = [word]

                word_start = text.lower().find(word.lower())
                if word_start == -1:
                    i += 1
                    continue

                j = i + 1
                while j < len(word_tag_pairs):
                    next_word, next_tag = word_tag_pairs[j]
                    if next_tag == f"I-{entity_type}":
                        entity_words.append(next_word)
                        j += 1
                    else:
                        break

                entity_text = ' '.join(entity_words)
                entity_end = word_start + len(entity_text)

                annotations.append({
                    'id': f'T{entity_id}',
                    'type': entity_type,
                    'start': word_start,
                    'end': entity_end,
                    'text': entity_text
                })
                entity_id += 1
                i = j
            else:
                i += 1

        return annotations

    def format_bio(self, word_tag_pairs: List[Tuple[str, str]]) -> str:
        return '\n'.join([f"{word} {tag}" for word, tag in word_tag_pairs])

    def format_cadec(self, annotations: List[Dict]) -> str:
        lines = []
        for ann in annotations:
            line = f"{ann['id']}\t{ann['type']} {ann['start']} {ann['end']}\t{ann['text']}"
            lines.append(line)
        return '\n'.join(lines)

    def process_text(self, text: str, verbose: bool = True) -> Tuple[str, str, List[Dict]]:
        if verbose:
            print("="*80)
            print("PROCESSING TEXT")
            print("="*80)

        entities = self.extract_entities(text)

        if verbose:
            print(f"\nExtracted {len(entities)} entities:")
            for ent_text, ent_type, start, end in entities:
                print(f"  [{start:3d}-{end:3d}] {ent_type:10s} '{ent_text}'")

        word_tag_pairs = self.entities_to_bio(text, entities)

        if verbose:
            print(f"\n\nBIO Format ({len(word_tag_pairs)} tokens):")
            for word, tag in word_tag_pairs:
                print(f"  {word:20s} {tag}")

        bio_output = self.format_bio(word_tag_pairs)
        annotations = self.bio_to_cadec_format(word_tag_pairs, text)
        cadec_output = self.format_cadec(annotations)

        if verbose:
            print(f"\n\nCADEC Format ({len(annotations)} entities):")
            print(cadec_output)
            print("\n" + "="*80)

        return bio_output, cadec_output, annotations


def process_single_file(file_path: str, output_dir: str = None):
    labeler = CADECNERLabeler()

    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read().strip()

    print(f"File: {file_path}")
    print(f"Text: {text[:100]}...\n")

    bio_output, cadec_output, annotations = labeler.process_text(text, verbose=True)

    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        base_name = os.path.basename(file_path)

        with open(os.path.join(output_dir, f"{base_name}.bio"), 'w') as f:
            f.write(bio_output)

        with open(os.path.join(output_dir, f"{base_name}.ann"), 'w') as f:
            f.write(cadec_output)

        print(f"\n✓ Saved to {output_dir}")


def process_directory(text_dir: str, output_dir: str, max_files: int = None):
    labeler = CADECNERLabeler()

    os.makedirs(os.path.join(output_dir, 'bio'), exist_ok=True)
    os.makedirs(os.path.join(output_dir, 'ann'), exist_ok=True)

    files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])
    if max_files:
        files = files[:max_files]

    print(f"Processing {len(files)} files...\n")

    for idx, filename in enumerate(files, 1):
        print(f"[{idx}/{len(files)}] {filename}")

        with open(os.path.join(text_dir, filename), 'r') as f:
            text = f.read().strip()

        bio_output, cadec_output, annotations = labeler.process_text(text, verbose=False)

        with open(os.path.join(output_dir, 'bio', filename), 'w') as f:
            f.write(bio_output)

        with open(os.path.join(output_dir, 'ann', filename.replace('.txt', '.ann')), 'w') as f:
            f.write(cadec_output)

        print(f"  → {len(annotations)} entities\n")


if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        print("✓ Drive mounted\n")
    except:
        pass

    sample_text = """I have been taking Lipitor for my high cholesterol for about 3 months.
Recently I have been experiencing severe muscle pain and weakness in my legs.
I also developed a rash on my arms. My doctor said these could be side effects."""

    print("="*80)
    print("EXAMPLE: Complete Pipeline Demonstration")
    print("="*80)

    labeler = CADECNERLabeler()

    print("\n" + "="*80)
    print("STEP 2A: Generate BIO Labels")
    print("="*80)

    entities = labeler.extract_entities(sample_text)
    print(f"\nExtracted {len(entities)} entities:")
    for ent_text, ent_type, start, end in entities:
        print(f"  [{start:3d}-{end:3d}] {ent_type:10s} '{ent_text}'")

    word_tag_pairs = labeler.entities_to_bio(sample_text, entities)
    bio_output = labeler.format_bio(word_tag_pairs)

    print(f"\n\nBIO Format Output ({len(word_tag_pairs)} tokens):")
    print("-" * 40)
    print(bio_output)
    print("-" * 40)

    print("\n\n" + "="*80)
    print("STEP 2B: Convert BIO Labels to CADEC Original Format")
    print("="*80)

    print("\nInput: BIO format word-tag pairs")
    print("Processing: Identifying entity boundaries (B- tags start entities, I- tags continue them)")

    annotations = labeler.bio_to_cadec_format(word_tag_pairs, sample_text)
    cadec_output = labeler.format_cadec(annotations)

    print(f"\n\nOutput: CADEC Original Format ({len(annotations)} entities)")
    print("-" * 80)
    print("ID\tType Start End\tText")
    print("-" * 80)
    print(cadec_output)
    print("-" * 80)

    print("\n\nDetailed Conversion Explanation:")
    print("-" * 80)
    for ann in annotations:
        print(f"\nEntity {ann['id']}:")
        print(f"  Type: {ann['type']}")
        print(f"  Character Range: {ann['start']}-{ann['end']}")
        print(f"  Text: '{ann['text']}'")

        relevant_pairs = [(w, t) for w, t in word_tag_pairs
                         if t == f"B-{ann['type']}" or t == f"I-{ann['type']}"][:10]
        if relevant_pairs:
            print(f"  BIO Tags:")
            for word, tag in relevant_pairs:
                if ann['text'].lower().startswith(word.lower()) or word.lower() in ann['text'].lower():
                    print(f"    {word:15s} {tag}")

    print("\n\n" + "="*80)
    print("Summary:")
    print("="*80)
    print(f"✓ Step 2a: Generated BIO format with {len(word_tag_pairs)} word-tag pairs")
    print(f"✓ Step 2b: Converted to CADEC format with {len(annotations)} entities")
    print(f"✓ Multi-word entities properly handled (e.g., 'severe muscle pain')")
    print(f"✓ Character offsets calculated correctly")

    print("\n" + "="*80)
    print("To process CADEC files:")
    print("  process_single_file(os.path.join(TEXT_DIR, 'file.txt'), OUTPUT_DIR)")
    print("  process_directory(TEXT_DIR, OUTPUT_DIR, max_files=10)")
    print("="*80)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Drive mounted

EXAMPLE: Complete Pipeline Demonstration
Loading model: d4data/biomedical-ner-all


Device set to use cuda:0


✓ Model loaded


STEP 2A: Generate BIO Labels

Extracted 5 entities:
  [ 19- 26] Drug       'lipitor'
  [ 39- 50] Disease    'cholesterol'
  [106-124] Symptom    'severe muscle pain'
  [129-137] Symptom    'weakness'
  [170-174] Symptom    'rash'


BIO Format Output (42 tokens):
----------------------------------------
I O
have O
been O
taking O
Lipitor B-Drug
for O
my O
high O
cholesterol B-Disease
for O
about O
3 O
months. O
Recently O
I O
have O
been O
experiencing O
severe B-Symptom
muscle I-Symptom
pain I-Symptom
and O
weakness B-Symptom
in O
my O
legs. O
I O
also O
developed O
a O
rash B-Symptom
on O
my O
arms. O
My O
doctor O
said O
these O
could O
be O
side O
effects. O
----------------------------------------


STEP 2B: Convert BIO Labels to CADEC Original Format

Input: BIO format word-tag pairs
Processing: Identifying entity boundaries (B- tags start entities, I- tags continue them)


Output: CADEC Original Format (5 entities)
------------------------------------------------

In [None]:
import os
from typing import Dict, List, Tuple, Set
from dataclasses import dataclass
from collections import defaultdict

BASE_PATH = '/content/drive/MyDrive/cadec'
TEXT_DIR = os.path.join(BASE_PATH, 'text')
ORIGINAL_DIR = os.path.join(BASE_PATH, 'original')
OUTPUT_DIR = os.path.join(BASE_PATH, 'generated_annotations')


@dataclass
class Entity:
    """Represents a single entity annotation."""
    entity_id: str
    entity_type: str
    start: int
    end: int
    text: str

    def __hash__(self):
        return hash((self.entity_type, self.start, self.end))

    def __eq__(self, other):
        return (self.entity_type == other.entity_type and
                self.start == other.start and
                self.end == other.end)


class CADECEvaluator:


    def __init__(self):
        self.entity_types = ['Drug', 'Disease', 'Symptom', 'ADR']

    def parse_annotation_file(self, file_path: str) -> List[Entity]:
        t: T1\tDrug 145 158\tLipitor

        entities = []

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    continue

                try:
                    parts = line.split('\t')
                    if len(parts) < 3:
                        continue

                    entity_id = parts[0]
                    type_and_span = parts[1].split()
                    entity_type = type_and_span[0]

                    start = int(type_and_span[1])
                    end = int(type_and_span[2]) if len(type_and_span) > 2 else int(type_and_span[1])

                    text = parts[2]

                    entities.append(Entity(entity_id, entity_type, start, end, text))

                except (ValueError, IndexError) as e:
                    continue

        return entities

    def calculate_metrics(self, true_entities: List[Entity],
                         pred_entities: List[Entity]) -> Dict:

        true_set = set(true_entities)
        pred_set = set(pred_entities)

        true_positive = len(true_set & pred_set)
        false_positive = len(pred_set - true_set)
        false_negative = len(true_set - pred_set)

        precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
        recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        metrics = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'true_positive': true_positive,
            'false_positive': false_positive,
            'false_negative': false_negative,
            'true_entities': len(true_entities),
            'pred_entities': len(pred_entities)
        }

        per_type_metrics = {}
        for entity_type in self.entity_types:
            true_type = [e for e in true_entities if e.entity_type == entity_type]
            pred_type = [e for e in pred_entities if e.entity_type == entity_type]

            true_type_set = set(true_type)
            pred_type_set = set(pred_type)

            tp = len(true_type_set & pred_type_set)
            fp = len(pred_type_set - true_type_set)
            fn = len(true_type_set - pred_type_set)

            prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
            rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
            f1_type = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0

            per_type_metrics[entity_type] = {
                'precision': prec,
                'recall': rec,
                'f1': f1_type,
                'true_positive': tp,
                'false_positive': fp,
                'false_negative': fn,
                'support': len(true_type)
            }

        metrics['per_type'] = per_type_metrics

        return metrics

    def analyze_errors(self, true_entities: List[Entity],
                       pred_entities: List[Entity]) -> Dict:
        """
        Detailed error analysis showing what went wrong.
        """
        true_set = set(true_entities)
        pred_set = set(pred_entities)

        false_positives = list(pred_set - true_set)
        false_negatives = list(true_set - pred_set)

        type_confusion = defaultdict(lambda: defaultdict(int))

        for fp in false_positives:
            for fn in false_negatives:
                if fp.start == fn.start and fp.end == fn.end:
                    type_confusion[fn.entity_type][fp.entity_type] += 1

        boundary_errors = []
        for fp in false_positives:
            for fn in false_negatives:
                if fp.entity_type == fn.entity_type:
                    if (fp.start <= fn.end and fp.end >= fn.start):
                        boundary_errors.append({
                            'true': fn,
                            'pred': fp,
                            'overlap': True
                        })

        return {
            'false_positives': false_positives,
            'false_negatives': false_negatives,
            'type_confusion': dict(type_confusion),
            'boundary_errors': boundary_errors
        }

    def evaluate_single_file(self, text_file: str, verbose: bool = True) -> Dict:
        """
        Evaluate predictions for a single file against ground truth.
        """
        base_name = os.path.basename(text_file)
        file_id = base_name.replace('.txt', '')

        true_ann_file = os.path.join(ORIGINAL_DIR, file_id + '.ann')
        pred_ann_file = os.path.join(OUTPUT_DIR, 'ann', file_id + '.ann')

        if not os.path.exists(true_ann_file):
            if verbose:
                print(f"Warning: Ground truth not found for {base_name}")
            return None

        if not os.path.exists(pred_ann_file):
            if verbose:
                print(f"Warning: Prediction not found for {base_name}")
            return None

        true_entities = self.parse_annotation_file(true_ann_file)
        pred_entities = self.parse_annotation_file(pred_ann_file)

        metrics = self.calculate_metrics(true_entities, pred_entities)
        errors = self.analyze_errors(true_entities, pred_entities)

        if verbose:
            self.print_evaluation(base_name, metrics, errors, true_entities, pred_entities)

        return {
            'metrics': metrics,
            'errors': errors,
            'file': base_name
        }

    def print_evaluation(self, filename: str, metrics: Dict, errors: Dict,
                        true_entities: List[Entity], pred_entities: List[Entity]):
        """Print detailed evaluation results."""
        print("\n" + "="*80)
        print(f"EVALUATION: {filename}")
        print("="*80)

        print("\nOVERALL METRICS:")
        print("-" * 80)
        print(f"Precision: {metrics['precision']:.3f}")
        print(f"Recall:    {metrics['recall']:.3f}")
        print(f"F1 Score:  {metrics['f1']:.3f}")
        print(f"\nTrue Positives:  {metrics['true_positive']}")
        print(f"False Positives: {metrics['false_positive']}")
        print(f"False Negatives: {metrics['false_negative']}")
        print(f"\nGround Truth Entities: {metrics['true_entities']}")
        print(f"Predicted Entities:    {metrics['pred_entities']}")

        print("\n\nPER-TYPE METRICS:")
        print("-" * 80)
        print(f"{'Type':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'Support':<10}")
        print("-" * 80)

        for entity_type in self.entity_types:
            if entity_type in metrics['per_type']:
                m = metrics['per_type'][entity_type]
                print(f"{entity_type:<12} {m['precision']:<12.3f} {m['recall']:<12.3f} "
                      f"{m['f1']:<12.3f} {m['support']:<10}")

        if errors['false_positives']:
            print("\n\nFALSE POSITIVES (Predicted but not in ground truth):")
            print("-" * 80)
            for i, fp in enumerate(errors['false_positives'][:5], 1):
                print(f"{i}. [{fp.start}-{fp.end}] {fp.entity_type}: '{fp.text}'")
            if len(errors['false_positives']) > 5:
                print(f"... and {len(errors['false_positives']) - 5} more")

        if errors['false_negatives']:
            print("\n\nFALSE NEGATIVES (In ground truth but not predicted):")
            print("-" * 80)
            for i, fn in enumerate(errors['false_negatives'][:5], 1):
                print(f"{i}. [{fn.start}-{fn.end}] {fn.entity_type}: '{fn.text}'")
            if len(errors['false_negatives']) > 5:
                print(f"... and {len(errors['false_negatives']) - 5} more")

        if errors['type_confusion']:
            print("\n\nTYPE CONFUSION (Same span, wrong type):")
            print("-" * 80)
            for true_type, pred_types in errors['type_confusion'].items():
                for pred_type, count in pred_types.items():
                    print(f"  {true_type} → {pred_type}: {count} times")

    def evaluate_directory(self, text_dir: str, max_files: int = None) -> Dict:
        """
        Evaluate all files in directory and compute aggregate metrics.
        """
        text_files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])
        if max_files:
            text_files = text_files[:max_files]

        print(f"\n{'='*80}")
        print(f"EVALUATING {len(text_files)} FILES")
        print(f"{'='*80}\n")

        all_results = []
        aggregate_tp = 0
        aggregate_fp = 0
        aggregate_fn = 0

        per_type_aggregate = {et: {'tp': 0, 'fp': 0, 'fn': 0} for et in self.entity_types}

        for idx, filename in enumerate(text_files, 1):
            file_path = os.path.join(text_dir, filename)
            result = self.evaluate_single_file(file_path, verbose=False)

            if result:
                all_results.append(result)
                m = result['metrics']
                aggregate_tp += m['true_positive']
                aggregate_fp += m['false_positive']
                aggregate_fn += m['false_negative']

                for entity_type in self.entity_types:
                    if entity_type in m['per_type']:
                        pt = m['per_type'][entity_type]
                        per_type_aggregate[entity_type]['tp'] += pt['true_positive']
                        per_type_aggregate[entity_type]['fp'] += pt['false_positive']
                        per_type_aggregate[entity_type]['fn'] += pt['false_negative']

                print(f"[{idx}/{len(text_files)}] {filename}: "
                      f"P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f}")

        overall_precision = aggregate_tp / (aggregate_tp + aggregate_fp) if (aggregate_tp + aggregate_fp) > 0 else 0.0
        overall_recall = aggregate_tp / (aggregate_tp + aggregate_fn) if (aggregate_tp + aggregate_fn) > 0 else 0.0
        overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0

        per_type_metrics = {}
        for entity_type in self.entity_types:
            agg = per_type_aggregate[entity_type]
            prec = agg['tp'] / (agg['tp'] + agg['fp']) if (agg['tp'] + agg['fp']) > 0 else 0.0
            rec = agg['tp'] / (agg['tp'] + agg['fn']) if (agg['tp'] + agg['fn']) > 0 else 0.0
            f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) > 0 else 0.0

            per_type_metrics[entity_type] = {
                'precision': prec,
                'recall': rec,
                'f1': f1,
                'tp': agg['tp'],
                'fp': agg['fp'],
                'fn': agg['fn']
            }

        print("\n" + "="*80)
        print("AGGREGATE RESULTS")
        print("="*80)
        print(f"\nFiles Evaluated: {len(all_results)}")
        print(f"\nOVERALL METRICS:")
        print(f"  Precision: {overall_precision:.3f}")
        print(f"  Recall:    {overall_recall:.3f}")
        print(f"  F1 Score:  {overall_f1:.3f}")
        print(f"\n  True Positives:  {aggregate_tp}")
        print(f"  False Positives: {aggregate_fp}")
        print(f"  False Negatives: {aggregate_fn}")

        print("\n\nPER-TYPE AGGREGATE METRICS:")
        print("-" * 80)
        print(f"{'Type':<12} {'Precision':<12} {'Recall':<12} {'F1':<12} {'TP':<8} {'FP':<8} {'FN':<8}")
        print("-" * 80)

        for entity_type in self.entity_types:
            m = per_type_metrics[entity_type]
            print(f"{entity_type:<12} {m['precision']:<12.3f} {m['recall']:<12.3f} "
                  f"{m['f1']:<12.3f} {m['tp']:<8} {m['fp']:<8} {m['fn']:<8}")

        return {
            'overall': {
                'precision': overall_precision,
                'recall': overall_recall,
                'f1': overall_f1,
                'true_positive': aggregate_tp,
                'false_positive': aggregate_fp,
                'false_negative': aggregate_fn
            },
            'per_type': per_type_metrics,
            'num_files': len(all_results),
            'individual_results': all_results
        }


def generate_and_evaluate(text_dir: str, output_dir: str, max_files: int = 10):
    """
    Complete pipeline: Generate predictions and evaluate against ground truth.
    """
    from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
    import torch
    import re

    print("="*80)
    print("STEP 1: GENERATING PREDICTIONS")
    print("="*80)

    print("\nLoading NER model...")
    tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
    model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
    ner_pipeline = pipeline(
        "ner",
        model=model,
        tokenizer=tokenizer,
        aggregation_strategy="simple",
        device=0 if torch.cuda.is_available() else -1
    )
    print("✓ Model loaded\n")

    os.makedirs(os.path.join(output_dir, 'ann'), exist_ok=True)

    text_files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])[:max_files]

    print(f"Processing {len(text_files)} files...\n")

    for idx, filename in enumerate(text_files, 1):
        print(f"[{idx}/{len(text_files)}] {filename}")

        with open(os.path.join(text_dir, filename), 'r', encoding='utf-8') as f:
            text = f.read().strip()

        results = ner_pipeline(text)
        entities = []

        for entity in results:
            word = entity['word']
            entity_type = entity['entity_group'].lower()
            start = entity['start']
            end = entity['end']

            cadec_type = None
            if 'therapeutic' in entity_type or 'drug' in entity_type:
                cadec_type = 'Drug'
            elif 'sign_symptom' in entity_type or 'symptom' in entity_type:
                cadec_type = 'Symptom'
            elif 'diagnostic' in entity_type or 'disease' in entity_type:
                cadec_type = 'Disease'
            elif 'severity' in entity_type or 'biological_structure' in entity_type:
                cadec_type = 'Symptom'

            if cadec_type:
                entities.append((word, cadec_type, start, end))

        entities = sorted(entities, key=lambda x: x[2])
        merged = []
        i = 0
        while i < len(entities):
            current_text, current_type, current_start, current_end = entities[i]

            while i + 1 < len(entities):
                next_text, next_type, next_start, next_end = entities[i + 1]
                gap = next_start - current_end

                if gap <= 1 and current_type == next_type:
                    current_text = current_text + ' ' + next_text
                    current_end = next_end
                    i += 1
                else:
                    break

            merged.append((current_text, current_type, current_start, current_end))
            i += 1

        output_lines = []
        for idx_ent, (ent_text, ent_type, start, end) in enumerate(merged, 1):
            output_lines.append(f"T{idx_ent}\t{ent_type} {start} {end}\t{ent_text}")

        output_file = os.path.join(output_dir, 'ann', filename.replace('.txt', '.ann'))
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('\n'.join(output_lines))

        print(f"  → Generated {len(merged)} entities")

    print("\n" + "="*80)
    print("STEP 2: EVALUATING PREDICTIONS")
    print("="*80 + "\n")

    evaluator = CADECEvaluator()
    results = evaluator.evaluate_directory(text_dir, max_files=max_files)

    return results


def detailed_error_analysis(text_dir: str, output_dir: str, num_examples: int = 2):
    """
    Show detailed examples of predictions vs ground truth.
    """
    evaluator = CADECEvaluator()

    text_files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])[:num_examples]

    for filename in text_files:
        file_path = os.path.join(text_dir, filename)

        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read().strip()

        print("\n" + "="*80)
        print(f"FILE: {filename}")
        print("="*80)
        print(f"\nText:\n{text}\n")

        result = evaluator.evaluate_single_file(file_path, verbose=True)

        if result:
            print("\n" + "="*80)


if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        print("✓ Drive mounted\n")
    except:
        pass

    print("="*80)
    print("CADEC NER EVALUATION PIPELINE")
    print("="*80)
    print("\nThis will:")
    print("1. Generate predictions using biomedical NER model")
    print("2. Compare predictions against ground truth (original directory)")
    print("3. Compute Exact Match F1 scores")
    print("\nMetric: Token-level Exact Match F1 Score")
    print("Requirement: Entity type, start, and end must match exactly")
    print("\n" + "="*80 + "\n")

    results = generate_and_evaluate(TEXT_DIR, OUTPUT_DIR, max_files=5)

    print("\n\n" + "="*80)
    print("DETAILED ERROR ANALYSIS")
    print("="*80)
    detailed_error_analysis(TEXT_DIR, OUTPUT_DIR, num_examples=2)

    print("\n" + "="*80)
    print("EVALUATION COMPLETE")
    print("="*80)
    print(f"\nTo evaluate more files:")
    print("  results = generate_and_evaluate(TEXT_DIR, OUTPUT_DIR, max_files=20)")
    print("="*80)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Drive mounted

CADEC NER EVALUATION PIPELINE

This will:
1. Generate predictions using biomedical NER model
2. Compare predictions against ground truth (original directory)
3. Compute Exact Match F1 scores

Metric: Token-level Exact Match F1 Score
Requirement: Entity type, start, and end must match exactly


STEP 1: GENERATING PREDICTIONS

Loading NER model...


Device set to use cuda:0


✓ Model loaded

Processing 5 files...

[1/5] ARTHROTEC.1.txt
  → Generated 5 entities
[2/5] ARTHROTEC.10.txt
  → Generated 3 entities
[3/5] ARTHROTEC.100.txt
  → Generated 6 entities
[4/5] ARTHROTEC.101.txt
  → Generated 4 entities
[5/5] ARTHROTEC.102.txt
  → Generated 1 entities

STEP 2: EVALUATING PREDICTIONS


EVALUATING 5 FILES

[1/5] ARTHROTEC.1.txt: P=0.400 R=0.250 F1=0.308
[2/5] ARTHROTEC.10.txt: P=0.333 R=0.500 F1=0.400
[3/5] ARTHROTEC.100.txt: P=0.167 R=0.200 F1=0.182
[4/5] ARTHROTEC.101.txt: P=0.000 R=0.000 F1=0.000
[5/5] ARTHROTEC.102.txt: P=0.000 R=0.000 F1=0.000

AGGREGATE RESULTS

Files Evaluated: 5

OVERALL METRICS:
  Precision: 0.211
  Recall:    0.190
  F1 Score:  0.200

  True Positives:  4
  False Positives: 15
  False Negatives: 17


PER-TYPE AGGREGATE METRICS:
--------------------------------------------------------------------------------
Type         Precision    Recall       F1           TP       FP       FN      
------------------------------------------------

 it's showing that a general biomedical NER model performs poorly on this specialized task without domain-specific fine-tuning

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from typing import List, Dict, Tuple, Set
from dataclasses import dataclass
import os
import random
import re
import json


BASE_PATH = '/content/drive/MyDrive/cadec'

TEXT_DIR = os.path.join(BASE_PATH, 'text')
ORIGINAL_DIR = os.path.join(BASE_PATH, 'original')
MEDDRA_DIR = os.path.join(BASE_PATH, 'meddra')


# DATA STRUCTURES


@dataclass(frozen=True)
class ADREntity:
    start: int
    end: int
    text: str

    def __hash__(self):
        return hash((self.start, self.end))

    def __eq__(self, other):
        if not isinstance(other, ADREntity):
            return NotImplemented
        return self.start == other.start and self.end == other.end

    def __repr__(self):
        return f"ADR([{self.start}-{self.end}], '{self.text}')"


# BIOMEDICAL NER MODEL & CUSTOM ADR CLASSIFICATION LOGIC


class BiomedicalNER:
    def __init__(self):
        self.ner_pipeline = None
        self.model_name = "d4data/biomedical-ner-all"

    def load_model(self):
        if self.ner_pipeline is not None:
            return

        print(f"Loading Hugging Face NER model: '{self.model_name}'...")
        tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        model = AutoModelForTokenClassification.from_pretrained(self.model_name)
        device = 0 if torch.cuda.is_available() else -1
        self.ner_pipeline = pipeline(
            "ner",
            model=model,
            tokenizer=tokenizer,
            aggregation_strategy="simple",
            device=device
        )
        # FIX: Changed self_model_name to self.model_name
        print(f"Model '{self.model_name}' loaded successfully on {'GPU (CUDA)' if device == 0 else 'CPU'}.\n")

    def predict_adrs(self, text: str) -> List[ADREntity]:
        if self.ner_pipeline is None:
            self.load_model()

        ner_results = self.ner_pipeline(text)

        intermediate_entities = []
        text_lower = text.lower()

        for result in ner_results:
            entity_text = result['word']
            bio_type = result['entity_group'].lower()
            start = result['start']
            end = result['end']

            cadec_type = self._map_entity_type(bio_type, entity_text.lower())

            if cadec_type:
                intermediate_entities.append((entity_text, cadec_type, start, end))

        intermediate_entities.sort(key=lambda x: x[2])

        merged_entities = self._merge_adjacent(intermediate_entities, text)

        adr_predictions = []

        for ent_text, ent_type, start, end in merged_entities:
            final_type = self._classify_as_adr(text_lower, ent_text, ent_type, start, end)
            if final_type == 'ADR':
                adr_predictions.append(ADREntity(start=start, end=end, text=ent_text))

        return adr_predictions

    def _map_entity_type(self, bio_type: str, text_lower: str) -> str:
        if 'therapeutic' in bio_type or 'drug' in bio_type or 'pharmacologic' in bio_type:
            return 'Drug'

        if 'sign_symptom' in bio_type or 'symptom' in bio_type:
            return 'Symptom'

        if 'disease' in bio_type or 'diagnostic' in bio_type or 'disorder' in bio_type:
            symptom_keywords = ['pain', 'ache', 'nausea', 'vomiting', 'fever', 'rash',
                                'weakness', 'fatigue', 'tired', 'dizziness', 'headache',
                                'migraine', 'cramps', 'insomnia', 'diarrhea', 'constipation',
                                'loss of appetite', 'irritability', 'mood swings', 'anxiety',
                                'depression', 'difficulty', 'problems with']
            if any(keyword in text_lower for keyword in symptom_keywords):
                return 'Symptom'
            return 'Disease'

        if 'severity' in bio_type or 'biological_structure' in bio_type:
            return 'Symptom'

        ignore_types = ['lab_value', 'duration', 'age', 'dosage', 'frequency',
                        'measurement', 'date', 'procedure', 'gene_or_protein',
                        'device', 'substance', 'finding']
        if any(ig_type in bio_type for ig_type in ignore_types):
            return None

        return None

    def _merge_adjacent(self, entities: List[Tuple], doc_text: str) -> List[Tuple]:
        if not entities:
            return []

        merged_results = []
        current_text, current_type, current_start, current_end = entities[0]

        for i in range(1, len(entities)):
            next_text, next_type, next_start, next_end = entities[i]
            gap = next_start - current_end

            if current_type == next_type and gap <= 1:
                current_end = next_end
                current_text = doc_text[current_start:current_end]
            else:
                merged_results.append((current_text, current_type, current_start, current_end))
                current_text, current_type, current_start, current_end = next_text, next_type, next_start, next_end

        merged_results.append((current_text, current_type, current_start, current_end))

        return merged_results

    def _classify_as_adr(self, full_text_lower: str, entity_text: str,
                        entity_type: str, start: int, end: int) -> str:
        if entity_type != 'Symptom':
            return entity_type

        context_window_size = 150
        context_start = max(0, start - context_window_size)
        context_end = min(len(full_text_lower), end + context_window_size)
        context = full_text_lower[context_start:context_end]

        adr_indicators = [
            'side effect', 'side-effect', 'adverse reaction', 'adverse event',
            'reaction to', 'caused by', 'due to', 'from taking', 'because of',
            'after taking', 'while on', 'while taking', 'on this medication',
            'withdrawal', 'stopped taking', 'came off', 'discontinued',
            'experienced', 'suffered from', 'developed', 'got worse',
            'gave me', 'made me feel', 'triggered by', 'drug induced',
            'effect of', 'onset of', 'medication related'
        ]

        has_adr_context = any(indicator in context for indicator in adr_indicators)

        drug_keywords_in_context = ['drug', 'medication', 'pill', 'tablet', 'rx', 'prescribed', 'meds']
        has_drug_keywords = any(keyword in context for keyword in drug_keywords_in_context)

        if has_adr_context or (has_drug_keywords and ("taking" in context or "on" in context)):
            return 'ADR'

        return 'Symptom'


# ANNOTATION FILE PARSERS (CUSTOMIZED FOR ADR SPAN GROUND TRUTH)


def parse_adr_spans_from_original_annotations(file_path: str, raw_text: str) -> List[ADREntity]:
    ground_truth_adrs = []

    if not os.path.exists(file_path):
        print(f"Warning: Ground truth file '{os.path.basename(file_path)}' not found. Skipping evaluation for this file.")
        return ground_truth_adrs

    with open(file_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            line = line.strip()

            if not line or line.startswith('#'):
                continue

            match = re.match(r'^(T\d+)\t(ADR|Drug|Symptom|Disease)\t([\d\-;]+)\t(.+)$', line)
            if not match:
                print(f"Warning: Line {line_num} in '{os.path.basename(file_path)}' did not match expected format: '{line}'. Skipping.")
                continue

            label = match.group(2)
            ranges_str = match.group(3)

            if label != 'ADR':
                continue

            for r_str in ranges_str.split(';'):
                s_str, e_str = r_str.split('-')
                start_offset = int(s_str)
                end_offset = int(e_str)

                if not (0 <= start_offset < end_offset <= len(raw_text)):
                    print(f"Warning: Invalid span [{start_offset}-{end_offset}] found in line {line_num} of '{os.path.basename(file_path)}'. Text length: {len(raw_text)}. Skipping span.")
                    continue

                entity_text_from_raw = raw_text[start_offset:end_offset].strip()

                if entity_text_from_raw:
                    ground_truth_adrs.append(
                        ADREntity(start=start_offset, end=end_offset, text=entity_text_from_raw)
                    )
    return ground_truth_adrs


# EVALUATION METRICS CALCULATION


def calculate_metrics(true_adrs: List[ADREntity],
                      pred_adrs: List[ADREntity]) -> Dict:
    true_set = set(true_adrs)
    pred_set = set(pred_adrs)

    true_positive = len(true_set.intersection(pred_set))
    false_positive = len(pred_set.difference(true_set))
    false_negative = len(true_set.difference(pred_set))

    precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
    recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'true_positive': true_positive,
        'false_positive': false_positive,
        'false_negative': false_negative,
        'num_true_adrs': len(true_adrs),
        'num_pred_adrs': len(pred_adrs)
    }


# MAIN EVALUATOR CLASS FOR TASK 4

class Task4ADREvaluator:
    def __init__(self, seed: int = 42):
        self.ner_model = BiomedicalNER()
        self.seed = seed
        random.seed(seed)

    def evaluate_single_file(self, filename: str) -> Dict:
        file_id = filename.replace('.txt', '')

        text_file_path = os.path.join(TEXT_DIR, filename)
        if not os.path.exists(text_file_path):
            raise FileNotFoundError(f"Text file not found: '{text_file_path}'. Cannot evaluate.")

        with open(text_file_path, 'r', encoding='utf-8') as f:
            raw_forum_post_text = f.read().strip()

        pred_adrs = self.ner_model.predict_adrs(raw_forum_post_text)

        original_annotation_path = os.path.join(ORIGINAL_DIR, file_id + '.ann')
        true_adrs = parse_adr_spans_from_original_annotations(original_annotation_path, raw_forum_post_text)

        metrics = calculate_metrics(true_adrs, pred_adrs)

        return {
            'filename': filename,
            'file_id': file_id,
            'text_length': len(raw_forum_post_text),
            'true_adrs': true_adrs,
            'pred_adrs': pred_adrs,
            'metrics': metrics
        }

    def evaluate_random_sample(self, num_files: int = 50) -> Dict:
        print("\n" + "="*80)
        print("STARTING TASK 4: ADR PERFORMANCE EVALUATION (Random Sample)")
        print("Ground Truth Source: ADR spans from 'original' annotations.")
        print("This approach is necessary because MedDRA files (in CADEC) provide")
        print("concept IDs but lack character span annotations required for F1-score.")
        print("="*80 + "\n")

        all_text_files = sorted([f for f in os.listdir(TEXT_DIR) if f.endswith('.txt')])

        if len(all_text_files) < num_files:
            print(f"Warning: Only {len(all_text_files)} text files found. Evaluating all available files.")
            selected_files = all_text_files
        else:
            selected_files = random.sample(all_text_files, num_files)
            selected_files.sort()

        print(f"Random seed for file selection: {self.seed}")
        print(f"Number of files selected: {len(selected_files)}\n")
        if selected_files:
            print(f"First 5 selected files: {', '.join(selected_files[:5])}{'...' if len(selected_files) > 5 else ''}\n")
        else:
            print("No text files found. Check BASE_PATH and TEXT_DIR configuration.\n")
            return {
                'aggregate_metrics': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'true_positive': 0, 'false_positive': 0, 'false_negative': 0},
                'average_metrics': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0},
                'num_files_evaluated': 0,
                'selected_files': [],
                'individual_results': []
            }

        self.ner_model.load_model()

        individual_file_results = []
        total_true_positive = 0
        total_false_positive = 0
        total_false_negative = 0

        total_precision = 0.0
        total_recall = 0.0
        total_f1 = 0.0
        num_evaluated_files_with_preds_or_truth = 0


        print("-" * 80)
        print("Processing files (individual metrics):")
        print("-" * 80)

        for idx, filename in enumerate(selected_files, 1):
            print(f"({idx}/{len(selected_files)}) Evaluating {filename}...")
            try:
                result = self.evaluate_single_file(filename)
                individual_file_results.append(result)

                metrics = result['metrics']
                total_true_positive += metrics['true_positive']
                total_false_positive += metrics['false_positive']
                total_false_negative += metrics['false_negative']

                if metrics['num_true_adrs'] > 0 or metrics['num_pred_adrs'] > 0:
                    total_precision += metrics['precision']
                    total_recall += metrics['recall']
                    total_f1 += metrics['f1']
                    num_evaluated_files_with_preds_or_truth += 1

                print(f"  -> TP: {metrics['true_positive']}, FP: {metrics['false_positive']}, FN: {metrics['false_negative']}")
                print(f"  -> P: {metrics['precision']:.4f}, R: {metrics['recall']:.4f}, F1: {metrics['f1']:.4f}\n")

            except FileNotFoundError as e:
                print(f"Error for {filename}: {e}. Skipping file.")
            except Exception as e:
                print(f"An unexpected error occurred for {filename}: {e}. Skipping file.")

        print("-" * 80)
        print("Evaluation Complete.")
        print("-" * 80)

        micro_precision = total_true_positive / (total_true_positive + total_false_positive) if (total_true_positive + total_false_positive) > 0 else 0.0
        micro_recall = total_true_positive / (total_true_positive + total_false_negative) if (total_true_positive + total_false_negative) > 0 else 0.0
        micro_f1 = 2 * (micro_precision * micro_recall) / (micro_precision + micro_recall) if (micro_precision + micro_recall) > 0 else 0.0

        macro_precision = total_precision / num_evaluated_files_with_preds_or_truth if num_evaluated_files_with_preds_or_truth > 0 else 0.0
        macro_recall = total_recall / num_evaluated_files_with_preds_or_truth if num_evaluated_files_with_preds_or_truth > 0 else 0.0
        macro_f1 = total_f1 / num_evaluated_files_with_preds_or_truth if num_evaluated_files_with_preds_or_truth > 0 else 0.0


        aggregate_metrics = {
            'precision': micro_precision,
            'recall': micro_recall,
            'f1': micro_f1,
            'true_positive': total_true_positive,
            'false_positive': total_false_positive,
            'false_negative': total_false_negative,
            'total_true_adrs_in_sample': sum(res['metrics']['num_true_adrs'] for res in individual_file_results),
            'total_pred_adrs_in_sample': sum(res['metrics']['num_pred_adrs'] for res in individual_file_results)
        }

        average_metrics = {
            'precision': macro_precision,
            'recall': macro_recall,
            'f1': macro_f1
        }

        print("\n" + "="*80)
        print("AGGREGATE (MICRO-AVERAGED) RESULTS ACROSS ALL EVALUATED FILES:")
        print("="*80)
        print(f"True Positives: {aggregate_metrics['true_positive']}")
        print(f"False Positives: {aggregate_metrics['false_positive']}")
        print(f"False Negatives: {aggregate_metrics['false_negative']}")
        print(f"Micro Precision: {aggregate_metrics['precision']:.4f}")
        print(f"Micro Recall: {aggregate_metrics['recall']:.4f}")
        print(f"Micro F1-score: {aggregate_metrics['f1']:.4f}")

        print("\n" + "="*80)
        print("AVERAGE (MACRO-AVERAGED) RESULTS ACROSS EVALUATED FILES:")
        print("="*80)
        print(f"Macro Precision: {average_metrics['precision']:.4f}")
        print(f"Macro Recall: {average_metrics['recall']:.4f}")
        print(f"Macro F1-score: {average_metrics['f1']:.4f}")
        print("="*80 + "\n")


        return {
            'aggregate_metrics': aggregate_metrics,
            'average_metrics': average_metrics,
            'num_files_evaluated': len(individual_file_results),
            'selected_files': selected_files,
            'individual_results': individual_file_results
        }

if __name__ == "__main__":
    evaluator = Task4ADREvaluator(seed=123)
    evaluation_summary = evaluator.evaluate_random_sample(num_files=50)

    class CustomJsonEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, ADREntity):
                return str(obj)
            return json.JSONEncoder.default(self, obj)


STARTING TASK 4: ADR PERFORMANCE EVALUATION (Random Sample)
Ground Truth Source: ADR spans from 'original' annotations.
This approach is necessary because MedDRA files (in CADEC) provide
concept IDs but lack character span annotations required for F1-score.

Random seed for file selection: 123
Number of files selected: 50

First 5 selected files: ARTHROTEC.101.txt, ARTHROTEC.110.txt, ARTHROTEC.138.txt, ARTHROTEC.36.txt, ARTHROTEC.39.txt...

Loading Hugging Face NER model: 'd4data/biomedical-ner-all'...


Device set to use cpu


Model 'd4data/biomedical-ner-all' loaded successfully on CPU.

--------------------------------------------------------------------------------
Processing files (individual metrics):
--------------------------------------------------------------------------------
(1/50) Evaluating ARTHROTEC.101.txt...
  -> TP: 0, FP: 4, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(2/50) Evaluating ARTHROTEC.110.txt...
  -> TP: 0, FP: 2, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(3/50) Evaluating ARTHROTEC.138.txt...
  -> TP: 0, FP: 1, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(4/50) Evaluating ARTHROTEC.36.txt...
  -> TP: 0, FP: 3, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(5/50) Evaluating ARTHROTEC.39.txt...
  -> TP: 0, FP: 2, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(6/50) Evaluating ARTHROTEC.49.txt...
  -> TP: 0, FP: 3, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(7/50) Evaluating ARTHROTEC.64.txt...
  -> TP: 0, FP: 5, FN: 0
  -> P: 0.0000, R: 0.0000, F1: 0.0000

(8/50) Evaluating A

In [None]:
import os
from typing import Dict, List, Set
from dataclasses import dataclass
from collections import defaultdict

BASE_PATH = '/content/drive/MyDrive/cadec'
TEXT_DIR = os.path.join(BASE_PATH, 'text')
MEDDRA_DIR = os.path.join(BASE_PATH, 'meddra')
OUTPUT_DIR = os.path.join(BASE_PATH, 'generated_annotations')


@dataclass
class ADREntity:
    """Represents an ADR entity annotation."""
    entity_id: str
    start: int
    end: int
    text: str

    def __hash__(self):
        return hash((self.start, self.end))

    def __eq__(self, other):
        return self.start == other.start and self.end == other.end


class ADREvaluatorMedDRA:


    def parse_meddra_file(self, file_path: str) -> List[ADREntity]:

        entities = []

        if not os.path.exists(file_path):
            return entities

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    continue

                try:
                    parts = line.split('\t')
                    if len(parts) < 3:
                        continue

                    entity_id = parts[0]

                    position_part = parts[1].split()

                    if len(position_part) < 3:
                        continue

                    start = int(position_part[1])
                    end = int(position_part[2])

                    text = parts[2]

                    entities.append(ADREntity(entity_id, start, end, text))

                except (ValueError, IndexError) as e:
                    continue

        return entities

    def parse_prediction_file(self, file_path: str) -> List[ADREntity]:

        entities = []

        if not os.path.exists(file_path):
            return entities

        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    continue

                try:
                    parts = line.split('\t')
                    if len(parts) < 3:
                        continue

                    entity_id = parts[0]
                    type_and_span = parts[1].split()
                    entity_type = type_and_span[0]

                    if entity_type != 'ADR':
                        continue

                    start = int(type_and_span[1])
                    end = int(type_and_span[2]) if len(type_and_span) > 2 else start
                    text = parts[2]

                    entities.append(ADREntity(entity_id, start, end, text))

                except (ValueError, IndexError):
                    continue

        return entities

    def calculate_metrics(self, true_adrs: List[ADREntity],
                         pred_adrs: List[ADREntity]) -> Dict:

        true_set = set(true_adrs)
        pred_set = set(pred_adrs)

        true_positive = len(true_set & pred_set)
        false_positive = len(pred_set - true_set)
        false_negative = len(true_set - pred_set)

        precision = true_positive / (true_positive + false_positive) if (true_positive + false_positive) > 0 else 0.0
        recall = true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0

        return {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'true_positive': true_positive,
            'false_positive': false_positive,
            'false_negative': false_negative,
            'ground_truth_adrs': len(true_adrs),
            'predicted_adrs': len(pred_adrs)
        }

    def evaluate_single_file(self, text_filename: str, verbose: bool = True) -> Dict:

        file_id = text_filename.replace('.txt', '')

        meddra_file = os.path.join(MEDDRA_DIR, file_id + '.ann')
        pred_file = os.path.join(OUTPUT_DIR, 'ann', file_id + '.ann')

        if not os.path.exists(meddra_file):
            if verbose:
                print(f"Warning: MedDRA ground truth not found for {text_filename}")
            return None

        if not os.path.exists(pred_file):
            if verbose:
                print(f"Warning: Prediction not found for {text_filename}")
            return None

        true_adrs = self.parse_meddra_file(meddra_file)
        pred_adrs = self.parse_prediction_file(pred_file)

        metrics = self.calculate_metrics(true_adrs, pred_adrs)

        true_set = set(true_adrs)
        pred_set = set(pred_adrs)
        false_positives = list(pred_set - true_set)
        false_negatives = list(true_set - pred_set)

        if verbose:
            print(f"\n{'='*80}")
            print(f"EVALUATION: {text_filename}")
            print(f"{'='*80}")
            print(f"\nADR DETECTION METRICS (MedDRA Ground Truth):")
            print(f"{'-'*80}")
            print(f"Precision: {metrics['precision']:.3f}")
            print(f"Recall:    {metrics['recall']:.3f}")
            print(f"F1 Score:  {metrics['f1']:.3f}")
            print(f"\nTrue Positives:  {metrics['true_positive']}")
            print(f"False Positives: {metrics['false_positive']}")
            print(f"False Negatives: {metrics['false_negative']}")
            print(f"\nGround Truth ADRs (MedDRA): {metrics['ground_truth_adrs']}")
            print(f"Predicted ADRs:             {metrics['predicted_adrs']}")

            if false_positives:
                print(f"\nFALSE POSITIVES (Predicted as ADR but not in MedDRA):")
                print(f"{'-'*80}")
                for i, fp in enumerate(false_positives[:5], 1):
                    print(f"{i}. [{fp.start}-{fp.end}] '{fp.text}'")
                if len(false_positives) > 5:
                    print(f"... and {len(false_positives) - 5} more")

            if false_negatives:
                print(f"\nFALSE NEGATIVES (In MedDRA but not predicted):")
                print(f"{'-'*80}")
                for i, fn in enumerate(false_negatives[:5], 1):
                    print(f"{i}. [{fn.start}-{fn.end}] '{fn.text}'")
                if len(false_negatives) > 5:
                    print(f"... and {len(false_negatives) - 5} more")

        return {
            'metrics': metrics,
            'false_positives': false_positives,
            'false_negatives': false_negatives,
            'file': text_filename
        }

    def evaluate_directory(self, text_dir: str, max_files: int = None) -> Dict:

        text_files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])
        if max_files:
            text_files = text_files[:max_files]

        print(f"\n{'='*80}")
        print(f"EVALUATING ADR DETECTION ON {len(text_files)} FILES")
        print(f"Ground Truth Source: MedDRA directory")
        print(f"{'='*80}\n")

        all_results = []
        aggregate_tp = 0
        aggregate_fp = 0
        aggregate_fn = 0

        for idx, filename in enumerate(text_files, 1):
            result = self.evaluate_single_file(filename, verbose=False)

            if result:
                all_results.append(result)
                m = result['metrics']
                aggregate_tp += m['true_positive']
                aggregate_fp += m['false_positive']
                aggregate_fn += m['false_negative']

                print(f"[{idx}/{len(text_files)}] {filename}: "
                      f"P={m['precision']:.3f} R={m['recall']:.3f} F1={m['f1']:.3f} "
                      f"(GT:{m['ground_truth_adrs']} Pred:{m['predicted_adrs']})")

        overall_precision = aggregate_tp / (aggregate_tp + aggregate_fp) if (aggregate_tp + aggregate_fp) > 0 else 0.0
        overall_recall = aggregate_tp / (aggregate_tp + aggregate_fn) if (aggregate_tp + aggregate_fn) > 0 else 0.0
        overall_f1 = 2 * (overall_precision * overall_recall) / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.0

        print(f"\n{'='*80}")
        print(f"AGGREGATE ADR DETECTION RESULTS")
        print(f"{'='*80}")
        print(f"\nFiles Evaluated: {len(all_results)}")
        print(f"\nOVERALL METRICS:")
        print(f"  Precision: {overall_precision:.3f}")
        print(f"  Recall:    {overall_recall:.3f}")
        print(f"  F1 Score:  {overall_f1:.3f}")
        print(f"\n  True Positives:  {aggregate_tp}")
        print(f"  False Positives: {aggregate_fp}")
        print(f"  False Negatives: {aggregate_fn}")

        return {
            'precision': overall_precision,
            'recall': overall_recall,
            'f1': overall_f1,
            'true_positive': aggregate_tp,
            'false_positive': aggregate_fp,
            'false_negative': aggregate_fn,
            'num_files': len(all_results),
            'individual_results': all_results
        }


def generate_predictions_for_adr_eval(text_dir: str, output_dir: str, max_files: int = 10):

    from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
    import torch
    import re

    print(f"{'='*80}")
    print(f"GENERATING PREDICTIONS FOR ADR EVALUATION")
    print(f"{'='*80}\n")

    print("Loading NER model...")
    tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
    model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
    ner_pipeline = pipeline(
        "ner",
        model=model,
        tokenizer=tokenizer,
        aggregation_strategy="simple",
        device=0 if torch.cuda.is_available() else -1
    )
    print("Model loaded\n")

    os.makedirs(os.path.join(output_dir, 'ann'), exist_ok=True)

    text_files = sorted([f for f in os.listdir(text_dir) if f.endswith('.txt')])[:max_files]

    print(f"Processing {len(text_files)} files...\n")

    for idx, filename in enumerate(text_files, 1):
        print(f"[{idx}/{len(text_files)}] {filename}")

        with open(os.path.join(text_dir, filename), 'r', encoding='utf-8') as f:
            text = f.read().strip()

        text_lower = text.lower()

        results = ner_pipeline(text)
        entities = []

        for entity in results:
            word = entity['word']
            entity_type = entity['entity_group'].lower()
            start = entity['start']
            end = entity['end']

            cadec_type = None
            if 'therapeutic' in entity_type or 'drug' in entity_type:
                cadec_type = 'Drug'
            elif 'sign_symptom' in entity_type or 'symptom' in entity_type:
                cadec_type = 'Symptom'
            elif 'diagnostic' in entity_type or 'disease' in entity_type:
                cadec_type = 'Disease'
            elif 'severity' in entity_type or 'biological_structure' in entity_type:
                cadec_type = 'Symptom'

            if cadec_type:
                entities.append((word, cadec_type, start, end))

        entities = sorted(entities, key=lambda x: x[2])
        merged = []
        i = 0
        while i < len(entities):
            current_text, current_type, current_start, current_end = entities[i]

            while i + 1 < len(entities):
                next_text, next_type, next_start, next_end = entities[i + 1]
                gap = next_start - current_end

                if gap <= 1 and current_type == next_type:
                    current_text = current_text + ' ' + next_text
                    current_end = next_end
                    i += 1
                else:
                    break

            merged.append((current_text, current_type, current_start, current_end))
            i += 1

        adr_keywords = ['side effect', 'adverse', 'reaction', 'after taking', 'since taking',
                       'caused by', 'due to the', 'from the medication']

        final_entities = []
        for ent_text, ent_type, start, end in merged:
            context_start = max(0, start - 100)
            context_end = min(len(text), end + 100)
            context = text_lower[context_start:context_end]

            if ent_type == 'Symptom':
                has_adr_context = any(kw in context for kw in adr_keywords)

                if has_adr_context:
                    ent_type = 'ADR'

            final_entities.append((ent_text, ent_type, start, end))

        output_lines = []
        for idx_ent, (ent_text, ent_type, start, end) in enumerate(final_entities, 1):
            output_lines.append(f"T{idx_ent}\t{ent_type} {start} {end}\t{ent_text}")

        output_file = os.path.join(output_dir, 'ann', filename.replace('.txt', '.ann'))
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write('\n'.join(output_lines))

        adr_count = sum(1 for _, t, _, _ in final_entities if t == 'ADR')
        print(f"  Generated {len(final_entities)} entities ({adr_count} ADRs)")


def debug_meddra_format(meddra_dir: str, num_files: int = 2):

    print(f"\n{'='*80}")
    print("DEBUGGING MEDDRA FILE FORMAT")
    print(f"{'='*80}\n")

    files = sorted([f for f in os.listdir(meddra_dir) if f.endswith('.ann')])[:num_files]

    for filename in files:
        filepath = os.path.join(meddra_dir, filename)
        print(f"File: {filename}")
        print(f"{'-'*80}")

        with open(filepath, 'r', encoding='utf-8') as f:
            lines = f.readlines()[:5]
            for i, line in enumerate(lines, 1):
                print(f"Line {i}: {line.strip()}")

        print(f"\n")


if __name__ == "__main__":
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        print("Drive mounted\n")
    except:
        pass

    print(f"{'='*80}")
    print(f"ADR EVALUATION USING MEDDRA GROUND TRUTH")
    print(f"{'='*80}")
    print(f"\nDifference from previous evaluation:")
    print(f"  - Ground truth: MedDRA directory (ADR-only annotations)")
    print(f"  - Previous: Original directory (all entity types)")
    print(f"  - Focus: Only ADR detection performance")
    print(f"\n{'='*80}\n")

    debug_meddra_format(MEDDRA_DIR, num_files=3)

    generate_predictions_for_adr_eval(TEXT_DIR, OUTPUT_DIR, max_files=10)

    print(f"\n{'='*80}")
    print(f"EVALUATION PHASE")
    print(f"{'='*80}\n")

    evaluator = ADREvaluatorMedDRA()

    print("Example: Single file evaluation\n")
    text_files = sorted([f for f in os.listdir(TEXT_DIR) if f.endswith('.txt')])
    if text_files:
        evaluator.evaluate_single_file(text_files[0], verbose=True)

    print(f"\n\nAggregate evaluation on 10 files:\n")
    results = evaluator.evaluate_directory(TEXT_DIR, max_files=10)

    print(f"\n{'='*80}")
    print(f"To evaluate more files:")
    print(f"  results = evaluator.evaluate_directory(TEXT_DIR, max_files=50)")
    print(f"{'='*80}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted

EVALUATING 50 RANDOMLY SELECTED FORUM POSTS

Selected 50 files (seed=42)
Sample: ARTHROTEC.38.txt, ARTHROTEC.41.txt, ARTHROTEC.45.txt, ARTHROTEC.46.txt, CAMBIA.2.txt...

Loading NER model...


Device set to use cuda:0


Model loaded

Processing files...

[1/50] ARTHROTEC.38.txt: P=0.200 R=0.083 F1=0.118
[2/50] ARTHROTEC.41.txt: P=0.077 R=0.053 F1=0.062
[3/50] ARTHROTEC.45.txt: P=0.000 R=0.000 F1=0.000
[4/50] ARTHROTEC.46.txt: P=0.000 R=0.000 F1=0.000
[5/50] CAMBIA.2.txt: P=0.000 R=0.000 F1=0.000
[6/50] CATAFLAM.2.txt: P=0.000 R=0.000 F1=0.000
[7/50] DICLOFENAC-SODIUM.7.txt: P=0.000 R=0.000 F1=0.000
[8/50] LIPITOR.142.txt: P=0.050 R=0.100 F1=0.067
[9/50] LIPITOR.151.txt: P=0.000 R=0.000 F1=0.000
[10/50] LIPITOR.153.txt: P=0.200 R=0.143 F1=0.167
[11/50] LIPITOR.173.txt: P=0.000 R=0.000 F1=0.000
[12/50] LIPITOR.183.txt: P=0.125 R=0.077 F1=0.095
[13/50] LIPITOR.196.txt: P=0.000 R=0.000 F1=0.000
[14/50] LIPITOR.213.txt: P=0.000 R=0.000 F1=0.000
[15/50] LIPITOR.234.txt: P=0.000 R=0.000 F1=0.000
[16/50] LIPITOR.250.txt: P=0.000 R=0.000 F1=0.000
[17/50] LIPITOR.289.txt: P=0.000 R=0.000 F1=0.000
[18/50] LIPITOR.306.txt: P=0.231 R=0.375 F1=0.286
[19/50] LIPITOR.319.txt: P=0.000 R=0.000 F1=0.000
[20/50] LIPITOR.

In [None]:
import os
import re
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from difflib import SequenceMatcher
import numpy as np


from google.colab import drive
drive.mount('/content/drive')
BASE_PATH = '/content/drive/MyDrive/cadec'
TEXT_DIR = os.path.join(BASE_PATH, 'text')
ORIGINAL_DIR = os.path.join(BASE_PATH, 'original')
SCT_DIR = os.path.join(BASE_PATH, 'sct')
MEDDRA_DIR = os.path.join(BASE_PATH, 'meddra')
PREDICTED_DIR = os.path.join(BASE_PATH, 'generated_annotations')

# DATA STRUCTURES

@dataclass
class SNOMEDMapping:

    standard_code: str
    standard_description: str
    label_type: str
    ground_truth_text: str
    entity_id: str = ""
    char_start: int = 0
    char_end: int = 0

    def __repr__(self):
        return (f"SNOMEDMapping(code={self.standard_code}, "
                f"desc='{self.standard_description[:50]}...', "
                f"type={self.label_type}, text='{self.ground_truth_text}')")


@dataclass
class PredictedADR:
    """Represents a predicted ADR entity from Task 2 output."""
    entity_id: str
    text: str
    start: int
    end: int

    def __repr__(self):
        return f"PredictedADR(id={self.entity_id}, text='{self.text}')"


@dataclass
class MatchResult:
    """Result of matching predicted ADR to SNOMED code."""
    predicted_adr: PredictedADR
    matched_mapping: SNOMEDMapping
    similarity_score: float
    match_method: str  # 'string' or 'embedding'

    def __repr__(self):
        return (f"Match('{self.predicted_adr.text}' → "
                f"{self.matched_mapping.standard_code}|"
                f"{self.matched_mapping.standard_description[:30]}..., "
                f"score={self.similarity_score:.3f})")


# STEP 1: PARSE FILES AND CREATE DATA STRUCTURE

def parse_original_annotations(filepath: str) -> Dict[str, Dict]:

    entities = {}

    if not os.path.exists(filepath):
        print(f"❌ File not found: {filepath}")
        return entities

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue

            try:
                parts = line.split('\t')
                if len(parts) < 3:
                    continue

                entity_id = parts[0].strip()

                # Parse type and positions
                type_pos = parts[1].strip().split()
                entity_type = type_pos[0]

                # Extract start and end positions
                positions = []
                for i in range(1, len(type_pos)):
                    if type_pos[i].replace(';', '').isdigit():
                        positions.append(int(type_pos[i].replace(';', '')))

                if len(positions) >= 2:
                    start = positions[0]
                    end = positions[1]
                else:
                    continue

                text = parts[2].strip()

                entities[entity_id] = {
                    'type': entity_type,
                    'start': start,
                    'end': end,
                    'text': text
                }

            except (ValueError, IndexError) as e:
                continue

    return entities


def parse_sct_annotations(filepath: str) -> Dict[str, List[Tuple[str, str]]]:

    sct_mappings = {}

    if not os.path.exists(filepath):
        print(f" File not found: {filepath}")
        return sct_mappings

    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue

            try:
                parts = line.split('\t')
                if len(parts) < 2:
                    continue

                # Entity ID has extra 'T' prefix
                entity_id = parts[0].strip()

                # Parse SNOMED codes and descriptions
                snomed_info = parts[1].strip()

                # Pattern: <code>|<description>|
                pattern = r'(\d+)\s*\|\s*([^|]+?)\s*\|'
                matches = re.findall(pattern, snomed_info)

                if matches:
                    sct_mappings[entity_id] = [(code.strip(), desc.strip())
                                                for code, desc in matches]

            except (ValueError, IndexError) as e:
                continue

    return sct_mappings


def create_snomed_data_structure(filename: str) -> List[SNOMEDMapping]:

    print(f"\n{'='*80}")
    print(f"CREATING SNOMED DATA STRUCTURE FOR: {filename}")
    print(f"{'='*80}")

    # File paths
    original_path = os.path.join(ORIGINAL_DIR, filename)
    sct_path = os.path.join(SCT_DIR, filename)

    print(f"\nReading files:")
    print(f"  Original: {original_path}")
    print(f"  SCT: {sct_path}")

    # Parse both files
    original_entities = parse_original_annotations(original_path)
    sct_data = parse_sct_annotations(sct_path)

    print(f"\nParsed:")
    print(f"  {len(original_entities)} entities from original file")
    print(f"  {len(sct_data)} SNOMED mappings from SCT file")

    # Create unified data structure
    snomed_mappings = []

    for entity_id, entity_info in original_entities.items():
        # SCT file has 'T' prefix before entity ID
        sct_id = 'T' + entity_id

        # Check if this entity has SNOMED mapping
        if sct_id in sct_data:
            # Get all SNOMED codes for this entity
            code_desc_pairs = sct_data[sct_id]

            # Create a mapping for each SNOMED code
            for code, description in code_desc_pairs:
                mapping = SNOMEDMapping(
                    standard_code=code,
                    standard_description=description,
                    label_type=entity_info['type'],
                    ground_truth_text=entity_info['text'],
                    entity_id=entity_id,
                    char_start=entity_info['start'],
                    char_end=entity_info['end']
                )
                snomed_mappings.append(mapping)

    print(f"\n✓ Created {len(snomed_mappings)} SNOMED mappings")

    # Summary by type
    type_counts = {}
    for mapping in snomed_mappings:
        type_counts[mapping.label_type] = type_counts.get(mapping.label_type, 0) + 1

    print(f"\nBreakdown by type:")
    for entity_type in ['ADR', 'Drug', 'Disease', 'Symptom']:
        count = type_counts.get(entity_type, 0)
        if count > 0:
            print(f"  {entity_type}: {count}")

    return snomed_mappings


def display_snomed_structure(mappings: List[SNOMEDMapping], max_display: int = 10):
    """Display the SNOMED data structure in readable format."""
    print(f"\n{'='*80}")
    print(f"SNOMED DATA STRUCTURE (showing first {max_display})")
    print(f"{'='*80}\n")

    for i, mapping in enumerate(mappings[:max_display], 1):
        print(f"{i}. Entity ID: {mapping.entity_id}")
        print(f"   Label Type: {mapping.label_type}")
        print(f"   Ground Truth Text: '{mapping.ground_truth_text}'")
        print(f"   SNOMED Code: {mapping.standard_code}")
        print(f"   SNOMED Description: {mapping.standard_description}")
        print(f"   Position: [{mapping.char_start}-{mapping.char_end}]")
        print()

    if len(mappings) > max_display:
        print(f"... and {len(mappings) - max_display} more mappings\n")



# STEP 2: LOAD PREDICTED ADRs FROM TASK 2

def load_predicted_adrs(filename: str) -> List[PredictedADR]:
    """
    Load predicted ADR entities from Task 2 output files.

    If Task 2 predictions are not found, falls back to using ground truth ADRs
    from the original file for demonstration purposes.

    Args:
        filename: Base filename (e.g., "ARTHROTEC.1.ann")

    Returns:
        List of PredictedADR objects
    """
    print(f"\n{'='*80}")
    print(f"LOADING PREDICTED ADRs FROM TASK 2")
    print(f"{'='*80}")

    # Try multiple possible paths for predictions
    possible_paths = [
        os.path.join(PREDICTED_DIR, filename),
        os.path.join(PREDICTED_DIR, 'ann', filename),
        os.path.join(PREDICTED_DIR, filename.replace('.ann', '') + '.ann'),
    ]

    predicted_file = None
    for path in possible_paths:
        if os.path.exists(path):
            predicted_file = path
            break

    if not predicted_file:
        print(f"\n  No Task 2 predictions found. Searched:")
        for path in possible_paths:
            print(f"  • {path}")
        print(f"\n Using ground truth ADRs from original file as demonstration...")
        return load_ground_truth_adrs_as_predictions(filename)

    print(f"\n✓ Found predictions: {predicted_file}")

    # Parse predictions
    predicted_adrs = []
    entities = parse_original_annotations(predicted_file)

    for entity_id, entity_info in entities.items():
        if entity_info['type'] == 'ADR':
            adr = PredictedADR(
                entity_id=entity_id,
                text=entity_info['text'],
                start=entity_info['start'],
                end=entity_info['end']
            )
            predicted_adrs.append(adr)

    print(f"\n✓ Loaded {len(predicted_adrs)} predicted ADR entities")

    return predicted_adrs


def load_ground_truth_adrs_as_predictions(filename: str) -> List[PredictedADR]:
    """Fallback: Use ground truth ADRs as sample predictions for demonstration."""
    original_path = os.path.join(ORIGINAL_DIR, filename)
    entities = parse_original_annotations(original_path)

    predicted_adrs = []
    for entity_id, entity_info in entities.items():
        if entity_info['type'] == 'ADR':
            adr = PredictedADR(
                entity_id=entity_id,
                text=entity_info['text'],
                start=entity_info['start'],
                end=entity_info['end']
            )
            predicted_adrs.append(adr)

    print(f"\n✓ Using {len(predicted_adrs)} ground truth ADRs for demonstration")

    return predicted_adrs



# STEP 3: METHOD A - STRING SIMILARITY MATCHING


def calculate_string_similarity(text1: str, text2: str) -> float:
    """
    Calculate string similarity using SequenceMatcher (Ratcliff/Obershelp).

    Returns similarity ratio between 0.0 and 1.0.
    """
    return SequenceMatcher(None, text1.lower(), text2.lower()).ratio()


def match_adr_using_string_similarity(predicted_adr: PredictedADR,
                                     snomed_mappings: List[SNOMEDMapping]) -> Optional[MatchResult]:
    """
    METHOD A: Match predicted ADR to SNOMED codes using approximate string matching.

    Strategy:
    1. Compare predicted ADR text with SNOMED standard descriptions
    2. Also compare with ground truth text segments
    3. Return best match with highest similarity score

    Args:
        predicted_adr: The predicted ADR entity
        snomed_mappings: All available SNOMED mappings

    Returns:
        MatchResult with best match, or None if no good match found
    """
    best_match = None
    best_score = 0.0

    # Only consider ADR mappings
    adr_mappings = [m for m in snomed_mappings if m.label_type == 'ADR']

    for mapping in adr_mappings:
        # Compare with SNOMED standard description
        score_desc = calculate_string_similarity(
            predicted_adr.text,
            mapping.standard_description
        )

        # Compare with ground truth text
        score_text = calculate_string_similarity(
            predicted_adr.text,
            mapping.ground_truth_text
        )

        # Use the better score
        score = max(score_desc, score_text)

        if score > best_score:
            best_score = score
            best_match = MatchResult(
                predicted_adr=predicted_adr,
                matched_mapping=mapping,
                similarity_score=score,
                match_method='string'
            )

    return best_match


def perform_string_matching(predicted_adrs: List[PredictedADR],
                           snomed_mappings: List[SNOMEDMapping]) -> List[MatchResult]:
    """
    Perform string similarity matching for all predicted ADRs.

    Returns:
        List of MatchResult objects
    """
    print(f"\n{'='*80}")
    print(f"METHOD A: STRING SIMILARITY MATCHING")
    print(f"{'='*80}")
    print(f"\nMatching {len(predicted_adrs)} predicted ADRs using string similarity...")

    matches = []
    for adr in predicted_adrs:
        match = match_adr_using_string_similarity(adr, snomed_mappings)
        if match:
            matches.append(match)

    print(f"\n✓ Successfully matched {len(matches)}/{len(predicted_adrs)} ADRs")

    return matches



# STEP 4: METHOD B - EMBEDDING-BASED MATCHING


def load_embedding_model():
    """
    Load pre-trained embedding model from Hugging Face.

    Using sentence-transformers/all-MiniLM-L6-v2 for semantic similarity.
    For medical domain, could use: dmis-lab/biobert-v1.1 or similar.
    """
    from transformers import AutoTokenizer, AutoModel
    import torch

    print(f"\n🔄 Loading embedding model...")

    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    print(f"  Model: {model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)



    return tokenizer, model, device


def get_text_embedding(text: str, tokenizer, model, device) -> np.ndarray:
    """
    Generate embedding vector for text using transformer model.

    Uses mean pooling over token embeddings.
    """
    import torch

    # Tokenize
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding=True
    )

    # Move to device
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get embeddings
    with torch.no_grad():
        outputs = model(**inputs)

    # Mean pooling
    embedding = outputs.last_hidden_state.mean(dim=1).cpu().numpy()[0]

    return embedding


def calculate_cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:

    dot_product = np.dot(vec1, vec2)
    norm_product = np.linalg.norm(vec1) * np.linalg.norm(vec2)

    if norm_product == 0:
        return 0.0

    return float(dot_product / norm_product)


def match_adr_using_embeddings(predicted_adr: PredictedADR,
                               snomed_mappings: List[SNOMEDMapping],
                               tokenizer, model, device) -> Optional[MatchResult]:
    """
    METHOD B: Match predicted ADR to SNOMED codes using embedding similarity.



    Args:
        predicted_adr: The predicted ADR entity
        snomed_mappings: All available SNOMED mappings

    Returns:
        MatchResult with best match, or None if no good match found
    """
    # Get embedding for predicted text
    pred_embedding = get_text_embedding(predicted_adr.text, tokenizer, model, device)

    best_match = None
    best_score = 0.0

    # Only consider ADR mappings
    adr_mappings = [m for m in snomed_mappings if m.label_type == 'ADR']

    for mapping in adr_mappings:
        # Get embedding for SNOMED description
        desc_embedding = get_text_embedding(
            mapping.standard_description,
            tokenizer, model, device
        )

        # Calculate similarity
        score_desc = calculate_cosine_similarity(pred_embedding, desc_embedding)

        # Also compare with ground truth text
        text_embedding = get_text_embedding(
            mapping.ground_truth_text,
            tokenizer, model, device
        )
        score_text = calculate_cosine_similarity(pred_embedding, text_embedding)

        # Use better score
        score = max(score_desc, score_text)

        if score > best_score:
            best_score = score
            best_match = MatchResult(
                predicted_adr=predicted_adr,
                matched_mapping=mapping,
                similarity_score=score,
                match_method='embedding'
            )

    return best_match


def perform_embedding_matching(predicted_adrs: List[PredictedADR],
                              snomed_mappings: List[SNOMEDMapping]) -> List[MatchResult]:
    """
    Perform embedding-based matching for all predicted ADRs.

    Returns:
        List of MatchResult objects
    """
    print(f"\n{'='*80}")
    print(f"METHOD B: EMBEDDING-BASED MATCHING")
    print(f"{'='*80}")

    # Load model
    tokenizer, model, device = load_embedding_model()

    print(f"Matching {len(predicted_adrs)} predicted ADRs using embeddings...")

    matches = []
    for adr in predicted_adrs:
        match = match_adr_using_embeddings(adr, snomed_mappings, tokenizer, model, device)
        if match:
            matches.append(match)

    print(f"\n✓ Successfully matched {len(matches)}/{len(predicted_adrs)} ADRs")

    return matches



# STEP 5: COMPARE RESULTS

def display_matching_results(matches: List[MatchResult], method_name: str):
    """Display matching results in readable format."""
    print(f"\n{'-'*80}")
    print(f"{method_name} - RESULTS")
    print(f"{'-'*80}\n")

    if not matches:
        print("❌ No matches found")
        return

    for i, match in enumerate(matches, 1):
        print(f"{i}. Predicted ADR: '{match.predicted_adr.text}'")
        print(f"   ↓")
        print(f"   SNOMED Code: {match.matched_mapping.standard_code}")
        print(f"   SNOMED Description: {match.matched_mapping.standard_description}")
        print(f"   Ground Truth: '{match.matched_mapping.ground_truth_text}'")
        print(f"   Similarity Score: {match.similarity_score:.4f}")
        print()

    # Calculate statistics
    avg_score = sum(m.similarity_score for m in matches) / len(matches)
    print(f"📊 Average Similarity Score: {avg_score:.4f}")
    print(f"📊 Total Matches: {len(matches)}")


def compare_methods(string_matches: List[MatchResult],
                   embedding_matches: List[MatchResult]):
    """
    Compare results from both matching methods.

    Analyzes:
    - Number of matches found
    - Average similarity scores
    - Agreement rate (same SNOMED code assigned)
    - Disagreements and their causes
    """
    print(f"\n{'='*80}")
    print(f"COMPARISON OF METHODS")
    print(f"{'='*80}")

    print(f"\n📊 Match Statistics:")
    print(f"  String Method: {len(string_matches)} matches")
    print(f"  Embedding Method: {len(embedding_matches)} matches")

    if not string_matches or not embedding_matches:
        print("\n⚠️  Cannot compare - one or both methods found no matches")
        return

    # Compare average scores
    avg_string = sum(m.similarity_score for m in string_matches) / len(string_matches)
    avg_embedding = sum(m.similarity_score for m in embedding_matches) / len(embedding_matches)

    print(f"\n📈 Average Similarity Scores:")
    print(f"  String Method: {avg_string:.4f}")
    print(f"  Embedding Method: {avg_embedding:.4f}")

    # Calculate agreement
    min_len = min(len(string_matches), len(embedding_matches))
    agreements = 0
    disagreements = []

    for i in range(min_len):
        str_code = string_matches[i].matched_mapping.standard_code
        emb_code = embedding_matches[i].matched_mapping.standard_code

        if str_code == emb_code:
            agreements += 1
        else:
            disagreements.append((i, string_matches[i], embedding_matches[i]))

    agreement_rate = (agreements / min_len) * 100
    print(f"\n🤝 Agreement Rate: {agreements}/{min_len} ({agreement_rate:.1f}%)")

    # Analyze disagreements
    if disagreements:
        print(f"\n⚠️  Disagreements ({len(disagreements)} cases):")
        for idx, str_match, emb_match in disagreements[:5]:  # Show first 5
            print(f"\n  Case {idx + 1}: '{str_match.predicted_adr.text}'")
            print(f"    String → {str_match.matched_mapping.standard_code} "
                  f"(score: {str_match.similarity_score:.3f})")
            print(f"      Description: {str_match.matched_mapping.standard_description}")
            print(f"    Embedding → {emb_match.matched_mapping.standard_code} "
                  f"(score: {emb_match.similarity_score:.3f})")
            print(f"      Description: {emb_match.matched_mapping.standard_description}")

        if len(disagreements) > 5:
            print(f"\n  ... and {len(disagreements) - 5} more disagreements")

    # Conclusion
    print(f"\n🎯 Conclusion:")
    if avg_embedding > avg_string:
        diff = ((avg_embedding - avg_string) / avg_string) * 100
        print(f"  • Embedding method shows {diff:.1f}% higher average similarity")
        print(f"  • Embeddings capture semantic similarity better")
        print(f"  • Recommended for medical terminology with synonyms")
    elif avg_string > avg_embedding:
        diff = ((avg_string - avg_embedding) / avg_embedding) * 100
        print(f"  • String method shows {diff:.1f}% higher average similarity")
        print(f"  • String matching works well for exact/near-exact matches")
        print(f"  • Faster and simpler than embeddings")
    else:
        print(f"  • Both methods show similar performance")

    if agreement_rate >= 80:
        print(f"  • High agreement rate ({agreement_rate:.1f}%) indicates consistency")
    else:
        print(f"  • Lower agreement rate suggests methods capture different aspects")



# MAIN PIPELINE


def run_task6_pipeline(filename: str):
    """
    Complete Task 6 pipeline.

    Steps:
    1. Create SNOMED data structure from original + sct files
    2. Load predicted ADRs from Task 2
    3. Match using Method A (string similarity)
    4. Match using Method B (embedding similarity)
    5. Compare both methods

    Args:
        filename: Name of file to process (e.g., "ARTHROTEC.1.ann")
    """
    print(f"\n{'#'*80}")
    print(f"# TASK 6: SNOMED CT CODE MAPPING")
    print(f"# File: {filename}")
    print(f"{'#'*80}")

    # Step 1: Create SNOMED data structure
    snomed_mappings = create_snomed_data_structure(filename)

    if not snomed_mappings:
        print(f"\n❌ No SNOMED mappings found for {filename}")
        print(f"   This file may not have any entities with SNOMED codes")
        return

    display_snomed_structure(snomed_mappings, max_display=5)

    # Step 2: Load predicted ADRs
    predicted_adrs = load_predicted_adrs(filename)

    if not predicted_adrs:
        print(f"\n❌ No ADR entities to process")
        return

    print(f"\nPredicted ADRs to match:")
    for i, adr in enumerate(predicted_adrs, 1):
        print(f"  {i}. {adr}")

    # Step 3: Method A - String Matching
    string_matches = perform_string_matching(predicted_adrs, snomed_mappings)
    display_matching_results(string_matches, "METHOD A: STRING SIMILARITY")

    # Step 4: Method B - Embedding Matching
    embedding_matches = perform_embedding_matching(predicted_adrs, snomed_mappings)
    display_matching_results(embedding_matches, "METHOD B: EMBEDDING SIMILARITY")

    # Step 5: Compare Methods
    compare_methods(string_matches, embedding_matches)

    print(f"\n{'#'*80}")
    print(f"# TASK 6 COMPLETED FOR {filename}")
    print(f"{'#'*80}\n")



# MAIN EXECUTION


if __name__ == "__main__":
    # Mount Google Drive if in Colab
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        print("✓ Google Drive mounted\n")
    except:
        print("Running in local environment\n")


    # SELECT FILE TO PROCESS


    # Choose a file from the dataset
    # From the listing, we know files are named like: ARTHROTEC.1.ann
    filename = "ARTHROTEC.1.ann"

    print(f"Processing file: {filename}")
    print(f"You can change this by modifying the 'filename' variable\n")

    # Run the complete Task 6 pipeline
    run_task6_pipeline(filename)

    print("\n" + "="*80)
    print("TASK 6 COMPLETE")
    print("="*80)
    print("\nSummary:")
    print("✓ Combined original and sct data into SNOMED mapping structure")
    print("✓ Matched predicted ADRs using string similarity (Method A)")
    print("✓ Matched predicted ADRs using embedding similarity (Method B)")
    print("✓ Compared both methods and analyzed results")
    print("\n" + "="*80)



# ADDITIONAL UTILITIES


def process_multiple_files(filenames: List[str]):
    """
    Process multiple files in batch mode.

    Useful for analyzing performance across multiple files.

    Args:
        filenames: List of filenames to process
    """
    print(f"\n{'#'*80}")
    print(f"# BATCH PROCESSING: {len(filenames)} FILES")
    print(f"{'#'*80}\n")

    all_string_scores = []
    all_embedding_scores = []
    all_agreements = []

    for i, filename in enumerate(filenames, 1):
        print(f"\n{'='*80}")
        print(f"Processing file {i}/{len(filenames)}: {filename}")
        print(f"{'='*80}")

        try:
            # Create SNOMED structure
            snomed_mappings = create_snomed_data_structure(filename)
            if not snomed_mappings:
                print(f"⚠️  No SNOMED mappings for {filename}, skipping...")
                continue

            # Load predictions
            predicted_adrs = load_predicted_adrs(filename)
            if not predicted_adrs:
                print(f"⚠️  No ADRs for {filename}, skipping...")
                continue

            # Match with both methods
            string_matches = perform_string_matching(predicted_adrs, snomed_mappings)
            embedding_matches = perform_embedding_matching(predicted_adrs, snomed_mappings)

            # Collect statistics
            if string_matches:
                avg_str = sum(m.similarity_score for m in string_matches) / len(string_matches)
                all_string_scores.append(avg_str)

            if embedding_matches:
                avg_emb = sum(m.similarity_score for m in embedding_matches) / len(embedding_matches)
                all_embedding_scores.append(avg_emb)

            # Calculate agreement
            if string_matches and embedding_matches:
                min_len = min(len(string_matches), len(embedding_matches))
                agreements = sum(1 for i in range(min_len)
                               if string_matches[i].matched_mapping.standard_code ==
                                  embedding_matches[i].matched_mapping.standard_code)
                agreement_rate = (agreements / min_len) * 100
                all_agreements.append(agreement_rate)

                print(f"\nFile Results:")
                print(f"  String avg score: {avg_str:.4f}")
                print(f"  Embedding avg score: {avg_emb:.4f}")
                print(f"  Agreement: {agreement_rate:.1f}%")

        except Exception as e:
            print(f"❌ Error processing {filename}: {e}")
            continue

    # Aggregate results
    print(f"\n{'#'*80}")
    print(f"# BATCH PROCESSING SUMMARY")
    print(f"{'#'*80}\n")

    if all_string_scores:
        print(f"String Similarity Method:")
        print(f"  Files processed: {len(all_string_scores)}")
        print(f"  Average score: {np.mean(all_string_scores):.4f}")
        print(f"  Std deviation: {np.std(all_string_scores):.4f}")
        print(f"  Min/Max: {min(all_string_scores):.4f} / {max(all_string_scores):.4f}")

    if all_embedding_scores:
        print(f"\nEmbedding Similarity Method:")
        print(f"  Files processed: {len(all_embedding_scores)}")
        print(f"  Average score: {np.mean(all_embedding_scores):.4f}")
        print(f"  Std deviation: {np.std(all_embedding_scores):.4f}")
        print(f"  Min/Max: {min(all_embedding_scores):.4f} / {max(all_embedding_scores):.4f}")

    if all_agreements:
        print(f"\nAgreement Rates:")
        print(f"  Files compared: {len(all_agreements)}")
        print(f"  Average agreement: {np.mean(all_agreements):.1f}%")
        print(f"  Std deviation: {np.std(all_agreements):.1f}%")
        print(f"  Min/Max: {min(all_agreements):.1f}% / {max(all_agreements):.1f}%")


def verify_file_exists(filename: str) -> bool:
    """
    Verify that a file exists in all required directories.

    Args:
        filename: Name of file to check

    Returns:
        True if file exists in all directories
    """
    required_files = {
        'original': os.path.join(ORIGINAL_DIR, filename),
        'sct': os.path.join(SCT_DIR, filename),
        'text': os.path.join(TEXT_DIR, filename.replace('.ann', '.txt') if filename.endswith('.ann') else filename)
    }

    all_exist = True
    print(f"\nVerifying file: {filename}")

    for name, path in required_files.items():
        exists = os.path.exists(path)
        status = "✓" if exists else "❌"
        print(f"  {status} {name:10}: {path}")
        if not exists:
            all_exist = False

    return all_exist


def get_sample_files(n: int = 5) -> List[str]:
    """
    Get a sample of files from the dataset for testing.

    Args:
        n: Number of files to return

    Returns:
        List of filenames
    """
    if not os.path.exists(ORIGINAL_DIR):
        print(f"❌ Directory not found: {ORIGINAL_DIR}")
        return []

    files = sorted([f for f in os.listdir(ORIGINAL_DIR) if f.endswith('.ann')])

    if len(files) == 0:

        return []

    # Return first n files
    sample = files[:n]

    print(f"\nSample files ({n}):")
    for i, f in enumerate(sample, 1):
        print(f"  {i}. {f}")

    return sample


def analyze_snomed_coverage():

    all_files = [f for f in os.listdir(ORIGINAL_DIR) if f.endswith('.ann')]
    print(f"Analyzing {len(all_files)} files...")

    total_entities = 0
    total_with_snomed = 0
    type_counts = {'ADR': 0, 'Drug': 0, 'Disease': 0, 'Symptom': 0}
    type_with_snomed = {'ADR': 0, 'Drug': 0, 'Disease': 0, 'Symptom': 0}

    for filename in all_files[:100]:  # Sample first 100 files for speed
        original_entities = parse_original_annotations(os.path.join(ORIGINAL_DIR, filename))
        sct_data = parse_sct_annotations(os.path.join(SCT_DIR, filename))

        for entity_id, entity_info in original_entities.items():
            total_entities += 1
            entity_type = entity_info['type']

            if entity_type in type_counts:
                type_counts[entity_type] += 1

            # Check if has SNOMED
            sct_id = 'T' + entity_id
            if sct_id in sct_data:
                total_with_snomed += 1
                if entity_type in type_with_snomed:
                    type_with_snomed[entity_type] += 1

    # Display results
    coverage_rate = (total_with_snomed / total_entities * 100) if total_entities > 0 else 0

    print(f"Total Entities: {total_entities}")
    print(f"Total Entities with SNOMED: {total_with_snomed} ({coverage_rate:.1f}%)")
    for entity_type in ['ADR', 'Drug', 'Disease', 'Symptom']:
        total = type_counts[entity_type]
        with_snomed = type_with_snomed[entity_type]
        coverage = (with_snomed / total * 100) if total > 0 else 0
        print(f"  {entity_type:10}: {with_snomed:4}/{total:4} ({coverage:5.1f}%)")



