In [11]:
# ============================================================================
# IMPORTS AND CONFIGURATION
# ============================================================================

import os
import re
import random
import pandas as pd
from tqdm.auto import tqdm
from collections import defaultdict
from typing import TypedDict, Annotated, List, Dict
import operator

# Load environment variables from .env file
from dotenv import load_dotenv
load_dotenv()

# LangGraph & LangChain
from langgraph.graph import StateGraph, END
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage

# Embeddings & Matching
from sentence_transformers import SentenceTransformer, util
from thefuzz import process as fuzzy_process

# For evaluation
from sklearn.metrics import precision_score, recall_score, f1_score

# ============================================================================
# CONFIGURATION
# ============================================================================

# Data paths - reads from .env or uses defaults
DATA_DIR = os.getenv('DATA_DIR', './cadec')
TEXT_DIR = os.path.join(DATA_DIR, 'text')
ORIGINAL_DIR = os.path.join(DATA_DIR, 'original')
MEDDRA_DIR = os.path.join(DATA_DIR, 'meddra')
SCT_DIR = os.path.join(DATA_DIR, 'sct')

# Verify paths exist
if not os.path.exists(DATA_DIR):
    raise FileNotFoundError(f"Data directory not found: {DATA_DIR}")

print(f"Data directory: {os.path.abspath(DATA_DIR)}")
print(f"Text files: {len([f for f in os.listdir(TEXT_DIR) if f.endswith('.txt')])} found")

# ============================================================================
# API INITIALIZATION
# ============================================================================

# Get API key from environment
gemini_api_key = os.getenv('GEMINI_API_KEY')
if not gemini_api_key:
    raise ValueError(
        "GEMINI_API_KEY not found. Please:\n"
        "1. Create a .env file in your project root\n"
        "2. Add: GEMINI_API_KEY=your_key_here\n"
        "3. Get key from: https://aistudio.google.com/app/apikey"
    )

# Initialize Gemini model
try:
    llm = ChatGoogleGenerativeAI(
        model="gemini-2.5-flash",
        google_api_key=gemini_api_key,
        temperature=0.1,
        max_output_tokens=4096
    )
    print("✓ Gemini API initialized")
except Exception as e:
    raise ConnectionError(f"Failed to initialize Gemini API: {e}")

# Initialize embedding model (downloads on first run)
print("Loading embedding model (this may take a moment on first run)...")
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("✓ Embedding model loaded")

# ============================================================================
# TASK 1: ENUMERATE DISTINCT ENTITIES
# ============================================================================

def enumerate_entities():
    """Parse all files to find distinct entities for each label type"""
    distinct_entities = defaultdict(set)
    line_regex = re.compile(r'^(T\d+)\t(\w+)[\s\d;]+\t(.+)$')
    
    print(f"\nProcessing files in: {ORIGINAL_DIR}")
    filenames = [f for f in os.listdir(ORIGINAL_DIR) if f.endswith('.ann')]
    
    for filename in tqdm(filenames, desc="Parsing annotation files"):
        filepath = os.path.join(ORIGINAL_DIR, filename)
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                if line.startswith('#'):
                    continue
                match = line_regex.match(line.strip())
                if match:
                    _, label, text = match.groups()
                    normalized_text = text.strip().lower()
                    if normalized_text:
                        distinct_entities[label].add(normalized_text)
    
    return distinct_entities

print("\n" + "="*70)
print("TASK 1: ENUMERATING DISTINCT ENTITIES")
print("="*70)

distinct_entities_by_label = enumerate_entities()

print("\nResults:")
for label, entities in distinct_entities_by_label.items():
    print(f"  {label:10} | {len(entities):5} unique entities")
    print(f"             | Examples: {list(entities)[:3]}")

# ============================================================================
# LANGGRAPH STATE DEFINITION
# ============================================================================

class NERState(TypedDict):
    """State for medical NER workflow"""
    # Input
    text: str
    filename: str
    
    # LLM outputs
    bio_output: str
    llm_attempts: int
    
    # Parsed entities
    raw_entities: List[tuple]
    entities: List[Dict]
    
    # Ground truth & evaluation
    ground_truth: List[Dict]
    performance: Dict[str, float]
    
    # Error tracking
    errors: Annotated[List[str], operator.add]

# ============================================================================
# PROMPT TEMPLATES
# ============================================================================

SYSTEM_PROMPT = """You are a medical NER expert specializing in extracting entities from patient forum posts.

Your task is to label EVERY word using BIO format with these EXACT tags:
- B-ADR, I-ADR: Adverse drug reactions (side effects from medication)
- B-Drug, I-Drug: Medication names ONLY (not dosages)
- B-Disease, I-Disease: Medical conditions being treated
- B-Symptom, I-Symptom: Disease symptoms (NOT drug side effects)
- O: Outside any entity

CRITICAL FORMAT RULES:
1. Output format: word/TAG word/TAG word/TAG
2. NO punctuation in tags: "pain/B-Symptom" NOT "pain/B-Symptom,"
3. NO invalid tags like B-DOSAGE, B-ACTIVITY, S-ADR
4. Every word needs exactly ONE tag
5. Punctuation gets O tag

ENTITY DISTINCTION:
- If symptom appears AFTER taking medication → ADR
- If symptom is from the disease itself → Symptom
- Numbers alone (50, 75) → O (not Drug)

EXAMPLES:

Input: "I feel dizzy and nauseous after taking Lipitor"
Output: I/O feel/B-ADR dizzy/I-ADR and/O nauseous/B-ADR after/O taking/O Lipitor/B-Drug

Input: "My arthritis pain is unbearable"
Output: My/O arthritis/B-Disease pain/B-Symptom is/O unbearable/O

Input: "Started Arthrotec 50 twice daily, no stomach issues so far"
Output: Started/O Arthrotec/B-Drug 50/O twice/O daily/O ,/O no/O stomach/B-ADR issues/I-ADR so/O far/O

Input: "Severe headache and muscle weakness from the medication"
Output: Severe/B-ADR headache/I-ADR and/O muscle/B-ADR weakness/I-ADR from/O the/O medication/O"""

def create_user_prompt(text: str) -> str:
    return f"""Label the following patient forum post in BIO format.

Text:
{text}

Output (word/TAG format only, no explanations):"""

# ============================================================================
# LANGGRAPH NODE FUNCTIONS
# ============================================================================

def generate_bio_labels(state: NERState) -> NERState:
    """Node: Generate BIO labels using Gemini"""
    try:
        messages = [
            SystemMessage(content=SYSTEM_PROMPT),
            HumanMessage(content=create_user_prompt(state['text']))
        ]
        
        response = llm.invoke(messages)
        bio_output = response.content.strip()
        
        state['bio_output'] = bio_output
        state['llm_attempts'] = state.get('llm_attempts', 0) + 1
        
    except Exception as e:
        state['errors'].append(f"LLM generation error: {str(e)}")
        state['bio_output'] = ""
    
    return state

def parse_bio_to_entities(state: NERState) -> NERState:
    """Node: Parse BIO output into entity tuples"""
    bio_string = state['bio_output']
    
    if '/' not in bio_string:
        state['errors'].append("No valid BIO tags found")
        state['raw_entities'] = []
        return state
    
    VALID_TAGS = {'B-ADR', 'I-ADR', 'B-DRUG', 'I-DRUG', 
                  'B-DISEASE', 'I-DISEASE', 'B-SYMPTOM', 'I-SYMPTOM', 'O'}
    
    entities = []
    current_entity = None
    entity_words = []
    
    for token in bio_string.split():
        token = token.rstrip('.,;:!?/').strip()
        
        if not token or '/' not in token or token.startswith('/'):
            continue
        
        parts = token.rsplit('/', 1)
        if len(parts) != 2:
            continue
        
        word, tag = parts
        word = word.strip('.,;:!?/')
        tag = tag.strip('.,;:!?/').upper().replace('DRUG', 'DRUG')
        
        if not word or word in ['/', ',', '.']:
            continue
        
        if tag not in VALID_TAGS:
            if current_entity and entity_words:
                entities.append((current_entity, entity_words))
            current_entity = None
            entity_words = []
            continue
        
        if tag.startswith('B-'):
            if current_entity and entity_words:
                entities.append((current_entity, entity_words))
            current_entity = tag[2:]
            entity_words = [word]
            
        elif tag.startswith('I-'):
            label = tag[2:]
            if current_entity == label:
                entity_words.append(word)
            else:
                if current_entity and entity_words:
                    entities.append((current_entity, entity_words))
                current_entity = label
                entity_words = [word]
                
        elif tag == 'O':
            if current_entity and entity_words:
                entities.append((current_entity, entity_words))
            current_entity = None
            entity_words = []
    
    if current_entity and entity_words:
        entities.append((current_entity, entity_words))
    
    state['raw_entities'] = entities
    
    return state

def map_to_text_spans(state: NERState) -> NERState:
    """Node: Map entities to character spans in original text"""
    original_text = state['text']
    original_lower = original_text.lower()
    
    final_entities = []
    
    for label, words in state['raw_entities']:
        entity_text = ' '.join(words)
        
        if len(entity_text) <= 1 or (entity_text.isdigit() and len(entity_text) <= 2):
            continue
        
        entity_lower = entity_text.lower()
        start_idx = original_lower.find(entity_lower)
        
        if start_idx != -1:
            actual_text = original_text[start_idx:start_idx + len(entity_text)]
            final_entities.append({
                'label': label,
                'text': actual_text,
                'start': start_idx,
                'end': start_idx + len(entity_text)
            })
        else:
            pattern = r'\s+'.join(re.escape(w) for w in words)
            match = re.search(pattern, original_lower, re.IGNORECASE)
            if match:
                actual_text = original_text[match.start():match.end()]
                final_entities.append({
                    'label': label,
                    'text': actual_text,
                    'start': match.start(),
                    'end': match.end()
                })
    
    state['entities'] = final_entities
    
    return state

def apply_postprocessing(state: NERState) -> NERState:
    """Node: Context-aware ADR detection and deduplication"""
    original_text = state['text']
    entities = state['entities']
    
    for entity in entities:
        if entity['label'] == 'SYMPTOM':
            context_start = max(0, entity['start'] - 120)
            context_end = min(len(original_text), entity['end'] + 120)
            context = original_text[context_start:context_end].lower()
            
            drug_phrases = ['taking', 'take it', 'on', 'after', 'when on', 
                           'started me on', 'caused', 'from the']
            
            if any(phrase in context for phrase in drug_phrases):
                has_drug = any(
                    e['label'] == 'DRUG' and abs(e['start'] - entity['start']) < 150
                    for e in entities
                )
                if has_drug:
                    entity['label'] = 'ADR'
    
    seen = set()
    unique = []
    for ent in entities:
        key = (ent['label'], ent['text'].lower().strip())
        if key not in seen:
            seen.add(key)
            unique.append(ent)
    
    unique.sort(key=lambda x: x['start'])
    state['entities'] = unique
    
    return state

def load_ground_truth(state: NERState) -> NERState:
    """Node: Load ground truth annotations"""
    ann_filename = state['filename'].replace('.txt', '.ann')
    ann_path = os.path.join(ORIGINAL_DIR, ann_filename)
    
    entities = []
    line_regex = re.compile(r'^(T\d+)\t(\w+)\s([\d\s;]+)\t(.+)$')
    
    with open(ann_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith('#'):
                continue
            match = line_regex.match(line.strip())
            if match:
                _, label, _, text = match.groups()
                entities.append({
                    'label': label,
                    'text': text.strip()
                })
    
    state['ground_truth'] = entities
    
    return state

def evaluate_performance(state: NERState) -> NERState:
    """Node: Calculate precision, recall, F1 with Jaccard matching"""
    predictions = state['entities']
    ground_truth = state['ground_truth']
    
    tp = 0
    matched_gt = set()
    
    for pred in predictions:
        pred_words = set(pred['text'].lower().split())
        
        for idx, gt in enumerate(ground_truth):
            if idx in matched_gt:
                continue
            
            if pred['label'].upper() != gt['label'].upper():
                continue
            
            gt_words = set(gt['text'].lower().split())
            
            if pred_words and gt_words:
                intersection = len(pred_words & gt_words)
                union = len(pred_words | gt_words)
                similarity = intersection / union if union > 0 else 0
                
                if similarity > 0.66:
                    tp += 1
                    matched_gt.add(idx)
                    break
    
    fp = len(predictions) - tp
    fn = len(ground_truth) - len(matched_gt)
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
    
    state['performance'] = {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'tp': tp,
        'fp': fp,
        'fn': fn
    }
    
    return state

# ============================================================================
# BUILD LANGGRAPH WORKFLOW
# ============================================================================

def build_ner_graph():
    """Construct the LangGraph workflow"""
    workflow = StateGraph(NERState)
    
    workflow.add_node("generate_bio", generate_bio_labels)
    workflow.add_node("parse_entities", parse_bio_to_entities)
    workflow.add_node("map_spans", map_to_text_spans)
    workflow.add_node("postprocess", apply_postprocessing)
    workflow.add_node("load_ground_truth", load_ground_truth)
    workflow.add_node("evaluate", evaluate_performance)
    
    workflow.set_entry_point("generate_bio")
    workflow.add_edge("generate_bio", "parse_entities")
    workflow.add_edge("parse_entities", "map_spans")
    workflow.add_edge("map_spans", "postprocess")
    workflow.add_edge("postprocess", "load_ground_truth")
    workflow.add_edge("load_ground_truth", "evaluate")
    workflow.add_edge("evaluate", END)
    
    return workflow.compile()

ner_graph = build_ner_graph()
print("\n✓ LangGraph workflow compiled")

# ============================================================================
# TASK 2-3: TEST ON SINGLE FILE
# ============================================================================

print("\n" + "="*70)
print("TASK 2-3: TESTING PIPELINE ON SAMPLE FILE")
print("="*70)

sample_filename = 'ARTHROTEC.1.txt'
sample_path = os.path.join(TEXT_DIR, sample_filename)

if not os.path.exists(sample_path):
    print(f"WARNING: {sample_filename} not found. Using first available file.")
    sample_filename = [f for f in os.listdir(TEXT_DIR) if f.endswith('.txt')][0]
    sample_path = os.path.join(TEXT_DIR, sample_filename)

with open(sample_path, 'r', encoding='utf-8') as f:
    sample_text = f.read()

initial_state = {
    'text': sample_text,
    'filename': sample_filename,
    'bio_output': '',
    'llm_attempts': 0,
    'raw_entities': [],
    'entities': [],
    'ground_truth': [],
    'performance': {},
    'errors': []
}

print(f"\nProcessing: {sample_filename}")
print(f"Input length: {len(sample_text)} characters\n")

result = ner_graph.invoke(initial_state)

print("\n" + "="*70)
print("RESULTS")
print("="*70)

print(f"\nBIO Output (first 500 chars):")
print(result['bio_output'][:500] + "..." if len(result['bio_output']) > 500 else result['bio_output'])

print(f"\nPredicted Entities ({len(result['entities'])}):")
for ent in result['entities']:
    print(f"  {ent['label']:10} | {ent['start']:3}-{ent['end']:3} | {ent['text']}")

print(f"\nGround Truth ({len(result['ground_truth'])}):")
for gt in result['ground_truth']:
    print(f"  {gt['label']:10} | {gt['text']}")

print(f"\nPerformance:")
perf = result['performance']
print(f"  Precision: {perf['precision']:.4f}")
print(f"  Recall:    {perf['recall']:.4f}")
print(f"  F1-Score:  {perf['f1']:.4f}")
print(f"  TP: {perf['tp']}, FP: {perf['fp']}, FN: {perf['fn']}")

if result['errors']:
    print(f"\nErrors: {result['errors']}")

# ============================================================================
# TASK 5: EVALUATE ON MULTIPLE FILES
# ============================================================================

def evaluate_file_with_graph(filename: str, verbose=False):
    """Evaluate single file using LangGraph"""
    try:
        with open(os.path.join(TEXT_DIR, filename), 'r', encoding='utf-8') as f:
            text = f.read()
        
        state = {
            'text': text,
            'filename': filename,
            'bio_output': '',
            'llm_attempts': 0,
            'raw_entities': [],
            'entities': [],
            'ground_truth': [],
            'performance': {},
            'errors': []
        }
        
        if verbose:
            print(f"\n{'='*70}")
            print(f"FILE: {filename}")
            print(f"{'='*70}")
            print(f"Text length: {len(text)} chars\n")
        
        result = ner_graph.invoke(state)
        
        if verbose:
            print(f"\nBIO output length: {len(result['bio_output'])} chars")
            print(f"Extracted entities: {len(result['entities'])}")
            print(f"Ground truth: {len(result['ground_truth'])}")
            print(f"\nPredictions:")
            for ent in result['entities']:
                print(f"  {ent['label']:10} | {ent['text']}")
            print(f"\nPerformance: P={result['performance']['precision']:.3f} "
                  f"R={result['performance']['recall']:.3f} "
                  f"F1={result['performance']['f1']:.3f}")
        
        return result['performance']
        
    except Exception as e:
        print(f"ERROR in {filename}: {e}")
        return None

print("\n" + "="*70)
print("TASK 5: EVALUATING MULTIPLE FILES")
print("="*70)

all_files = [f for f in os.listdir(TEXT_DIR) if f.endswith('.txt')]
random.seed(42)
random_files = random.sample(all_files, min(50, len(all_files)))

print(f"\nEvaluating {len(random_files)} random files\n")

# Verbose for first 2
scores = []
print("Detailed output for first 2 files:")
print("-" * 70)

for f in random_files[:2]:
    perf = evaluate_file_with_graph(f, verbose=True)
    if perf:
        scores.append(perf)

# Summary for rest
print("\n" + "-" * 70)
print("Processing remaining files:")
print("-" * 70)

for f in tqdm(random_files[2:], desc="Evaluating"):
    perf = evaluate_file_with_graph(f, verbose=False)
    if perf:
        scores.append(perf)
        print(f"{f:35} | F1: {perf['f1']:.3f}")

# Final statistics
if scores:
    avg_p = sum(s['precision'] for s in scores) / len(scores)
    avg_r = sum(s['recall'] for s in scores) / len(scores)
    avg_f1 = sum(s['f1'] for s in scores) / len(scores)
    
    print("\n" + "="*70)
    print("FINAL RESULTS")
    print("="*70)
    print(f"Files processed: {len(scores)}/{len(random_files)}")
    print(f"\nMacro-Averaged Metrics:")
    print(f"  Precision: {avg_p:.4f}")
    print(f"  Recall:    {avg_r:.4f}")
    print(f"  F1-Score:  {avg_f1:.4f}")
    
    f1_scores = [s['f1'] for s in scores]
    print(f"\nF1 Distribution:")
    print(f"  Min:    {min(f1_scores):.4f}")
    print(f"  Q1:     {sorted(f1_scores)[len(f1_scores)//4]:.4f}")
    print(f"  Median: {sorted(f1_scores)[len(f1_scores)//2]:.4f}")
    print(f"  Q3:     {sorted(f1_scores)[3*len(f1_scores)//4]:.4f}")
    print(f"  Max:    {max(f1_scores):.4f}")
    
    zero_count = sum(1 for s in f1_scores if s == 0.0)
    low_count = sum(1 for s in f1_scores if 0 < s < 0.3)
    mid_count = sum(1 for s in f1_scores if 0.3 <= s < 0.6)
    high_count = sum(1 for s in f1_scores if s >= 0.6)
    
    print(f"\nPerformance Breakdown:")
    print(f"  Failures (F1=0):      {zero_count}/{len(scores)} ({zero_count/len(scores)*100:.1f}%)")
    print(f"  Poor (0 < F1 < 0.3):  {low_count}/{len(scores)} ({low_count/len(scores)*100:.1f}%)")
    print(f"  Moderate (0.3-0.6):   {mid_count}/{len(scores)} ({mid_count/len(scores)*100:.1f}%)")
    print(f"  Good (F1 >= 0.6):     {high_count}/{len(scores)} ({high_count/len(scores)*100:.1f}%)")

# ============================================================================
# TASK 6: SNOMED-CT LINKING
# ============================================================================

def create_sct_datastore(ann_filename: str):
    """Combine original and SCT annotations"""
    sct_filepath = os.path.join(SCT_DIR, ann_filename)
    original_filepath = os.path.join(ORIGINAL_DIR, ann_filename)
    
    text_to_label = {}
    line_regex_orig = re.compile(r'^(T\d+)\t(\w+)[\s\d;]+\t(.+)$')
    
    with open(original_filepath, 'r', encoding='utf-8') as f:
        for line in f:
            match = line_regex_orig.match(line.strip())
            if match:
                text_to_label[match.group(3).strip()] = match.group(2)
    
    sct_data = []
    line_regex_sct = re.compile(r'^(TT\d+)\t(.*?)\t(.*?)$')
    
    with open(sct_filepath, 'r', encoding='utf-8') as f:
        for line in f:
            match = line_regex_sct.match(line.strip())
            if match:
                _, codes_and_descs_raw, text_and_spans = match.groups()
                ground_truth_text = re.split(r'\d+\s\d+', text_and_spans)[-1].strip()
                label = text_to_label.get(ground_truth_text, 'Unknown')
                
                codes_and_descs = re.findall(r'(\d+)\s*\|\s*(.*?)\s*(?=\||$)', codes_and_descs_raw)
                
                for code, desc in codes_and_descs:
                    sct_data.append({
                        'sct_code': code.strip(),
                        'sct_description': desc.strip(),
                        'label_type': label,
                        'ground_truth_text': ground_truth_text
                    })
    
    return pd.DataFrame(sct_data)

def link_with_string_matching(predicted_adrs, sct_datastore):
    """Link ADRs using fuzzy string matching"""
    sct_adr_df = sct_datastore[sct_datastore['label_type'] == 'ADR']
    sct_descriptions = sct_adr_df['sct_description'].unique().tolist()
    
    if not predicted_adrs or not sct_descriptions:
        return pd.DataFrame()
    
    results = []
    for adr in predicted_adrs:
        best_match, score = fuzzy_process.extractOne(adr['text'], sct_descriptions)
        matched_row = sct_adr_df[sct_adr_df['sct_description'] == best_match].iloc[0]
        
        results.append({
            'predicted_text': adr['text'],
            'best_sct_match': best_match,
            'sct_code': matched_row['sct_code'],
            'match_score': score
        })
    
    return pd.DataFrame(results)

def link_with_embeddings(predicted_adrs, sct_datastore):
    """Link ADRs using semantic embeddings"""
    sct_adr_df = sct_datastore[sct_datastore['label_type'] == 'ADR'].drop_duplicates(subset=['sct_description'])
    sct_descriptions = sct_adr_df['sct_description'].tolist()
    
    if not predicted_adrs or not sct_descriptions:
        return pd.DataFrame()
    
    pred_texts = [adr['text'] for adr in predicted_adrs]
    pred_embeddings = embedding_model.encode(pred_texts, convert_to_tensor=True)
    sct_embeddings = embedding_model.encode(sct_descriptions, convert_to_tensor=True)
    
    similarities = util.cos_sim(pred_embeddings, sct_embeddings)
    
    results = []
    for i, adr in enumerate(predicted_adrs):
        best_idx = similarities[i].argmax().item()
        
        results.append({
            'predicted_text': adr['text'],
            'best_sct_match': sct_descriptions[best_idx],
            'sct_code': sct_adr_df.iloc[best_idx]['sct_code'],
            'match_score': similarities[i][best_idx].item()
        })
    
    return pd.DataFrame(results)

print("\n" + "="*70)
print("TASK 6: SNOMED-CT CODE LINKING")
print("="*70)

sample_ann = sample_filename.replace('.txt', '.ann')
sct_df = create_sct_datastore(sample_ann)

print(f"\nSCT Datastore:")
print(sct_df.head(10))

adr_predictions = [e for e in result['entities'] if e['label'] == 'ADR']

if adr_predictions:
    print(f"\nLinking {len(adr_predictions)} ADR predictions to SNOMED-CT...")
    
    string_results = link_with_string_matching(adr_predictions, sct_df)
    embedding_results = link_with_embeddings(adr_predictions, sct_df)
    
    print("\n" + "-"*70)
    print("String Matching Results:")
    print("-"*70)
    print(string_results.to_string(index=False))
    
    print("\n" + "-"*70)
    print("Embedding Similarity Results:")
    print("-"*70)
    print(embedding_results.to_string(index=False))
    
    print("\n" + "="*70)
    print("COMPARISON")
    print("="*70)
    print("""
String Matching: Fast, handles typos, good for lexically similar terms
Embedding Similarity: Captures semantic meaning, better for colloquial terms

Recommendation: Use embeddings for patient→clinical terminology mapping
""")
else:
    print("\nNo ADR predictions found for linking demonstration.")

# ============================================================================
# TASK 4: ADR-SPECIFIC EVALUATION (MedDRA)
# ============================================================================

def parse_meddra_truth(filepath: str):
    """Parse MedDRA ground truth file"""
    entities = []
    line_regex = re.compile(r'^(TT\d+)\t\d+\s[\d\s]+\t(.+)$')
    
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            match = line_regex.match(line.strip())
            if match:
                _, text = match.groups()
                entities.append({'label': 'ADR', 'text': text.strip()})
    return entities

print("\n" + "="*70)
print("TASK 4: ADR-SPECIFIC EVALUATION (vs MedDRA)")
print("="*70)

meddra_ann = sample_filename.replace('.txt', '.ann')
meddra_path = os.path.join(MEDDRA_DIR, meddra_ann)

if os.path.exists(meddra_path):
    meddra_ground_truth = parse_meddra_truth(meddra_path)
    
    # Filter only ADR predictions
    adr_only_preds = [e for e in result['entities'] if e['label'] == 'ADR']
    
    # Calculate ADR-specific performance
    tp = 0
    matched_gt = set()
    
    for pred in adr_only_preds:
        pred_words = set(pred['text'].lower().split())
        
        for idx, gt in enumerate(meddra_ground_truth):
            if idx in matched_gt:
                continue
            
            gt_words = set(gt['text'].lower().split())
            
            if pred_words and gt_words:
                intersection = len(pred_words & gt_words)
                union = len(pred_words | gt_words)
                similarity = intersection / union if union > 0 else 0
                
                if similarity > 0.66:
                    tp += 1
                    matched_gt.add(idx)
                    break
    
    fp = len(adr_only_preds) - tp
    fn = len(meddra_ground_truth) - len(matched_gt)
    
    adr_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    adr_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    adr_f1 = 2 * adr_precision * adr_recall / (adr_precision + adr_recall) if (adr_precision + adr_recall) > 0 else 0
    
    print(f"\nADR-Only Performance (vs MedDRA):")
    print(f"  ADR Predictions: {len(adr_only_preds)}")
    print(f"  MedDRA Ground Truth: {len(meddra_ground_truth)}")
    print(f"  Precision: {adr_precision:.4f}")
    print(f"  Recall:    {adr_recall:.4f}")
    print(f"  F1-Score:  {adr_f1:.4f}")
    print(f"  TP: {tp}, FP: {fp}, FN: {fn}")
else:
    print(f"\nMedDRA file not found: {meddra_path}")

# ============================================================================
# SAVE RESULTS TO FILE
# ============================================================================

def save_results_to_csv(all_scores, output_filename='ner_results.csv'):
    """Save evaluation results to CSV"""
    results_df = pd.DataFrame(all_scores)
    results_df.to_csv(output_filename, index=False)
    print(f"\n✓ Results saved to {output_filename}")

# Uncomment to save results from the 10-file evaluation:
# save_results_to_csv(scores)

# ============================================================================
# VISUALIZATION & ANALYSIS
# ============================================================================

def analyze_errors(filename: str):
    """Detailed error analysis for a specific file"""
    with open(os.path.join(TEXT_DIR, filename), 'r', encoding='utf-8') as f:
        text = f.read()
    
    state = {
        'text': text,
        'filename': filename,
        'bio_output': '',
        'llm_attempts': 0,
        'raw_entities': [],
        'entities': [],
        'ground_truth': [],
        'performance': {},
        'errors': []
    }
    
    result = ner_graph.invoke(state)
    
    predictions = result['entities']
    ground_truth = result['ground_truth']
    
    print(f"\n{'='*70}")
    print(f"ERROR ANALYSIS: {filename}")
    print(f"{'='*70}")
    
    # Find false positives
    matched_preds = set()
    for pred in predictions:
        pred_words = set(pred['text'].lower().split())
        for gt in ground_truth:
            if pred['label'].upper() == gt['label'].upper():
                gt_words = set(gt['text'].lower().split())
                if pred_words and gt_words:
                    similarity = len(pred_words & gt_words) / len(pred_words | gt_words)
                    if similarity > 0.66:
                        matched_preds.add(pred['text'].lower())
                        break
    
    false_positives = [p for p in predictions if p['text'].lower() not in matched_preds]
    
    if false_positives:
        print(f"\nFalse Positives ({len(false_positives)}):")
        for fp in false_positives:
            print(f"  {fp['label']:10} | {fp['text']}")
            # Show context
            start = max(0, fp['start'] - 50)
            end = min(len(text), fp['end'] + 50)
            context = text[start:end].replace('\n', ' ')
            print(f"    Context: ...{context}...")
    
    # Find false negatives
    matched_gt = set()
    for gt in ground_truth:
        gt_words = set(gt['text'].lower().split())
        for pred in predictions:
            if pred['label'].upper() == gt['label'].upper():
                pred_words = set(pred['text'].lower().split())
                if pred_words and gt_words:
                    similarity = len(pred_words & gt_words) / len(pred_words | gt_words)
                    if similarity > 0.66:
                        matched_gt.add(gt['text'].lower())
                        break
    
    false_negatives = [g for g in ground_truth if g['text'].lower() not in matched_gt]
    
    if false_negatives:
        print(f"\nFalse Negatives ({len(false_negatives)}) - Missed by system:")
        for fn in false_negatives:
            print(f"  {fn['label']:10} | {fn['text']}")
            # Try to find in text
            if fn['text'].lower() in text.lower():
                idx = text.lower().find(fn['text'].lower())
                start = max(0, idx - 50)
                end = min(len(text), idx + len(fn['text']) + 50)
                context = text[start:end].replace('\n', ' ')
                print(f"    Found in text: ...{context}...")
            else:
                print(f"    Text not found verbatim in input")
    
    return {
        'false_positives': false_positives,
        'false_negatives': false_negatives
    }

# Example: Analyze errors for a specific file
if scores and len(scores) > 0:
    # Find a file with moderate performance for interesting analysis
    mid_performing_idx = len(scores) // 2
    sorted_files = sorted(zip(random_files, scores), key=lambda x: x[1]['f1'])
    mid_file = sorted_files[mid_performing_idx][0]
    
    print("\n" + "="*70)
    print("DETAILED ERROR ANALYSIS")
    print("="*70)
    print(f"Analyzing: {mid_file}")
    
    error_analysis = analyze_errors(mid_file)

# ============================================================================
# SUMMARY STATISTICS BY ENTITY TYPE
# ============================================================================

def calculate_per_entity_metrics(all_results):
    """Calculate P/R/F1 for each entity type separately"""
    entity_stats = defaultdict(lambda: {'tp': 0, 'fp': 0, 'fn': 0})
    
    for filename, performance in zip(random_files, scores):
        with open(os.path.join(TEXT_DIR, filename), 'r') as f:
            text = f.read()
        
        state = {
            'text': text,
            'filename': filename,
            'bio_output': '',
            'llm_attempts': 0,
            'raw_entities': [],
            'entities': [],
            'ground_truth': [],
            'performance': {},
            'errors': []
        }
        
        result = ner_graph.invoke(state)
        
        for entity_type in ['ADR', 'Drug', 'Disease', 'Symptom']:
            preds = [e for e in result['entities'] if e['label'].upper() == entity_type.upper()]
            gts = [e for e in result['ground_truth'] if e['label'].upper() == entity_type.upper()]
            
            matched = set()
            for pred in preds:
                pred_words = set(pred['text'].lower().split())
                for idx, gt in enumerate(gts):
                    if idx in matched:
                        continue
                    gt_words = set(gt['text'].lower().split())
                    if pred_words and gt_words:
                        similarity = len(pred_words & gt_words) / len(pred_words | gt_words)
                        if similarity > 0.66:
                            entity_stats[entity_type]['tp'] += 1
                            matched.add(idx)
                            break
                else:
                    entity_stats[entity_type]['fp'] += 1
            
            entity_stats[entity_type]['fn'] += len(gts) - len(matched)
    
    return entity_stats

if scores:
    print("\n" + "="*70)
    print("PER-ENTITY-TYPE PERFORMANCE")
    print("="*70)
    
    entity_metrics = calculate_per_entity_metrics(scores)
    
    print(f"\n{'Entity Type':<15} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
    print("-" * 55)
    
    for entity_type in ['ADR', 'Drug', 'Disease', 'Symptom']:
        stats = entity_metrics[entity_type]
        tp, fp, fn = stats['tp'], stats['fp'], stats['fn']
        
        p = tp / (tp + fp) if (tp + fp) > 0 else 0
        r = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * p * r / (p + r) if (p + r) > 0 else 0
        
        print(f"{entity_type:<15} {p:<12.4f} {r:<12.4f} {f1:<12.4f}")

# ============================================================================
# FINAL SUMMARY & RECOMMENDATIONS
# ============================================================================

print("\n" + "="*70)
print("PIPELINE SUMMARY")
print("="*70)

print("""
APPROACH:
- Zero-shot learning with Gemini 2.0 Flash
- BIO tagging format with structured LangGraph pipeline
- Relaxed Jaccard matching (>66% word overlap)

STRENGTHS:
✓ Modular architecture with LangGraph
✓ Clear state management and error tracking
✓ Context-aware ADR detection
✓ Robust to format variations in LLM output

LIMITATIONS:
⚠ Zero-shot approach limits performance vs fine-tuned models
⚠ Boundary detection for modifiers inconsistent
⚠ Context-dependent labels (ADR vs Symptom) remain challenging
⚠ Long texts may exceed model context window

RECOMMENDATIONS FOR IMPROVEMENT:
1. Fine-tune smaller model (e.g., BioBERT) on CADEC training set
2. Add retry logic for malformed LLM outputs
3. Implement active learning for uncertain cases
4. Use ensemble of multiple models
5. Add human-in-the-loop validation for low-confidence predictions
""")

if scores and len(scores) > 0:
    avg_f1 = sum(s['f1'] for s in scores) / len(scores)
    
    if avg_f1 >= 0.6:
        assessment = "EXCELLENT - Outperforms typical zero-shot baselines"
    elif avg_f1 >= 0.5:
        assessment = "GOOD - Competitive with zero-shot approaches"
    elif avg_f1 >= 0.4:
        assessment = "ACCEPTABLE - Baseline performance"
    elif avg_f1 >= 0.3:
        assessment = "BELOW EXPECTATIONS - Needs improvement"
    else:
        assessment = "POOR - System requires significant debugging"
    
    print(f"\nOVERALL ASSESSMENT (F1={avg_f1:.3f}): {assessment}")

print("\n" + "="*70)
print("PIPELINE COMPLETE")
print("="*70)

Data directory: /home/aakash/Development/Miimansa_assignment/CADEC.v2/cadec
Text files: 1250 found
✓ Gemini API initialized
Loading embedding model (this may take a moment on first run)...


E0000 00:00:1759772983.327931    9556 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


✓ Embedding model loaded

TASK 1: ENUMERATING DISTINCT ENTITIES

Processing files in: /home/aakash/Development/Miimansa_assignment/CADEC.v2/cadec/original


Parsing annotation files: 100%|██████████| 1250/1250 [00:00<00:00, 27954.72it/s]


Results:
  ADR        |  3400 unique entities
             | Examples: ['trouble with walking/balance', 'scalp itching', 'shrinking muscels neck']
  Drug       |   323 unique entities
             | Examples: ['lecithin granules', 'ezetimbe', 'pravochol']
  Finding    |   298 unique entities
             | Examples: ['type2 diabetes', 'cramps in legs', 'generalized skin discoloration']
  Disease    |   164 unique entities
             | Examples: ['hyprochondria', 'rhabdomyolosis', 'lipid problem']
  Symptom    |   148 unique entities
             | Examples: ['elbow pain', 'severly restricted', 'chronic pain']

✓ LangGraph workflow compiled

TASK 2-3: TESTING PIPELINE ON SAMPLE FILE

Processing: ARTHROTEC.1.txt
Input length: 484 characters







RESULTS

BIO Output (first 500 chars):
I/O
feel/O
a/O
bit/O
drowsy/B-ADR
&/O
have/O
a/O
little/O
blurred/B-ADR
vision/I-ADR
,/O
so/O
far/O
no/O
gastric/B-ADR
problems/I-ADR
./O
I've/O
been/O
on/O
Arthrotec/B-Drug
50/O
for/O
over/O
10/O
years/O
on/O
and/O
off/O
,/O
only/O
taking/O
it/O
when/O
I/O
needed/O
it/O
./O
Due/O
to/O
my/O
arthritis/B-Disease
getting/O
progressively/O
worse/O
,/O
to/O
the/O
point/O
where/O
I/O
am/O
in/O
tears/O
with/O
the/O
agony/B-Symptom
,/O
gp's/O
started/O
me/O
on/O
75/O
twice/O
a/O
day/O
and/O
I/O
have/O
t...

Predicted Entities (8):
  ADR        |  13- 19 | drowsy
  ADR        |  36- 50 | blurred vision
  ADR        |  62- 78 | gastric problems
  DRUG       |  93-102 | Arthrotec
  DISEASE    | 179-188 | arthritis
  SYMPTOM    | 260-265 | agony
  SYMPTOM    | 412-417 | pains
  ADR        | 448-453 | weird

Ground Truth (8):
  ADR        | bit drowsy
  ADR        | little blurred vision
  Drug       | Arthrotec
  Disease    | arthritis
  Symptom    | agony
 

Evaluating:   2%|▏         | 1/48 [00:01<01:24,  1.79s/it]

LIPITOR.528.txt                     | F1: 0.000


Evaluating:   4%|▍         | 2/48 [00:12<05:35,  7.29s/it]

LIPITOR.464.txt                     | F1: 0.571


Evaluating:   6%|▋         | 3/48 [00:19<05:15,  7.01s/it]

LIPITOR.252.txt                     | F1: 0.400


Evaluating:   8%|▊         | 4/48 [00:26<05:01,  6.85s/it]

LIPITOR.816.txt                     | F1: 0.727


Evaluating:  10%|█         | 5/48 [00:45<08:11, 11.43s/it]

LIPITOR.718.txt                     | F1: 0.800


Evaluating:  12%|█▎        | 6/48 [00:51<06:39,  9.52s/it]

LIPITOR.583.txt                     | F1: 1.000


Evaluating:  15%|█▍        | 7/48 [01:17<10:04, 14.75s/it]

LIPITOR.173.txt                     | F1: 0.364


Evaluating:  17%|█▋        | 8/48 [01:19<07:17, 10.94s/it]

LIPITOR.975.txt                     | F1: 1.000


Evaluating:  19%|█▉        | 9/48 [01:32<07:25, 11.41s/it]

LIPITOR.681.txt                     | F1: 0.545


Evaluating:  21%|██        | 10/48 [01:53<09:06, 14.37s/it]

LIPITOR.401.txt                     | F1: 0.615


Evaluating:  23%|██▎       | 11/48 [02:04<08:14, 13.36s/it]

ARTHROTEC.24.txt                    | F1: 0.400


Evaluating:  25%|██▌       | 12/48 [02:15<07:40, 12.78s/it]

LIPITOR.674.txt                     | F1: 0.333


Evaluating:  27%|██▋       | 13/48 [02:29<07:33, 12.94s/it]

LIPITOR.649.txt                     | F1: 1.000


Evaluating:  29%|██▉       | 14/48 [02:40<07:08, 12.59s/it]

LIPITOR.253.txt                     | F1: 0.647


Evaluating:  31%|███▏      | 15/48 [02:47<05:58, 10.87s/it]

LIPITOR.817.txt                     | F1: 0.706


Evaluating:  33%|███▎      | 16/48 [02:58<05:45, 10.80s/it]

LIPITOR.144.txt                     | F1: 0.667


Evaluating:  35%|███▌      | 17/48 [03:08<05:31, 10.69s/it]

LIPITOR.892.txt                     | F1: 0.667


Evaluating:  38%|███▊      | 18/48 [03:21<05:38, 11.28s/it]

LIPITOR.761.txt                     | F1: 0.571


Evaluating:  40%|███▉      | 19/48 [03:41<06:38, 13.74s/it]

LIPITOR.824.txt                     | F1: 0.444


Evaluating:  42%|████▏     | 20/48 [03:55<06:29, 13.93s/it]

LIPITOR.185.txt                     | F1: 0.429


Evaluating:  44%|████▍     | 21/48 [04:02<05:21, 11.90s/it]

LIPITOR.55.txt                      | F1: 1.000


Evaluating:  46%|████▌     | 22/48 [04:16<05:24, 12.48s/it]

LIPITOR.614.txt                     | F1: 0.235


Evaluating:  48%|████▊     | 23/48 [04:36<06:05, 14.63s/it]

LIPITOR.568.txt                     | F1: 0.765


Evaluating:  50%|█████     | 24/48 [04:53<06:13, 15.55s/it]

LIPITOR.316.txt                     | F1: 0.600


Evaluating:  52%|█████▏    | 25/48 [05:04<05:27, 14.22s/it]

LIPITOR.782.txt                     | F1: 0.667


Evaluating:  54%|█████▍    | 26/48 [05:10<04:13, 11.51s/it]

ARTHROTEC.118.txt                   | F1: 0.857


Evaluating:  56%|█████▋    | 27/48 [05:40<06:00, 17.18s/it]

ARTHROTEC.135.txt                   | F1: 0.125


Evaluating:  58%|█████▊    | 28/48 [05:54<05:25, 16.30s/it]

LIPITOR.382.txt                     | F1: 0.750


Evaluating:  60%|██████    | 29/48 [06:08<04:53, 15.46s/it]

LIPITOR.420.txt                     | F1: 0.889


Evaluating:  62%|██████▎   | 30/48 [06:23<04:38, 15.44s/it]

VOLTAREN-XR.3.txt                   | F1: 0.737


Evaluating:  65%|██████▍   | 31/48 [06:36<04:07, 14.58s/it]

ARTHROTEC.56.txt                    | F1: 0.833


Evaluating:  67%|██████▋   | 32/48 [06:42<03:15, 12.20s/it]

LIPITOR.680.txt                     | F1: 0.000


Evaluating:  69%|██████▉   | 33/48 [06:50<02:42, 10.81s/it]

LIPITOR.4.txt                       | F1: 0.000


Evaluating:  71%|███████   | 34/48 [06:56<02:09,  9.27s/it]

LIPITOR.733.txt                     | F1: 0.667


Evaluating:  73%|███████▎  | 35/48 [07:07<02:09,  9.95s/it]

LIPITOR.743.txt                     | F1: 0.000


Evaluating:  75%|███████▌  | 36/48 [07:24<02:24, 12.00s/it]

ARTHROTEC.117.txt                   | F1: 0.154


Evaluating:  77%|███████▋  | 37/48 [07:35<02:10, 11.85s/it]

LIPITOR.206.txt                     | F1: 0.720


Evaluating:  79%|███████▉  | 38/48 [07:44<01:48, 10.89s/it]

ARTHROTEC.102.txt                   | F1: 0.333


Evaluating:  81%|████████▏ | 39/48 [07:52<01:30, 10.03s/it]

ARTHROTEC.72.txt                    | F1: 1.000


Evaluating:  83%|████████▎ | 40/48 [08:10<01:39, 12.43s/it]

LIPITOR.392.txt                     | F1: 0.625


Evaluating:  85%|████████▌ | 41/48 [08:15<01:10, 10.04s/it]

LIPITOR.525.txt                     | F1: 0.667


Evaluating:  88%|████████▊ | 42/48 [08:32<01:13, 12.29s/it]

ARTHROTEC.50.txt                    | F1: 0.500


Evaluating:  90%|████████▉ | 43/48 [08:50<01:10, 14.06s/it]

LIPITOR.786.txt                     | F1: 0.118


Evaluating:  92%|█████████▏| 44/48 [09:02<00:52, 13.23s/it]

LIPITOR.272.txt                     | F1: 1.000


Evaluating:  94%|█████████▍| 45/48 [09:32<00:55, 18.51s/it]

LIPITOR.800.txt                     | F1: 0.727


Evaluating:  96%|█████████▌| 46/48 [09:46<00:34, 17.13s/it]

LIPITOR.735.txt                     | F1: 0.800


Evaluating:  98%|█████████▊| 47/48 [10:13<00:19, 19.98s/it]

LIPITOR.17.txt                      | F1: 0.118


Evaluating: 100%|██████████| 48/48 [10:26<00:00, 13.06s/it]

ARTHROTEC.20.txt                    | F1: 0.667

FINAL RESULTS
Files processed: 50/50

Macro-Averaged Metrics:
  Precision: 0.5787
  Recall:    0.5971
  F1-Score:  0.5731

F1 Distribution:
  Min:    0.0000
  Q1:     0.4000
  Median: 0.6667
  Q3:     0.7619
  Max:    1.0000

Performance Breakdown:
  Failures (F1=0):      4/50 (8.0%)
  Poor (0 < F1 < 0.3):  5/50 (10.0%)
  Moderate (0.3-0.6):   12/50 (24.0%)
  Good (F1 >= 0.6):     29/50 (58.0%)

TASK 6: SNOMED-CT CODE LINKING

SCT Datastore:
           sct_code                       sct_description label_type  \
0         271782001                                Drowsy        ADR   
1         246636008                 Blurred vision - hazy        ADR   
2         162076009  Excessive upper gastrointestinal gas        ADR   
3  3384011000036100                             Arthrotec       Drug   
4           3723001                             Arthritis    Disease   
5         102498003                                 Agony    Symptom   
6





----------------------------------------------------------------------
String Matching Results:
----------------------------------------------------------------------
  predicted_text                       best_sct_match  sct_code  match_score
          drowsy                               Drowsy 271782001          100
  blurred vision                Blurred vision - hazy 246636008           90
gastric problems Excessive upper gastrointestinal gas 162076009           45
           weird Excessive upper gastrointestinal gas 162076009           49

----------------------------------------------------------------------
Embedding Similarity Results:
----------------------------------------------------------------------
  predicted_text                       best_sct_match  sct_code  match_score
          drowsy                               Drowsy 271782001     1.000000
  blurred vision                Blurred vision - hazy 246636008     0.790501
gastric problems Excessive upper gastrointe