# Explainability Task

## Approach

1. A Llava Llama 8B model was trained and used for the detection of concepts and the captions of each image. Then, for this explainability task, Llama 3.1 was used to merge the concepts and the captions to generate a more complete explanation to relate with the image.

2. The list of concepts was translated into natural language. From those files, the rows corresponding to the test set of images for the task were extracted and renamed as caption-csv and ref_mini_concepts_natural_.csv.

3. The list of concepts in natural language and the subtitles of the test set were provided as part of a prompt to Open AI APIs to obtain approximate location coordinates, which were saved in the file sam_coord.csv.

4. Natural Language Processing integration with NER extraction and others techniques were used to have good candidate labels for SAM.

5. The file with the coordinates was sent to Segment Anything Model (SAM) to create the bounding boxes. Precision was added using heat maps. Specifically, the following was done:

    - Generates multiple segmentation candidates and automatically selects the best one per anatomical structure
    - Provides quantified quality metrics (0-1 scale) for each segmentation based on anatomical likelihood
    - Uses probability maps to identify most likely anatomical locations from approximate coordinates
    - Fine-tuned parameters specifically for medical imaging precision

6. Enhancement was done with YOLO (You Only Look Once) for automatic detection of missed anatomical structures.

7. Arrow detection and following, and keypoint detection algorithms (SIFT, FAST, LoG) were used.

8. Computer vision preprocessing (Canny edge detection and adaptive thresholding) and geometric & spatial analysis were also used.

9. Concepts, captions and explanations are printed below the bounding box images for clarity. Also some statistical analysis & confidence scoring.


### 1. Getting aproximate location coordinates

The objective is to get approximate coordinates to SAM. This script:

- Correctly extract and group compound medical terms (e.g., “small intestine” not “Intestines;Small”).

- Omit generic modality terms (X-ray, MRI, PET, etc.).

- Identify bounding boxes for medical structures/findings based on visual and textual cues.

- Detect arrow tips and link them to the relevant terms/regions.

- Leverage GPT-4V for visual reasoning and GPT-4.1 for text refinement, merging, and error correction.

- Produce CSV/JSON output with ImageID, Label, x, y, width, height.

- Collaborate between models for better accuracy and compound term preservation.

In [None]:
# ========== LIBRARIES ==========

import openai
import pandas as pd
import json
import os
import time
import base64
import re
from dotenv import load_dotenv
import csv

In [22]:
# ========== CONFIGURATION ==========

load_dotenv() # Load environment variables from the .env file into the environment
openai.api_key = os.getenv("OPENAI_API_KEY")

IMAGE_FOLDER = "./data/test_set_explain/"
CONCEPTS_FILE = "./data/ref_mini_concepts_natural_.csv"
CAPTIONS_FILE = "./data/3_submission_explainability.csv"
CSV_OUTPUT = "./results_explain-3/sam_coord.csv"
JSON_OUTPUT = "./results_explain-3/sam_coord.json"
VISION_MODEL = "gpt-4o"
LANG_MODEL = "gpt-4.1"
MAX_VISION_TOKENS = 800
MAX_LANG_TOKENS = 900

In [None]:
# ========== UTILITIES ==========

def get_image_path(image_id):
    for ext in ['.jpeg', '.jpg', '.png']:
        img_path = os.path.join(IMAGE_FOLDER, f"{image_id}{ext}")
        if os.path.exists(img_path):
            return img_path
    return None

def is_modality(term):
    modalities = [
        "plain x-ray", "x-ray", "ct", "mri", "ultrasound", "pet", "ct-scan", "computed tomography",
        "magnetic resonance imaging", "positron-emission tomography", "ultrasonography", "angiogram"
    ]
    return term.strip().lower() in [m.lower() for m in modalities]

def clean_term(term):
    return re.sub(r"^(Structure of |structure of )", "", term).strip()

def group_compound_terms(term_list):
    known_compounds = [
        "small intestine", "large intestine", "left axillary region", "right lower lobe",
        "lower lobe", "upper lobe", "anterior chamber", "posterior chamber", "right kidney",
        "left kidney", "left atrium", "right atrium", "left ventricle", "right ventricle",
        "biliary tree", "pulmonary artery", "coronary artery", "saphenous vein"
    ]
    clean_terms = [clean_term(t).lower() for t in term_list if t and not is_modality(t)]
    final_terms = set()
    used = set()
    for comp in known_compounds:
        parts = comp.split()
        if all(any(part == t or part in t for t in clean_terms) for part in parts):
            final_terms.add(comp)
            used.update(parts)
    for t in clean_terms:
        if not any(t == u or t in u for u in used):
            final_terms.add(t)
    return [t if " " not in t else " ".join([w.capitalize() for w in t.split()]) for t in final_terms]

In [36]:
# ========== PROMPTS ==========

def gpt4v_prompt(terms, caption):
    return f"""
You are an expert radiologist. Analyze the provided medical image together with these terms: {terms}, and caption: "{caption}".

Important: In radiology, "left" and "right" always refer to the patient's left and right (the patient's perspective), which is the opposite of the observer's (radiologist's) view. Always reason and answer using the patient's perspective.

Your tasks:
1. Compound terms like "small intestine", "left axillary region", etc., must be treated as unique, unsplit entities. Never split such terms.
2. Omit general modality terms such as X-ray, MRI, PET, CT, etc.; focus on anatomical structures, findings, or pathologies.
3. For each medical term, estimate the most probable bounding box (x, y, width, height) in the image. Use both the image and the caption to reason location.
4. If there are any arrows, lines, or visual pointers, detect their tip coordinates (x, y) and, if possible, assign the pointed term or region.
5. For ambiguous terms, return your best-guess bounding box or arrow tip, and flag as "uncertain" in the label.
6. Return your results ONLY as a JSON list:
    [
        {{"label": "<compound_term>", "box": {{"x": ..., "y": ..., "width": ..., "height": ...}}}},
        {{"label": "<compound_term>", "arrow_tip": {{"x": ..., "y": ...}}}},
        ...
    ]
Do not include modality labels. Never split compound terms. Do not return any explanation, only the JSON array.
"""

def gpt41_prompt(detections_json, terms, caption, image_id):
    return f"""
You are an expert radiologist data curator. Here are the vision detections from a medical image: {json.dumps(detections_json)}.
- Original caption: "{caption}"
- Medical terms to consider: {terms}
- Image ID: {image_id}

Important: In radiology, "left" and "right" always refer to the patient's left and right, not the observer's. Always use the patient's perspective.

Your tasks:
1. For each term in the medical terms list, make sure compound terms are preserved as single entities (e.g., "small intestine" not split).
2. Omit general imaging modality terms (X-ray, MRI, PET, CT, ultrasound, etc.) and only keep anatomical structures, pathologies, or findings.
3. Validate the bounding boxes and arrow tips. If a term is missing a bounding box but has an arrow tip, estimate a default box (40x40 pixels) centered at the arrow tip.
4. Output the full data in:
    - CSV with columns: ImageID, Label, x, y, width, height
    - JSON list: [{{"ImageID":..., "Label":..., "x":..., "y":..., "width":..., "height":...}}, ...]
Only include rows for valid medical terms (no modalities, no split terms).
If any info must be inferred, do so based on your expertise. Output both the CSV and JSON (CSV first, then JSON).
"""

def extract_json_anywhere(text):
    # Busca el primer bloque válido de JSON array (list)
    matches = re.findall(r"(\[.*?\])", text, re.DOTALL)
    for m in matches:
        try:
            parsed = json.loads(m)
            return m
        except Exception:
            continue
    # Busca dentro de bloques markdown
    match = re.search(r"```json(.*?)```", text, re.DOTALL)
    if match:
        inner = match.group(1).strip()
        try:
            parsed = json.loads(inner)
            return inner
        except Exception:
            pass
    return ""

def call_gpt4v_vision(image_path, terms, caption, retries=2):
    prompt = gpt4v_prompt(terms, caption)
    with open(image_path, "rb") as imgf:
        img_b64 = base64.b64encode(imgf.read()).decode()
    for attempt in range(retries + 1):
        try:
            print("Prompt:\n", prompt[:400], "...")
            print("Image exists?", os.path.exists(image_path))
            response = openai.chat.completions.create(
                model=VISION_MODEL,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": prompt},
                            {
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:image/jpeg;base64,{img_b64}"
                                }
                            }
                        ]
                    }
                ],
                max_tokens=MAX_VISION_TOKENS
            )
            result = response.choices[0].message.content.strip()
            if not result:
                print("[Vision model] Empty response from OpenAI. Check quota, input size, or prompt.")
                return []
            clean_json = extract_json_anywhere(result)
            try:
                return json.loads(clean_json)
            except Exception as je:
                print("Raw result (not JSON):", result)
                raise je
        except Exception as e:
            print(f"[Vision model] Error (attempt {attempt+1}): {e}")
            time.sleep(3 + attempt*3)
    print("[Vision model] Failed after retries.")
    return []

def call_gpt4_1_refiner(detections_json, terms, caption, image_id, retries=2):
    prompt = gpt41_prompt(detections_json, terms, caption, image_id)
    for attempt in range(retries + 1):
        try:
            completion = openai.chat.completions.create(
                model=LANG_MODEL,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=MAX_LANG_TOKENS
            )
            reply = completion.choices[0].message.content.strip()
            print("\n=== GPT-4.1 RAW REPLY ===\n", reply, "\n=== END REPLY ===\n")
            # Busca CSV (puedes dejar igual)
            csv_part, json_part = "", ""
            csv_match = re.search(r"ImageID,.*?(\n(?:[^\[]|\[)*?)\n?\[", reply, re.DOTALL)
            if csv_match:
                csv_part = csv_match.group(0).split("\n[")[0].strip()
            else:
                parts = re.split(r"\n\s*\n", reply)
                if len(parts) >= 2 and parts[0].startswith("ImageID"):
                    csv_part = parts[0].strip()
                else:
                    idx = reply.find('[')
                    if idx != -1:
                        csv_part = reply[:idx].strip()
            # Busca el JSON con extractor avanzado
            json_part = extract_json_anywhere(reply)
            if not json_part:
                print("[Language model] Could not parse JSON part.")
                json_part = "[]"
            return csv_part, json_part
        except Exception as e:
            print(f"[Language model] Error (attempt {attempt+1}): {e}")
            time.sleep(3 + attempt*3)
    print("[Language model] Failed after retries.")
    return "", "[]"

In [37]:
# ========== MAIN PIPELINE ==========

def run_pipeline():
    if not os.path.exists(CONCEPTS_FILE) or not os.path.exists(CAPTIONS_FILE):
        print(f"ERROR: Concepts file or captions file not found.")
        return
    concepts_df = pd.read_csv(CONCEPTS_FILE)
    captions_df = pd.read_csv(CAPTIONS_FILE)
    cap_map = {str(row["ID"]): str(row["Caption"]) for _, row in captions_df.iterrows()}

    all_csv_rows = []
    all_json_objs = []

    for idx, row in concepts_df.iterrows():
        image_id = str(row["ID"])
        cuis = [c.strip() for c in str(row["CUIs"]).split(';') if c.strip()]
        caption = cap_map.get(image_id, "")
        image_path = get_image_path(image_id)
        if image_path is None:
            print(f"[{image_id}] Image not found, skipping.")
            continue
        terms = group_compound_terms(cuis)
        print(f"\n=== Processing {image_id} ===")
        print(f"Terms: {terms}")
        vision_results = call_gpt4v_vision(image_path, terms, caption)
        if not vision_results:
            print(f"[{image_id}] No vision results, skipping.")
            continue
        csv_out, json_out = call_gpt4_1_refiner(vision_results, terms, caption, image_id)
        if csv_out:
            if not all_csv_rows and "ImageID" in csv_out:
                header, *rest = csv_out.splitlines()
                all_csv_rows.append(header)
                all_csv_rows.extend(rest)
            else:
                lines = [line for line in csv_out.splitlines() if not line.lower().startswith("imageid")]
                all_csv_rows.extend(lines)
        if json_out and json_out != "[]":
            try:
                data = json.loads(json_out)
                if isinstance(data, list):
                    all_json_objs.extend(data)
            except Exception as e:
                print(f"[{image_id}] Error parsing JSON output: {e}")
        time.sleep(2.5)

    print(f"\nWriting outputs to {CSV_OUTPUT} and {JSON_OUTPUT} ...")
    with open(CSV_OUTPUT, "w") as f:
        for row in all_csv_rows:
            f.write(row + "\n")
    with open(JSON_OUTPUT, "w") as f:
        json.dump(all_json_objs, f, indent=2)
    print("Done.")


In [60]:
# Run the pipeline
run_pipeline()


=== Processing ImageCLEFmedical_Caption_2025_test_118 ===
Terms: ['anterior-posterior', 'Bone Structure Of Ilium', 'pelvis', 'abdomen', 'Bone Structure Of Pubis']
Prompt:
 
You are an expert radiologist. Analyze the provided medical image together with these terms: ['anterior-posterior', 'Bone Structure Of Ilium', 'pelvis', 'abdomen', 'Bone Structure Of Pubis'], and caption: "Plain radiograph of the abdomen, taken in the anterior-posterior projection, shows multiple air-fluid levels within the bowel loops, extending into the pelvis. The bone structure of the ilium a ...
Image exists? True

=== GPT-4.1 RAW REPLY ===
 **CSV Output:**
```
ImageID,Label,x,y,width,height
ImageCLEFmedical_Caption_2025_test_118,Bone Structure Of Ilium,200,450,150,100
ImageCLEFmedical_Caption_2025_test_118,pelvis,220,500,160,120
ImageCLEFmedical_Caption_2025_test_118,abdomen,150,100,300,350
ImageCLEFmedical_Caption_2025_test_118,Bone Structure Of Pubis,230,580,140,60
ImageCLEFmedical_Caption_2025_test_118,air

In [61]:
# ========== RESULTS ==========

# Visualize the results
display(pd.read_csv(CSV_OUTPUT).head(5))

Unnamed: 0,ImageID,Label,x,y,width,height
0,ImageCLEFmedical_Caption_2025_test_118,Bone Structure Of Ilium,200.0,450.0,150.0,100.0
1,ImageCLEFmedical_Caption_2025_test_118,pelvis,220.0,500.0,160.0,120.0
2,ImageCLEFmedical_Caption_2025_test_118,abdomen,150.0,100.0,300.0,350.0
3,ImageCLEFmedical_Caption_2025_test_118,Bone Structure Of Pubis,230.0,580.0,140.0,60.0
4,ImageCLEFmedical_Caption_2025_test_118,air-fluid levels,230.0,180.0,40.0,40.0


## SAM Medical Image Analysis with Heatmap-based Confidence

This script adds heatmap generation and confidence scoring to select the best mask per label.


### 3. Using SAM for bounding boxes

In [76]:
# ========== LIBRARIES ==========

import pandas as pd
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D
import textwrap
import re
import urllib.request
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline
from IPython.display import display, Image as IPImage
import warnings
from scipy import ndimage
from sklearn.cluster import DBSCAN
from skimage.measure import label as skimage_label, regionprops
from scipy.spatial.distance import euclidean
import json
warnings.filterwarnings('ignore')
import csv
from collections import defaultdict
import logging

In [77]:
# ========== CONFIGURATION ==========

SAM_COORD_PATH = './sam_coord.csv'
CAPTIONS_PATH = './3_submission_explainability.csv'
CONCEPTS_PATH = './ref_mini_concepts_natural_.csv'
CAPTION_FILE_PATH = './caption.csv'
IMAGES_DIR = './data/test_set_explain'
OUTPUT_DIR = './result_explain-3/sam'

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Color palette for consistency
COLORS = [
    '#FF0000',  # Red
    '#00FF00',  # Green  
    '#0000FF',  # Blue
    '#FFA500',  # Orange
    '#FF00FF',  # Magenta
    '#00FFFF',  # Cyan
    '#800080',  # Purple
    '#FFC0CB',  # Pink
    '#A52A2A',  # Brown
    '#FFFF00',  # Yellow
    '#808080',  # Gray
    '#000080'   # Navy Blue
]

print("Setting up Advanced Medical SAM with NER and Arrow Detection...")

Setting up Advanced Medical SAM with NER and Arrow Detection...


In [9]:
# ========== SAM SETUP ==========

# Setup SAM model
def setup_sam():
    """Setup SAM model with automatic download if needed."""
    # Download SAM model if not exists
    os.makedirs('sam_models', exist_ok=True)
    sam_checkpoint = "sam_models/sam_vit_h_4b8939.pth"
    
    if not os.path.exists(sam_checkpoint):
        print("Downloading SAM model (this may take a few minutes)...")
        urllib.request.urlretrieve(
            "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
            sam_checkpoint
        )
        print("SAM model downloaded successfully!")
    
    # Load SAM model
    model_type = "vit_h"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    
    # Configure mask generator optimized for medical images
    mask_generator = SamAutomaticMaskGenerator(
        model=sam,
        points_per_side=32,
        pred_iou_thresh=0.85,
        stability_score_thresh=0.92,
        crop_n_layers=1,
        crop_n_points_downscale_factor=2,
        min_mask_region_area=200
    )
    
    predictor = SamPredictor(sam)
    print("SAM configured for medical images!")
    return mask_generator, predictor, device

# Initialize SAM
mask_generator, predictor, device = setup_sam()

Using device: cuda
SAM configured for medical images!


In [11]:
# Setup Medical NER Model
def setup_medical_ner():
    """Setup biomedical NER model for extracting medical terms."""
    try:
        print("Loading biomedical NER model...")
        tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
        model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
        ner_pipeline = pipeline("ner", 
                               model=model, 
                               tokenizer=tokenizer, 
                               aggregation_strategy="simple")
        print("Medical NER model loaded successfully!")
        return ner_pipeline
    except Exception as e:
        print(f"Could not load biomedical NER model: {e}")
        print("Installing transformers: pip install transformers")
        return None

# Initialize Medical NER
medical_ner = setup_medical_ner()

Loading biomedical NER model...
Medical NER model loaded successfully!


## NER Extraction

Medical NER Extraction and Comparison Script
Extracts medical terms from explanations, cleans them, and compares with existing annotations.

In [12]:
import pandas as pd
import re
import os
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

In [None]:
# File paths
EXPLANATIONS_PATH = './3_submission_explainability.csv'
SAM_COORD_PATH = './sam_coord.csv'
OUTPUT_PATH = './data/ner.csv'

def extract_location_hint(caption, term):
    """Extract location hints from caption context."""
    caption_lower = caption.lower()
    term_lower = term.lower()
    
    # Find the term position in caption
    term_pos = caption_lower.find(term_lower)
    if term_pos == -1:
        return None
    
    # Look for directional indicators around the term
    context_window = 50  # Characters before and after
    start = max(0, term_pos - context_window)
    end = min(len(caption), term_pos + len(term) + context_window)
    context = caption_lower[start:end]
    
    location_hints = {
        'left': ['left', 'sinister', 'l.'],
        'right': ['right', 'dexter', 'r.'],
        'upper': ['upper', 'superior', 'top', 'cranial'],
        'lower': ['lower', 'inferior', 'bottom', 'caudal'],
        'anterior': ['anterior', 'front', 'ventral'],
        'posterior': ['posterior', 'back', 'dorsal'],
        'medial': ['medial', 'central', 'middle'],
        'lateral': ['lateral', 'side', 'peripheral'],
        'arrow': ['arrow', 'pointer', 'indicated', 'shown', 'marked']
    }
    
    found_hints = {}
    for direction, keywords in location_hints.items():
        for keyword in keywords:
            if keyword in context:
                found_hints[direction] = True
                break
    
    return found_hints if found_hints else None

def extract_medical_terms_rule_based(caption):
    """Enhanced rule-based extraction of medical terms from captions."""
    medical_patterns = [
        # Pathological conditions with boundaries
        r'\b(?:mass|tumor|tumour|lesion|nodule|growth|neoplasm)\b',
        r'\b(?:hematoma|hemorrhage|bleeding|blood|clot)\b',
        r'\b(?:air-fluid|fluid|effusion|pneumothorax|pleural\s+effusion)\b',
        r'\b(?:hypermetabolic|metabolic|uptake|enhancement)\b',
        r'\b(?:fracture|break|dislocation|injury)\b',
        r'\b(?:stenosis|occlusion|blockage|obstruction)\b',
        r'\b(?:inflammation|infection|abscess|sepsis)\b',
        r'\b(?:calcification|calcified|calcium|mineralization)\b',
        
        # Anatomical references with descriptors
        r'\b(?:left|right)\s+(?:upper|lower|middle)?\s*(?:lobe|lung|breast|kidney|liver|ventricle)\b',
        r'\b(?:left|right)\s+(?:ventricle|atrium|carotid|cerebral|temporal|frontal|parietal|occipital)\b',
        r'\b(?:internal|external|common)\s+(?:carotid|iliac|mammary)\s+artery\b',
        r'\blymph\s+nodes?\b|\badenopathy\b',
        
        # Size and severity descriptors with following nouns
        r'\b(?:large|small|massive|extensive|diffuse|focal)\s+(?:mass|lesion|area|collection|density)\b',
        r'\b(?:severe|moderate|mild|acute|chronic)\s+(?:inflammation|infection|stenosis|obstruction)\b',
        r'\b(?:free|trapped|loculated)\s+fluid\b',
        
        # Specific medical conditions
        r'\b(?:pneumonia|atelectasis|consolidation|opacity|infiltrate)\b',
        r'\b(?:cardiomegaly|hepatomegaly|splenomegaly)\b',
        r'\b(?:thrombosis|embolism|infarct|ischemia)\b',
        r'\b(?:aneurysm|dissection|stenosis)\b',
        r'\b(?:cyst|polyp|diverticulum|hernia)\b',
        
        # Contrast and imaging terms
        r'\b(?:contrast|enhancement|hypodense|hyperdense|isodense)\b',
        r'\b(?:hypointense|hyperintense|signal|intensity)\b',
        
        # Anatomical structures
        r'\b(?:vertebra|vertebrae|disc|facet|joint|bone|rib|sternum)\b',
        r'\b(?:aorta|vena\s+cava|pulmonary|coronary)\b',
        r'\b(?:bladder|prostate|uterus|ovary|kidney|adrenal)\b'
    ]
    
    medical_terms = []
    caption_lower = caption.lower()
    
    for pattern in medical_patterns:
        matches = re.finditer(pattern, caption_lower, re.IGNORECASE)
        for match in matches:
            term = match.group().strip()
            if len(term) > 2:  # Filter out very short terms
                # Clean the term
                cleaned_term = clean_medical_term(term)
                if cleaned_term and is_valid_medical_term(cleaned_term):
                    medical_terms.append({
                        'term': cleaned_term,
                        'type': 'MEDICAL_RULE',
                        'confidence': 0.8,
                        'location_hint': extract_location_hint(caption, term),
                        'original_match': term,
                        'pattern_matched': pattern
                    })
    
    return medical_terms

def clean_medical_term(term):
    """Clean and standardize medical terms."""
    # Convert to title case
    cleaned = term.strip().title()
    
    # Handle specific medical abbreviations
    abbreviations = {
        'Ca': 'Calcium',
        'Hb': 'Hemoglobin',
        'O2': 'Oxygen',
        'Co2': 'Carbon Dioxide',
        'Iv': 'Intravenous',
        'Im': 'Intramuscular'
    }
    
    for abbr, full in abbreviations.items():
        if cleaned == abbr:
            return full
    
    # Remove redundant words
    remove_words = ['Structure Of', 'Bone Structure Of', 'Entire', 'Complete']
    for remove_word in remove_words:
        cleaned = cleaned.replace(remove_word + ' ', '').replace(remove_word, '')
    
    # Standardize compound terms
    standardizations = {
        'Air-Fluid': 'Air-Fluid Level',
        'Free Fluid': 'Free Fluid',
        'Lymph Node': 'Lymph Nodes',
        'Blood Clot': 'Thrombosis'
    }
    
    for old, new in standardizations.items():
        if old in cleaned:
            cleaned = cleaned.replace(old, new)
    
    return cleaned.strip()

def is_valid_medical_term(term):
    """Validate if a term is a proper medical term."""
    term_lower = term.lower()
    
    # Exclude common non-medical words
    exclude_words = {
        'the', 'and', 'or', 'of', 'in', 'on', 'at', 'to', 'for', 'with', 'by',
        'from', 'as', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has',
        'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may',
        'might', 'can', 'must', 'shall', 'this', 'that', 'these', 'those',
        'there', 'here', 'where', 'when', 'how', 'what', 'which', 'who', 'why',
        'very', 'more', 'most', 'some', 'all', 'any', 'each', 'every', 'both',
        'either', 'neither', 'one', 'two', 'three', 'first', 'second', 'third',
        'new', 'old', 'good', 'bad', 'big', 'small', 'long', 'short', 'high', 'low'
    }
    
    if term_lower in exclude_words:
        return False
    
    # Must be at least 3 characters
    if len(term) < 3:
        return False
    
    # Must contain at least one letter
    if not re.search(r'[a-zA-Z]', term):
        return False
    
    # Should not be purely numeric
    if term.replace('-', '').replace(' ', '').isdigit():
        return False
    
    # Should not contain too many special characters
    special_char_ratio = len(re.findall(r'[^a-zA-Z0-9\s\-]', term)) / len(term)
    if special_char_ratio > 0.3:
        return False
    
    return True

def normalize_term_for_comparison(term):
    """Normalize term for comparison purposes."""
    # Convert to lowercase
    normalized = term.lower()
    
    # Remove common prefixes/suffixes for comparison
    prefixes_to_remove = ['left ', 'right ', 'bilateral ', 'upper ', 'lower ', 'middle ']
    suffixes_to_remove = [' left', ' right', ' bilateral', ' upper', ' lower', ' middle']
    
    for prefix in prefixes_to_remove:
        if normalized.startswith(prefix):
            normalized = normalized[len(prefix):]
    
    for suffix in suffixes_to_remove:
        if normalized.endswith(suffix):
            normalized = normalized[:-len(suffix)]
    
    # Remove articles and common words
    words_to_remove = ['the ', ' the', 'of ', ' of', 'structure ', ' structure']
    for word in words_to_remove:
        normalized = normalized.replace(word, ' ')
    
    # Clean up extra spaces
    normalized = ' '.join(normalized.split())
    
    return normalized.strip()

def compare_with_existing_annotations(ner_terms, sam_labels):
    """Compare NER terms with existing SAM annotations."""
    # Normalize existing SAM labels for comparison
    sam_normalized = set()
    for label in sam_labels:
        normalized = normalize_term_for_comparison(label)
        if normalized:
            sam_normalized.add(normalized)
    
    # Find new terms not covered by existing annotations
    new_terms = []
    covered_terms = []
    
    for term_data in ner_terms:
        term = term_data['term']
        normalized_term = normalize_term_for_comparison(term)
        
        is_covered = False
        
        # Check exact match
        if normalized_term in sam_normalized:
            is_covered = True
            covered_terms.append({**term_data, 'coverage_type': 'exact_match'})
        else:
            # Check partial matches
            for sam_term in sam_normalized:
                # Check if terms share significant words
                ner_words = set(normalized_term.split())
                sam_words = set(sam_term.split())
                
                # If more than 50% of words overlap, consider it covered
                if len(ner_words.intersection(sam_words)) / max(len(ner_words), 1) > 0.5:
                    is_covered = True
                    covered_terms.append({**term_data, 'coverage_type': 'partial_match', 'matched_with': sam_term})
                    break
        
        if not is_covered:
            new_terms.append({**term_data, 'status': 'new_term'})
    
    return new_terms, covered_terms

def process_ner_extraction():
    """Main function to process NER extraction and comparison."""
    print("Starting Medical NER Extraction and Comparison")
    print("=" * 60)
    
    # Load data files
    print("Loading data files...")
    
    if not os.path.exists(EXPLANATIONS_PATH):
        print(f"Explanations file not found: {EXPLANATIONS_PATH}")
        return
    
    if not os.path.exists(SAM_COORD_PATH):
        print(f"SAM coordinates file not found: {SAM_COORD_PATH}")
        return
    
    explanations_df = pd.read_csv(EXPLANATIONS_PATH)
    sam_df = pd.read_csv(SAM_COORD_PATH)
    
    print(f"Loaded {len(explanations_df)} explanations and {len(sam_df)} SAM annotations")
    
    # Extract medical terms from all captions
    print("\nExtracting medical terms from captions...")
    all_ner_terms = []
    
    for idx, row in explanations_df.iterrows():
        image_id = row.get('ID', row.get('id', idx))
        caption = row.get('Caption', row.get('caption', ''))
        
        if pd.notna(caption) and caption.strip():
            ner_terms = extract_medical_terms_rule_based(str(caption))
            
            for term_data in ner_terms:
                term_data['image_id'] = image_id
                term_data['source_caption'] = caption[:100] + '...' if len(caption) > 100 else caption
                all_ner_terms.append(term_data)
    
    print(f"Extracted {len(all_ner_terms)} raw medical terms")
    
    # Remove duplicates and count frequencies
    print("\nCleaning and deduplicating terms...")
    term_counter = Counter()
    unique_terms = {}
    
    for term_data in all_ner_terms:
        term = term_data['term']
        normalized = normalize_term_for_comparison(term)
        
        if normalized not in unique_terms:
            unique_terms[normalized] = term_data
        
        term_counter[normalized] += 1
    
    # Add frequency information
    for normalized_term, term_data in unique_terms.items():
        term_data['frequency'] = term_counter[normalized_term]
    
    print(f"Found {len(unique_terms)} unique medical terms")
    
    # Get existing SAM labels
    sam_labels = sam_df['Label'].dropna().unique().tolist()
    print(f"Comparing with {len(sam_labels)} existing SAM labels")
    
    # Compare with existing annotations
    new_terms, covered_terms = compare_with_existing_annotations(
        list(unique_terms.values()), sam_labels
    )
    
    print(f"\nComparison Results:")
    print(f"New terms not in SAM: {len(new_terms)}")
    print(f"Terms covered by SAM: {len(covered_terms)}")
    
    # Sort new terms by frequency and confidence
    new_terms.sort(key=lambda x: (x['frequency'], x['confidence']), reverse=True)
    
    # Prepare output DataFrame
    output_records = []
    for i, term_data in enumerate(new_terms):
        output_records.append({
            'rank': i + 1,
            'term': term_data['term'],
            'frequency': term_data['frequency'],
            'confidence': term_data['confidence'],
            'type': term_data['type'],
            'location_hints': str(term_data.get('location_hint', {})),
            'sample_image_id': term_data['image_id'],
            'sample_caption': term_data['source_caption'],
            'status': term_data['status']
        })
    
    # Create output directory if needed
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    
    # Save to CSV
    output_df = pd.DataFrame(output_records)
    output_df.to_csv(OUTPUT_PATH, index=False)
    
    print(f"\nResults saved to: {OUTPUT_PATH}")
    
    # Display top 15 new terms
    print(f"\nTOP 15 NEW MEDICAL TERMS:")
    print("=" * 80)
    print(f"{'Rank':<4} {'Term':<25} {'Freq':<6} {'Conf':<6} {'Location Hints':<20} {'Sample ID':<10}")
    print("-" * 80)
    
    for i, record in enumerate(output_records[:15]):
        location_display = record['location_hints'][:18] + '..' if len(record['location_hints']) > 20 else record['location_hints']
        print(f"{record['rank']:<4} {record['term']:<25} {record['frequency']:<6} {record['confidence']:<6.2f} {location_display:<20} {record['sample_image_id']:<10}")
    
    print("-" * 80)
    print(f"Total new terms found: {len(new_terms)}")
    print(f"Coverage rate: {len(covered_terms)/(len(covered_terms)+len(new_terms))*100:.1f}%")
    
    return output_df

if __name__ == "__main__":
    result_df = process_ner_extraction()
    print("\nNER extraction and comparison completed!")

Starting Medical NER Extraction and Comparison
Loading data files...
Loaded 16 explanations and 110 SAM annotations

Extracting medical terms from captions...
Extracted 36 raw medical terms

Cleaning and deduplicating terms...
Found 18 unique medical terms
Comparing with 40 existing SAM labels

Comparison Results:
New terms not in SAM: 12
Terms covered by SAM: 6

Results saved to: ./data/ner.csv

TOP 15 NEW MEDICAL TERMS:
Rank Term                      Freq   Conf   Location Hints       Sample ID 
--------------------------------------------------------------------------------
1    Lymph Nodess              4      0.80   {'left': True, 'ar.. ImageCLEFmedical_Caption_2025_test_15167
2    Right Breast              3      0.80   {'right': True, 'a.. ImageCLEFmedical_Caption_2025_test_4346
3    Air-Fluid Level           2      0.80   {'anterior': True,.. ImageCLEFmedical_Caption_2025_test_118
4    Stenosis                  2      0.80   {'left': True, 'ar.. ImageCLEFmedical_Caption_2025_te

In [26]:
"""
NER Pipeline Adapter
Uses existing pipeline output format and preserves duplicates (multiple findings per image)
"""

import pandas as pd
import json
import os
import time
from datetime import datetime

# Configuration
IMAGE_FOLDER = "./data/test_set_explain/"
NER_FILE = './data/ner.csv'
CAPTIONS_FILE = './3_submission_explainability.csv'
ORIGINAL_SAM_COORD = './sam_coord.csv'
NER_COORD_OUTPUT = './ner_coord.csv'
MERGED_SAM_COORD = './sam_coord.csv'  # Overwrites original with merged data
BACKUP_SAM_COORD = './sam_coord_backup.csv'  # Backup of original

def backup_original_sam_coord():
    """Create backup of original sam_coord.csv before merging."""
    if os.path.exists(ORIGINAL_SAM_COORD):
        print(f"Creating backup: {BACKUP_SAM_COORD}")
        import shutil
        shutil.copy2(ORIGINAL_SAM_COORD, BACKUP_SAM_COORD)
        return True
    else:
        print(f"Original sam_coord.csv not found at: {ORIGINAL_SAM_COORD}")
        return False

def prepare_ner_data_for_pipeline():
    """Prepare NER data in the same format as the original concepts pipeline expects."""
    print("Loading NER terms and captions...")
    
    # Load files
    ner_df = pd.read_csv(NER_FILE)
    captions_df = pd.read_csv(CAPTIONS_FILE)
    
    # Create caption mapping
    cap_map = {str(row["ID"]): str(row["Caption"]) for _, row in captions_df.iterrows()}
    
    # Group NER terms by image_id
    image_groups = ner_df.groupby('sample_image_id')
    
    # Create a concepts-like DataFrame
    concepts_like_data = []
    
    for image_id, group in image_groups:
        # Get all terms for this image (no filtering - all terms)
        terms_list = group['term'].tolist()
        
        # Join terms with semicolon (like CUIs in original format)
        terms_string = '; '.join(terms_list)
        
        concepts_like_data.append({
            'ID': image_id,
            'CUIs': terms_string
        })
    
    concepts_like_df = pd.DataFrame(concepts_like_data)
    
    print(f"Loaded {len(ner_df)} NER terms")
    print(f"Processing {len(concepts_like_df)} images with NER terms")
    print(f"Available captions: {len(captions_df)}")
    
    return concepts_like_df, cap_map

def merge_with_original_sam_coord(ner_csv_rows):
    """
    Merge new NER coordinates with existing sam_coord.csv.
    Preserves duplicates since multiple findings per image are valid.
    """
    print("\nMerging with existing sam_coord.csv...")
    
    # Prepare NER data for merging
    ner_data_lines = []
    if ner_csv_rows:
        # Skip header if present, keep all data rows
        for row in ner_csv_rows:
            if row and not row.lower().startswith('imageid'):
                ner_data_lines.append(row)
    
    if not os.path.exists(ORIGINAL_SAM_COORD):
        print(f"Original sam_coord.csv not found. Creating new file with NER data only.")
        # Create new file with header + NER data
        with open(MERGED_SAM_COORD, 'w') as f:
            f.write("ImageID,Label,x,y,width,height\n")
            for line in ner_data_lines:
                f.write(line + "\n")
        merged_count = len(ner_data_lines)
    else:
        # Read original file
        with open(ORIGINAL_SAM_COORD, 'r') as f:
            original_lines = f.readlines()
        
        original_count = len(original_lines) - 1  # Subtract header
        print(f"Original sam_coord.csv: {original_count} records")
        
        # Write merged file
        with open(MERGED_SAM_COORD, 'w') as f:
            # Write original content
            for line in original_lines:
                f.write(line)
            
            # Append NER data (no duplicates check - preserving all findings)
            for line in ner_data_lines:
                f.write(line + "\n")
        
        merged_count = original_count + len(ner_data_lines)
    
    print(f"Merged file: {merged_count} total records")
    print(f"Added {len(ner_data_lines)} new NER records")
    
    return merged_count

def run_ner_pipeline():
    """
    Run NER pipeline using existing functions and preserve original output format.
    """
    print("Starting NER Terms Pipeline (Preserving Original Format)")
    print("=" * 60)
    
    # Check files exist
    if not os.path.exists(NER_FILE):
        print(f"ERROR: NER file not found: {NER_FILE}")
        return None, None
    
    if not os.path.exists(CAPTIONS_FILE):
        print(f"ERROR: Captions file not found: {CAPTIONS_FILE}")
        return None, None
    
    # Create backup of original sam_coord.csv
    backup_created = backup_original_sam_coord()
    
    # Prepare data in the format expected by existing pipeline
    concepts_df, cap_map = prepare_ner_data_for_pipeline()
    
    # Use the existing pipeline logic - EXACTLY as original
    all_csv_rows = []
    all_json_objs = []
    
    for idx, row in concepts_df.iterrows():
        image_id = str(row["ID"])
        terms_raw = [c.strip() for c in str(row["CUIs"]).split(';') if c.strip()]
        caption = cap_map.get(image_id, "")
        
        # Use existing function to get image path
        image_path = get_image_path(image_id)
        if image_path is None:
            print(f"[{image_id}] Image not found, skipping.")
            continue
        
        # Use existing function to group compound terms
        terms = group_compound_terms(terms_raw)
        
        print(f"\n=== Processing {image_id} ===")
        print(f"Original NER terms: {terms_raw}")
        print(f"Grouped terms: {terms}")
        
        if not terms:
            print(f"[{image_id}] No valid terms after grouping, skipping.")
            continue
        
        # Use existing vision function
        vision_results = call_gpt4v_vision(image_path, terms, caption)
        if not vision_results:
            print(f"[{image_id}] No vision results, skipping.")
            continue
        
        print(f"Vision model returned {len(vision_results)} detections")
        
        # Use existing refiner function
        csv_out, json_out = call_gpt4_1_refiner(vision_results, terms, caption, image_id)
        
        # Process CSV output - EXACTLY as original pipeline
        if csv_out:
            if not all_csv_rows and "ImageID" in csv_out:
                header, *rest = csv_out.splitlines()
                all_csv_rows.append(header)
                all_csv_rows.extend(rest)
            else:
                lines = [line for line in csv_out.splitlines() if not line.lower().startswith("imageid")]
                all_csv_rows.extend(lines)
        
        # Process JSON output - EXACTLY as original pipeline
        if json_out and json_out != "[]":
            try:
                data = json.loads(json_out)
                if isinstance(data, list):
                    all_json_objs.extend(data)
                    print(f"Added {len(data)} JSON objects")
            except Exception as e:
                print(f"[{image_id}] Error parsing JSON output: {e}")
        
        # Rate limiting
        time.sleep(2.5)
    
    # Save ner_coord.csv - EXACTLY as original pipeline saves
    print(f"\nSaving NER coordinates to: {NER_COORD_OUTPUT}")
    with open(NER_COORD_OUTPUT, "w") as f:
        for row in all_csv_rows:
            f.write(row + "\n")
    
    # Merge with original sam_coord.csv (preserving duplicates)
    merged_count = merge_with_original_sam_coord(all_csv_rows)
    
    print(f"\nNER Pipeline completed!")
    print(f"Final Summary:")
    print(f"NER coordinates: {len(all_csv_rows)} rows → {NER_COORD_OUTPUT}")
    print(f"Merged coordinates: {merged_count} total rows → {MERGED_SAM_COORD}")
    if backup_created:
        print(f"Original backup: {BACKUP_SAM_COORD}")
    print(f"JSON objects: {len(all_json_objs)}")
    print(f"Duplicates preserved (multiple findings per image allowed)")
    
    return all_csv_rows, merged_count

# ========== EXECUTION ==========

if __name__ == "__main__":
    """
    To use this with existing functions in your notebook:
    
    1. Make sure you have the existing functions available:
       - get_image_path()
       - group_compound_terms() 
       - call_gpt4v_vision()
       - call_gpt4_1_refiner()
    
    2. Then call:
       csv_rows, merged_count = run_ner_pipeline()
    """
    
    print("Required functions from your notebook:")
    print("   - get_image_path()")
    print("   - group_compound_terms()")
    print("   - call_gpt4v_vision()")
    print("   - call_gpt4_1_refiner()")
    print()
    print("Files that will be created/modified:")
    print(f"   {NER_COORD_OUTPUT} (new NER coordinates)")
    print(f"   {MERGED_SAM_COORD} (merged coordinates)")
    print(f"   {BACKUP_SAM_COORD} (backup of original)")
    print()
    print("Medical context: Duplicates preserved (multiple findings per image)")
    print()
    print("To run: csv_rows, merged_count = run_ner_pipeline()")

Required functions from your notebook:
   - get_image_path()
   - group_compound_terms()
   - call_gpt4v_vision()
   - call_gpt4_1_refiner()

Files that will be created/modified:
   ./ner_coord.csv (new NER coordinates)
   ./sam_coord.csv (merged coordinates)
   ./sam_coord_backup.csv (backup of original)

Medical context: Duplicates preserved (multiple findings per image)

To run: csv_rows, merged_count = run_ner_pipeline()


In [27]:
# Ejecutar el pipeline
csv_rows, merged_count = run_ner_pipeline()

Starting NER Terms Pipeline (Preserving Original Format)
Creating backup: ./sam_coord_backup.csv
Loading NER terms and captions...
Loaded 12 NER terms
Processing 9 images with NER terms
Available captions: 16

=== Processing ImageCLEFmedical_Caption_2025_test_118 ===
Original NER terms: ['Air-Fluid Level']
Grouped terms: ['Air-fluid Level']
Prompt:
 
You are an expert radiologist. Analyze the provided medical image together with these terms: ['Air-fluid Level'], and caption: "Plain radiograph of the abdomen, taken in the anterior-posterior projection, shows multiple air-fluid levels within the bowel loops, extending into the pelvis. The bone structure of the ilium and pubis is unremarkable.".

Important: In radiology, "left" and "right" alway ...
Image exists? True
Vision model returned 5 detections

=== GPT-4.1 RAW REPLY ===
 **CSV Output**

```
ImageID,Label,x,y,width,height
ImageCLEFmedical_Caption_2025_test_118,Air-fluid Level,160,280,40,40
ImageCLEFmedical_Caption_2025_test_118,Ai

### NLP Techniques to detect missing terms

In [None]:
import pandas as pd
import re
import csv
from collections import defaultdict
import logging

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def extract_medical_terms_enhanced(text):
    """
    Enhanced rule-based extraction of medical terms from captions.
    Ensures measurements are always extracted with their associated nouns.
    """
    if not text or pd.isna(text):
        return []
    
    text = text.lower()
    medical_terms = []
    
    # PRIORITY: Complete measurement phrases with descriptors and nouns
    measurement_phrase_patterns = [
        # Complete phrases: measurement + descriptor + noun
        r'\b\d+(?:\.\d+)?\s*(?:cm|mm)\s+(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed|echogenic|echolucent|solid|cystic|complex|heterogeneous|homogeneous|well-defined|ill-defined|poorly-defined)\s+(?:mass|lesion|nodule|cyst|tumor|tumour|growth|collection|area|region|structure)\b',
        
        # Reverse order: descriptor + noun + measurement
        r'\b(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed|echogenic|echolucent|solid|cystic|complex|heterogeneous|homogeneous|well-defined|ill-defined|poorly-defined)\s+(?:mass|lesion|nodule|cyst|tumor|tumour|growth|collection|area|region|structure)\s+(?:measuring|of|approximately)\s+\d+(?:\.\d+)?\s*(?:cm|mm)\b',
        
        # Noun + measuring + dimension
        r'\b(?:mass|lesion|nodule|cyst|tumor|tumour|growth|collection|area|region|structure)\s+measuring\s+\d+(?:\.\d+)?\s*(?:x\s*\d+(?:\.\d+)?)?\s*(?:x\s*\d+(?:\.\d+)?)?\s*(?:cm|mm)\b',
        
        # Multi-dimensional measurements with nouns
        r'\b\d+(?:\.\d+)?\s*(?:x\s*\d+(?:\.\d+)?)?\s*(?:x\s*\d+(?:\.\d+)?)?\s*(?:cm|mm)\s+(?:mass|lesion|nodule|tumor|tumour|cyst|growth|collection|area|region)\b',
        
        # Approximate measurements with nouns
        r'\b(?:approximately|about|roughly)\s+\d+(?:\.\d+)?\s*(?:cm|mm)\s+(?:mass|lesion|nodule|tumor|tumour|cyst|growth|collection|area|region)\b',
        r'\b(?:mass|lesion|nodule|tumor|tumour|cyst|growth|collection|area|region)\s+(?:approximately|about|roughly)\s+\d+(?:\.\d+)?\s*(?:cm|mm)\b',
        
        # Multi-word descriptive phrases with measurements
        r'\b\d+(?:\.\d+)?\s*(?:cm|mm)\s+(?:well|ill|poorly)\s*-?\s*(?:defined|circumscribed|demarcated)\s+(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed|echogenic)\s+(?:mass|lesion|nodule|cyst|structure)\b',
        
        # Complex anatomical measurements
        r'\b(?:approximately|about|roughly)\s+\d+(?:\.\d+)?\s*(?:cm|mm)\s+(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed|echogenic|solid|cystic|complex)\s+(?:mass|lesion|nodule|cyst|structure)\s+(?:in|within|at)\s+(?:the\s+)?(?:liver|kidney|thyroid|pancreas|gallbladder|uterus|ovary|breast|heart)\b'
    ]
    
    # Extract complete measurement phrases first (highest priority)
    for pattern in measurement_phrase_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        for match in matches:
            if isinstance(match, tuple):
                term = ' '.join(match).strip()
            else:
                term = match.strip()
            
            if len(term) > 5:  # Ensure meaningful phrases
                medical_terms.append(term.lower())
                logger.debug(f"Captured measurement phrase: {term}")
    
    # Enhanced medical patterns - MEASUREMENTS MUST BE LINKED TO NOUNS
    medical_patterns = [
        # Specific ultrasound characteristics WITH CONTEXT (no standalone characteristics)
        r'\b(?:well|ill|poorly)\s*-?\s*(?:defined|circumscribed|demarcated)\s+(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth)\b',
        r'\b(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth)\s+(?:with\s+)?(?:well|ill|poorly)\s*-?\s*(?:defined|circumscribed|demarcated)\s+(?:margins|borders|contours)\b',
        r'\b(?:heterogeneous|homogeneous)\s+(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth|echogenicity|echotexture|appearance)\b',
        r'\b(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth)\s+(?:with\s+)?(?:heterogeneous|homogeneous)\s+(?:echogenicity|echotexture|appearance)\b',
        r'\b(?:acoustic|posterior)\s+(?:shadowing|enhancement)\s+(?:from|behind|of)\s+(?:mass|lesion|nodule|cyst|structure)\b',
        
        # Echo characteristics ONLY with anatomical context
        r'\b(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed\s+echogenicity|echogenic|echolucent|echopoor|echorich)\s+(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth|structure|parenchyma|wall|cortex|medulla)\b',
        r'\b(?:mass|lesion|nodule|area|region|cyst|tumor|tumour|growth|structure|parenchyma|wall|cortex|medulla)\s+(?:with\s+|showing\s+)?(?:hypoechoic|hyperechoic|isoechoic|anechoic|mixed|echogenic|echolucent|echopoor|echorich)\s+(?:echogenicity|appearance|characteristics)\b',
        
        # Pathologies and conditions (standalone medical terms - these are OK without measurements)
        r'\b(?:mass|tumor|tumour|lesion|nodule|growth|neoplasm|cyst|polyp|hematoma|hemorrhage|bleeding|abscess|collection|calcification|thrombus|thrombosis|embolism|stenosis|occlusion|obstruction|dilatation|dilation|distension|inflammation|infection|fibrosis|scarring)\b',
        
        # Anatomical structures with modifiers (these provide medical context)
        r'\b(?:gallbladder|cholecyst)\s+(?:wall|mucosa|fundus|neck|distension|thickening)\b',
        r'\b(?:liver|hepatic)\s+(?:parenchyma|segment|lobe|capsule|surface|contour)\b',
        r'\b(?:pancreatic|pancreas)\s+(?:head|body|tail|duct|parenchyma|contour)\b',
        r'\b(?:renal|kidney)\s+(?:cortex|medulla|pelvis|calix|capsule|parenchyma|collecting\s+system)\b',
        r'\b(?:thyroid|thyroidal)\s+(?:gland|nodule|lobe|isthmus|parenchyma|capsule)\b',
        r'\b(?:uterine|uterus)\s+(?:fundus|body|cervix|wall|cavity|endometrium|myometrium)\b',
        r'\b(?:ovarian|ovary|ovaries)\s+(?:follicle|cyst|mass|parenchyma|capsule|stroma)\b',
        r'\b(?:cardiac|heart)\s+(?:chamber|valve|wall|septum|muscle|pericardium)\b',
        r'\b(?:vascular|vessel|artery|vein)\s+(?:wall|lumen|flow|diameter|caliber)\b',
        
        # Organs and anatomical structures (contextual - these are meaningful standalone)
        r'\b(?:liver|hepatic|gallbladder|pancreas|pancreatic|spleen|splenic|kidney|renal|bladder|prostate|prostatic|thyroid|thyroidal|parathyroid|carotid|uterus|uterine|ovary|ovarian|cervix|cervical|heart|cardiac|aorta|aortic|mitral|tricuspid|breast|mammary|axillary|lymph\s+node|pleura|pleural|lung|pulmonary|bronchial)\b',
        
        # Texture/characteristics ONLY with anatomical context
        r'\b(?:solid|cystic|complex|simple)\s+(?:mass|lesion|nodule|cyst|structure|component|area|region)\b',
        r'\b(?:mass|lesion|nodule|cyst|structure|component|area|region)\s+(?:with\s+)?(?:solid|cystic|complex|simple)\s+(?:components|characteristics|appearance|features)\b',
        r'\b(?:regular|irregular)\s+(?:contour|margin|border|outline)\s+(?:of\s+)?(?:mass|lesion|nodule|cyst|organ|structure)\b',
        r'\b(?:mass|lesion|nodule|cyst|organ|structure)\s+(?:with\s+)?(?:regular|irregular)\s+(?:contour|margin|border|outline)\b',
        r'\b(?:increased|decreased|normal)\s+(?:echogenicity|vascularity|enhancement)\s+(?:of\s+)?(?:mass|lesion|organ|parenchyma|structure)\b',
        r'\b(?:doppler|color\s+flow|power\s+doppler)\s+(?:signal|assessment|evaluation)\s+(?:of\s+)?(?:mass|lesion|vessel|organ)\b',
        
        # Specific anatomical locations with direction/position
        r'\b(?:right|left|bilateral|unilateral)\s+(?:upper|lower|middle)\s+(?:pole|third|quadrant|lobe|segment)\b',
        r'\b(?:anterior|posterior|superior|inferior|medial|lateral)\s+(?:wall|aspect|portion|surface)\s+(?:of\s+)?(?:organ|structure|mass|lesion)\b',
        r'\b(?:subcapsular|intraparenchymal|extraperitoneal|retroperitoneal)\s+(?:mass|lesion|collection|hematoma|structure)\b'
    ]
    
    # Extract other medical patterns (contextual terms only)
    for pattern in medical_patterns:
        matches = re.findall(pattern, text, re.IGNORECASE)
        for match in matches:
            if isinstance(match, tuple):
                term = ' '.join(match).strip()
            else:
                term = match.strip()
            
            if len(term) > 2:  # Filter very short terms
                medical_terms.append(term.lower())
    
    # Remove duplicates and filter out standalone measurements without context
    unique_terms = []
    seen_terms = set()
    
    for term in medical_terms:
        term = term.strip()
        if len(term) < 3:
            continue
        
        # Skip standalone measurements without medical context
        if re.match(r'^\d+(?:\.\d+)?\s*(?:cm|mm)$', term):
            logger.debug(f"Skipped standalone measurement: {term}")
            continue
        
        # Skip standalone dimensions without context  
        if re.match(r'^\d+(?:\.\d+)?\s*x\s*\d+(?:\.\d+)?\s*(?:cm|mm)$', term):
            logger.debug(f"Skipped standalone dimensions: {term}")
            continue
            
        # Skip very generic terms without context
        generic_terms = {'measuring', 'approximately', 'about', 'cm', 'mm'}
        if term.lower() in generic_terms:
            logger.debug(f"Skipped generic term: {term}")
            continue
        
        if term not in seen_terms:
            unique_terms.append(term)
            seen_terms.add(term)
    
    # Sort by length (more specific terms first) and then by presence of measurements
    def sort_key(term):
        has_measurement = bool(re.search(r'\d+(?:\.\d+)?\s*(?:cm|mm)', term))
        return (has_measurement, len(term))
    
    unique_terms.sort(key=sort_key, reverse=True)
    
    return unique_terms

def load_existing_terms(sam_coord_file):
    """
    Load existing medical terms from sam_coord.csv
    """
    try:
        sam_df = pd.read_csv(sam_coord_file)
        existing_terms = set()
        
        # Assume there's a column with medical terms (adjust according to actual structure)
        # Commonly could be 'term', 'medical_term', 'object', etc.
        possible_columns = ['term', 'medical_term', 'object', 'label', 'annotation']
        
        term_column = None
        for col in possible_columns:
            if col in sam_df.columns:
                term_column = col
                break
        
        if term_column:
            existing_terms = set(sam_df[term_column].dropna().str.lower())
            logger.info(f"Loaded {len(existing_terms)} existing terms from column '{term_column}'")
        else:
            logger.warning(f"No term column found in {sam_coord_file}. Available columns: {list(sam_df.columns)}")
            # If we don't find the column, check all text columns
            for col in sam_df.columns:
                if sam_df[col].dtype == 'object':
                    existing_terms.update(sam_df[col].dropna().str.lower())
        
        return existing_terms
        
    except Exception as e:
        logger.error(f"Error loading {sam_coord_file}: {e}")
        return set()

def find_missing_medical_terms(explainability_file, sam_coord_file, output_file='missing_terms.csv'):
    """
    Main function to find missing medical terms
    """
    logger.info("Starting search for missing medical terms...")
    
    # Load existing terms
    existing_terms = load_existing_terms(sam_coord_file)
    logger.info(f"Existing terms loaded: {len(existing_terms)}")
    
    # Load explainability file
    try:
        explain_df = pd.read_csv(explainability_file)
        logger.info(f"Explainability file loaded: {len(explain_df)} rows")
    except Exception as e:
        logger.error(f"Error loading {explainability_file}: {e}")
        return
    
    # Identify relevant columns
    id_column = None
    text_column = None
    
    # Search for ID column
    for col in explain_df.columns:
        if 'id' in col.lower() or 'image' in col.lower():
            id_column = col
            break
    
    # Search for text/explanation column
    for col in explain_df.columns:
        if any(keyword in col.lower() for keyword in ['caption', 'explanation', 'text', 'description']):
            text_column = col
            break
    
    if not id_column or not text_column:
        logger.error(f"ID or text columns not found. Available columns: {list(explain_df.columns)}")
        return
    
    logger.info(f"Using ID column: '{id_column}', text column: '{text_column}'")
    
    # Process each row
    missing_terms_data = []
    processed_count = 0
    
    for idx, row in explain_df.iterrows():
        image_id = row[id_column]
        caption = row[text_column]
        
        if pd.isna(caption) or pd.isna(image_id):
            continue
        
        # Extract medical terms from explanation
        extracted_terms = extract_medical_terms_enhanced(caption)
        
        # Find missing terms
        missing_terms = []
        for term in extracted_terms:
            # Check if the term or any variation is in existing terms
            term_found = False
            for existing_term in existing_terms:
                if term in existing_term or existing_term in term:
                    term_found = True
                    break
            
            if not term_found:
                missing_terms.append(term)
        
        # If there are missing terms, add to results
        if missing_terms:
            for term in missing_terms:
                missing_terms_data.append({
                    'image_id': image_id,
                    'missing_term': term,
                    'original_caption': caption[:200] + '...' if len(caption) > 200 else caption
                })
        
        processed_count += 1
        if processed_count % 100 == 0:
            logger.info(f"Processed {processed_count} images...")
    
    # Save results
    if missing_terms_data:
        missing_df = pd.DataFrame(missing_terms_data)
        
        # Remove duplicates
        missing_df = missing_df.drop_duplicates(subset=['image_id', 'missing_term'])
        
        # Save CSV
        missing_df.to_csv(output_file, index=False)
        logger.info(f"File {output_file} created with {len(missing_df)} missing terms")
        
        # Show statistics
        unique_images = missing_df['image_id'].nunique()
        unique_terms = missing_df['missing_term'].nunique()
        
        logger.info(f"Summary:")
        logger.info(f"- Images with missing terms: {unique_images}")
        logger.info(f"- Unique missing terms: {unique_terms}")
        logger.info(f"- Total records: {len(missing_df)}")
        
        # Show most common missing terms
        term_counts = missing_df['missing_term'].value_counts().head(10)
        logger.info(f"\nMost frequent missing terms:")
        for term, count in term_counts.items():
            logger.info(f"  - '{term}': {count} times")
        
        return missing_df
    else:
        logger.info("No missing medical terms found")
        return None

# Example usage function
def main():
    """
    Main example function
    """
    explainability_file = '3_submission_explainability.csv'
    sam_coord_file = 'sam_coord.csv'
    output_file = 'missing_terms.csv'
    
    result = find_missing_medical_terms(explainability_file, sam_coord_file, output_file)
    
    if result is not None:
        print(f"\nProcess completed. Check the file {output_file}")
        print(f"Example of missing terms for ImageCLEFmedical_Caption_2025_test_4346:")
        
        # Search for the specific image mentioned
        example_image = result[result['image_id'].str.contains('4346', na=False)]
        if not example_image.empty:
            print(example_image[['image_id', 'missing_term']].to_string(index=False))
        else:
            print("No missing terms found for that specific image")

if __name__ == "__main__":
    main()

INFO:__main__:Starting search for missing medical terms...
INFO:__main__:Existing terms loaded: 83
INFO:__main__:Explainability file loaded: 16 rows
INFO:__main__:Using ID column: 'ID', text column: 'Caption'
INFO:__main__:File missing_terms.csv created with 4 missing terms
INFO:__main__:Summary:
INFO:__main__:- Images with missing terms: 3
INFO:__main__:- Unique missing terms: 4
INFO:__main__:- Total records: 4
INFO:__main__:
Most frequent missing terms:
INFO:__main__:  - 'mitral': 1 times
INFO:__main__:  - 'spleen': 1 times
INFO:__main__:  - '1.5 cm hypoechoic mass': 1 times
INFO:__main__:  - 'hypoechoic mass': 1 times



Process completed. Check the file missing_terms.csv
Example of missing terms for ImageCLEFmedical_Caption_2025_test_4346:
                               image_id           missing_term
ImageCLEFmedical_Caption_2025_test_4346 1.5 cm hypoechoic mass
ImageCLEFmedical_Caption_2025_test_4346        hypoechoic mass


In [51]:
import time
import shutil

In [41]:
"""
Missing Medical Terms Pipeline Adapter
Reuses existing notebook functions to process missing terms
"""

# Configuration
MISSING_TERMS_FILE = './missing_terms.csv'
CAPTIONS_FILE = './3_submission_explainability.csv'
SAM_COORD_FILE = './sam_coord.csv'
MISSING_COORD_OUTPUT = './missing_terms_coord.csv'
BACKUP_SAM_COORD = './sam_coord_backup.csv'

def prepare_missing_terms_data():
    """Convert missing_terms.csv to the format expected by existing functions."""
    missing_df = pd.read_csv(MISSING_TERMS_FILE)
    captions_df = pd.read_csv(CAPTIONS_FILE)
    
    # Create caption mapping (auto-detect columns)
    id_col = next((col for col in captions_df.columns if 'id' in col.lower()), None)
    cap_col = next((col for col in captions_df.columns if any(k in col.lower() for k in ['caption', 'explanation', 'text'])), None)
    
    cap_map = {str(row[id_col]): str(row[cap_col]) for _, row in captions_df.iterrows()}
    
    # Group terms by image (concepts format)
    concepts_data = []
    for image_id, group in missing_df.groupby('image_id'):
        terms = '; '.join(group['missing_term'].unique())
        concepts_data.append({'ID': image_id, 'CUIs': terms})
    
    return pd.DataFrame(concepts_data), cap_map

def run_missing_terms_pipeline():
    """Pipeline to process missing medical terms using existing functions."""
    
    print("Processing Missing Medical Terms...")
    print("=" * 50)
    
    # Backup sam_coord.csv
    if os.path.exists(SAM_COORD_FILE):
        shutil.copy2(SAM_COORD_FILE, BACKUP_SAM_COORD)
        print(f"Backup created: {BACKUP_SAM_COORD}")
    
    # Prepare data
    concepts_df, cap_map = prepare_missing_terms_data()
    print(f"Processing {len(concepts_df)} images with missing terms")
    
    # Use existing pipeline
    all_csv_rows = []
    all_json_objs = []
    
    for idx, row in concepts_df.iterrows():
        image_id = str(row["ID"])
        terms_raw = [c.strip() for c in str(row["CUIs"]).split(';') if c.strip()]
        caption = cap_map.get(image_id, "")
        
        # Reuse existing functions
        image_path = get_image_path(image_id)
        if image_path is None:
            continue
        
        terms = group_compound_terms(terms_raw)
        if not terms:
            continue
        
        print(f"\n[{image_id}] Processing {len(terms)} missing terms: {terms}")
        
        # GPT-4 Vision
        vision_results = call_gpt4v_vision(image_path, terms, caption)
        if not vision_results:
            continue
        
        # Refiner
        csv_out, json_out = call_gpt4_1_refiner(vision_results, terms, caption, image_id)
        
        # Process outputs (same format as previous adapter)
        if csv_out:
            if not all_csv_rows and "ImageID" in csv_out:
                header, *rest = csv_out.splitlines()
                all_csv_rows.append(header)
                all_csv_rows.extend(rest)
            else:
                lines = [line for line in csv_out.splitlines() if not line.lower().startswith("imageid")]
                all_csv_rows.extend(lines)
        
        if json_out and json_out != "[]":
            try:
                data = json.loads(json_out)
                if isinstance(data, list):
                    all_json_objs.extend(data)
            except:
                pass
        
        time.sleep(2.5)  # Rate limiting
    
    # Save missing_terms_coord.csv
    with open(MISSING_COORD_OUTPUT, "w") as f:
        for row in all_csv_rows:
            f.write(row + "\n")
    
    # Update sam_coord.csv (merge)
    if os.path.exists(SAM_COORD_FILE):
        with open(SAM_COORD_FILE, 'r') as f:
            original_lines = f.readlines()
        original_count = len(original_lines) - 1
    else:
        original_lines = ["ImageID,Label,x,y,width,height\n"]
        original_count = 0
    
    # Write final merged file
    with open(SAM_COORD_FILE, 'w') as f:
        for line in original_lines:
            f.write(line)
        for row in all_csv_rows:
            if row and not row.lower().startswith("imageid"):
                f.write(row + "\n")
    
    new_count = len([r for r in all_csv_rows if r and not r.lower().startswith("imageid")])
    total_count = original_count + new_count
    
    print(f"\nCompleted!")
    print(f"Missing terms coordinates: {new_count} → {MISSING_COORD_OUTPUT}")
    print(f"Updated sam_coord.csv: {total_count} total records")
    print(f"Original backup: {BACKUP_SAM_COORD}")
    
    return all_csv_rows, total_count

# Complete function that includes find_missing_medical_terms
def run_complete_workflow():
    """Execute everything: find missing terms + process with GPT-4."""
    
    # Step 1: Find missing terms
    print("Step 1: Finding missing medical terms...")
    missing_df = find_missing_medical_terms('3_submission_explainability.csv', 'sam_coord.csv', 'missing_terms.csv')
    
    if missing_df is None or missing_df.empty:
        print("No missing terms found.")
        return None, None
    
    print(f"Found {len(missing_df)} missing terms")
    
    # Step 2: Process with GPT-4
    print("\nStep 2: Processing with GPT-4 Vision...")
    return run_missing_terms_pipeline()

if __name__ == "__main__":
    print("Missing Medical Terms Pipeline")
    print("Required functions: get_image_path, group_compound_terms, call_gpt4v_vision, call_gpt4_1_refiner")
    print("Usage:")
    print("  csv_rows, total = run_complete_workflow()  # Complete pipeline") 
    print("  csv_rows, total = run_missing_terms_pipeline()  # Process existing missing_terms.csv")

Missing Medical Terms Pipeline
Required functions: get_image_path, group_compound_terms, call_gpt4v_vision, call_gpt4_1_refiner
Usage:
  csv_rows, total = run_complete_workflow()  # Complete pipeline
  csv_rows, total = run_missing_terms_pipeline()  # Process existing missing_terms.csv


In [42]:
# Run the pipeline
csv_rows, total_count = run_missing_terms_pipeline()

Processing Missing Medical Terms...
Backup created: ./sam_coord_backup.csv
Processing 3 images with missing terms

[ImageCLEFmedical_Caption_2025_test_1251] Processing 1 missing terms: ['mitral']
Prompt:
 
You are an expert radiologist. Analyze the provided medical image together with these terms: ['mitral'], and caption: "Transthoracic ultrasonography echocardiogram showing a large vegetation on the mitral valve involving the heart ventricle, with possible thrombus formation affecting the right ventricular structure.".

Important: In radiology, "left" and "right" always refer to the patient's left ...
Image exists? True


INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



=== GPT-4.1 RAW REPLY ===
 **Step-by-step Rationalization:**

1. **Compound Terms**: Medical terms list: ['mitral']  
   - "mitral" is not a complete anatomic structure; most likely referring to "mitral valve" (compound word). Use "mitral valve" as the preserved term.

2. **Exclude Imaging Modality Terms**:  
   - Exclude terms like "ultrasound", "ultrasonography", "echocardiogram."
   - Do not split terms (e.g. "heart ventricle" remains as is).

3. **Analyze Detections**:  
   - "mitral valve" (x=240, y=120, width=50, height=50)
   - "heart ventricle" (x=200, y=160, width=70, height=70)
   - "right ventricular structure" (uncertain, x=100, y=200, width=70, height=70)
   - "large vegetation" (x=240, y=140, width=50, height=50)
   - "possible thrombus formation" (uncertain, x=120, y=180, width=60, height=60)
   - "arrow" (arrow_tip at x=230, y=135)

4. **Candidate Terms to Output:**  
   - "mitral valve" – present as a detection
   - "large vegetation" – is a finding aligned with "vege

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



=== GPT-4.1 RAW REPLY ===
 **Step-by-step solution:**

1. **Medical Terms to Retain (from list):** `['spleen']`  
   - It's not split; keep as is.
2. **Omit Modalities:** Already omitted (PET/CT, etc.) per instructions.
3. **Bounding Box Validation:**  
   - "spleen": **No bounding box given**, only arrow tip at (160, 210). Per instruction, create a default 40x40 px box centered at that point. So:
     - x = 160 - 20 = 140
     - y = 210 - 20 = 190
     - width = 40
     - height = 40

---

**CSV:**
```
ImageID,Label,x,y,width,height
ImageCLEFmedical_Caption_2025_test_15167,spleen,140,190,40,40
```

---

**JSON:**
```json
[
  {
    "ImageID": "ImageCLEFmedical_Caption_2025_test_15167",
    "Label": "spleen",
    "x": 140,
    "y": 190,
    "width": 40,
    "height": 40
  }
]
``` 
=== END REPLY ===


[ImageCLEFmedical_Caption_2025_test_4346] Processing 2 missing terms: ['1.5 Cm Hypoechoic Mass', 'Hypoechoic Mass']
Prompt:
 
You are an expert radiologist. Analyze the provided medical im

INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"



=== GPT-4.1 RAW REPLY ===
 CSV:
```
ImageID,Label,x,y,width,height
ImageCLEFmedical_Caption_2025_test_4346,Hypoechoic Mass,180,60,40,40
ImageCLEFmedical_Caption_2025_test_4346,Hypoechoic Mass,180,50,40,40
```

JSON:
```json
[
  {
    "ImageID": "ImageCLEFmedical_Caption_2025_test_4346",
    "Label": "Hypoechoic Mass",
    "x": 180,
    "y": 60,
    "width": 40,
    "height": 40
  },
  {
    "ImageID": "ImageCLEFmedical_Caption_2025_test_4346",
    "Label": "Hypoechoic Mass",
    "x": 180,
    "y": 50,
    "width": 40,
    "height": 40
  }
]
```

**Notes:**
- Compound term "Hypoechoic Mass" is preserved as one entity (matches both terms in medical terms list: '1.5 Cm Hypoechoic Mass', 'Hypoechoic Mass').
- General modality terms (e.g., "ultrasound") are not included.
- The first detection had a valid bounding box.
- The second had only an arrow_tip at (200,70). Default box (40x40) is centered at the arrow tip, so x = 200-20=180, y = 70-20=50.
- Both are labeled "Hypoechoic Mass" becaus

### SAM: Defining boxes

After applying generative AI, NER, and other NLP techniques, we finally have a file to feed SAM with the key medical terms and their approximate positions for refinement through computing vision techniques.

In [17]:
# ==============================================================================
# SECTION I: IMPORTS, CONFIGURATION, AND INITIAL SETUP
# ==============================================================================
print("----------------------------------------------------------------------")
print("SECTION I: INITIALIZING LIBRARIES, CONFIGURATION, AND CUDA")
print("----------------------------------------------------------------------")

# Standard Libraries
import os
import gc # Garbage collection
import time # For adding pauses
import json
import re
import warnings
import urllib.request
from collections import defaultdict
import csv # Though pandas is used more for CSVs later
import logging # For more structured logging if desired

# Data Handling & Numerics
import pandas as pd
import numpy as np

# Image Processing
import cv2
from scipy import ndimage
from skimage.measure import label as skimage_label, regionprops
from scipy.spatial.distance import euclidean
try:
    from skimage import feature
    SKIMAGE_AVAILABLE = True
    print("✓ Scikit-image (skimage) available.")
except ImportError:
    SKIMAGE_AVAILABLE = False
    print(" Scikit-image (skimage) not available. Some keypoint features will be disabled.")

# Plotting & Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
from mpl_toolkits.mplot3d import Axes3D # For 3D heatmap
import textwrap # For wrapping text in visualizations

# Machine Learning & Deep Learning - PyTorch & Transformers
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
# from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline # Commented out as not used in the final logic provided
try:
    import ultralytics
    YOLO_AVAILABLE = True
    print("✓ Ultralytics YOLO available.")
except ImportError:
    YOLO_AVAILABLE = False
    print(" Ultralytics YOLO not available. YOLO enhancement will be disabled.")

# IPython/Jupyter specific (useful for notebooks, might not be needed for pure scripts)
from IPython.display import display, Image as IPImage

# --- Configuration ---
print("\n--- Loading Configuration ---")
SAM_COORD_PATH = './sam_coord.csv'
CAPTIONS_PATH = './3_submission_explainability.csv' # This seems to be the 'explanations_df' in create_image_data_table
CONCEPTS_PATH = './ref_mini_concepts_natural_.csv'
CAPTION_FILE_PATH = './caption.csv' # This seems to be the 'caption_file_df' in create_image_data_table
IMAGES_DIR = './data/test_set_explain'
OUTPUT_DIR = './result_explain-3/sam_deduplicated' # Changed to avoid overwriting original results

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"✓ Output directory set to: {OUTPUT_DIR}")

# Color palette for consistency (using the second, more vibrant palette you provided)
COLORS = [
    '#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FF8C00',
    '#DDA0DD', '#98D8C8', '#F7DC6F', '#BB8FCE', '#85C1E9',
    '#F8C471', '#82E0AA', '#F1948A', '#85C1E9', '#D5DBDB'
]
print(f"✓ Color palette loaded with {len(COLORS)} colors.")

# Warnings behavior
warnings.filterwarnings('ignore')
print("✓ Warnings are suppressed (set to 'ignore').")

# --- CUDA Memory Management Utility ---
def clear_cuda_cache():
    """
    Clears CUDA cache and runs garbage collection to free up GPU memory.
    Essential for processing multiple large images/models sequentially.
    """
    if torch.cuda.is_available():
        print(" Clearing CUDA cache...")
        torch.cuda.empty_cache()
        torch.cuda.synchronize()  # Wait for all CUDA operations to complete
        gc.collect() # Python's garbage collection
        print("✓ CUDA cache cleared and garbage collected.")
    else:
        print(" CUDA not available, skipping cache clear.")

# Initial clear at the start of the script
clear_cuda_cache()

# ==============================================================================
# SECTION II: CORE UTILITIES
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION II: DEFINING CORE UTILITY FUNCTIONS")
print("----------------------------------------------------------------------")

def hex_to_rgb(hex_color):
    """Convert hexadecimal color string to a list of RGB values (0-1 range)."""
    hex_color = hex_color.lstrip('#')
    return [int(hex_color[i:i+2], 16) / 255.0 for i in (0, 2, 4)]

def load_data():
    """
    Load all necessary CSV data files (SAM coordinates, captions, concepts, explanations).
    Handles potential file not found errors.
    """
    print("--- Loading CSV Data Files ---")
    data_loaded_successfully = True
    try:
        sam_df = pd.read_csv(SAM_COORD_PATH)
        print(f"✓ SAM coordinates loaded: {len(sam_df)} records from {SAM_COORD_PATH}")
    except FileNotFoundError:
        print(f"  ERROR: SAM coordinates file not found at {SAM_COORD_PATH}")
        sam_df = pd.DataFrame() # Return empty DataFrame
        data_loaded_successfully = False
    except Exception as e:
        print(f"  ERROR loading SAM coordinates from {SAM_COORD_PATH}: {e}")
        sam_df = pd.DataFrame()
        data_loaded_successfully = False

    try:
        # This was referred to as captions_df, but seems to hold 'explanations' for create_image_data_table
        # and 'caption' for the main visualization. Let's call it explanations_df for clarity based on its use.
        explanations_df = pd.read_csv(CAPTIONS_PATH)
        print(f"✓ Explanations/Captions (for main vis) loaded: {len(explanations_df)} records from {CAPTIONS_PATH}")
    except FileNotFoundError:
        print(f"  ERROR: Explanations file not found at {CAPTIONS_PATH}")
        explanations_df = pd.DataFrame()
        data_loaded_successfully = False
    except Exception as e:
        print(f"  ERROR loading explanations from {CAPTIONS_PATH}: {e}")
        explanations_df = pd.DataFrame()
        data_loaded_successfully = False

    if os.path.exists(CONCEPTS_PATH):
        try:
            concepts_df = pd.read_csv(CONCEPTS_PATH)
            print(f"✓ Concepts loaded: {len(concepts_df)} records from {CONCEPTS_PATH}")
        except Exception as e:
            print(f"  ERROR loading concepts from {CONCEPTS_PATH}: {e}")
            concepts_df = pd.DataFrame() # Return empty if error
            data_loaded_successfully = False
    else:
        print(f" Concepts file not found (optional): {CONCEPTS_PATH}")
        concepts_df = pd.DataFrame() # Optional, so return empty DataFrame

    if os.path.exists(CAPTION_FILE_PATH):
        try:
            # This seems to be distinct from CAPTIONS_PATH, used for table display
            caption_file_df = pd.read_csv(CAPTION_FILE_PATH)
            print(f"✓ Caption file (for table) loaded: {len(caption_file_df)} records from {CAPTION_FILE_PATH}")
        except Exception as e:
            print(f"  ERROR loading caption file from {CAPTION_FILE_PATH}: {e}")
            caption_file_df = pd.DataFrame() # Return empty if error
            data_loaded_successfully = False
    else:
        print(f" Caption file not found (optional): {CAPTION_FILE_PATH}")
        caption_file_df = pd.DataFrame() # Optional, so return empty DataFrame

    if data_loaded_successfully:
        print("✓ All available data files loaded.")
    else:
        print(" Some data files failed to load. Functionality might be affected.")
    return sam_df, explanations_df, concepts_df, caption_file_df

def find_image_file(image_id, images_dir):
    """Find image file corresponding to an ID within a directory, checking common extensions."""
    extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
    base_image_id = str(image_id).split('.')[0] # Handle if image_id has an extension

    for ext in extensions:
        full_path = os.path.join(images_dir, f"{base_image_id}{ext}")
        if os.path.exists(full_path):
            return full_path

    # Fallback: Walk through the directory if direct match fails (more exhaustive)
    for root, _, files in os.walk(images_dir):
        for file in files:
            file_base_name = os.path.splitext(file)[0]
            if base_image_id == file_base_name:
                return os.path.join(root, file)
    print(f"Could not find image for ID: {image_id} in {images_dir}")
    return None

def clean_label_text(label):
    """Clean and simplify label text by removing predefined terms and title-casing."""
    if not isinstance(label, str): # Ensure label is a string
        label = str(label)

    remove_terms = [
        "Structure of", "structure of", "Bone structure of", "bone structure of",
        "Entire", "entire", "Complete", "complete"
    ]
    cleaned = label
    for term in remove_terms:
        cleaned = cleaned.replace(term + " ", "").replace(term, "")
    return cleaned.strip().title()

print("✓ Core utility functions defined.")

# ==============================================================================
# SECTION III: SEGMENT ANYTHING MODEL (SAM) SETUP
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION III: SETTING UP SEGMENT ANYTHING MODEL (SAM)")
print("----------------------------------------------------------------------")

def setup_sam():
    """
    Sets up the SAM model (ViT-H), downloads checkpoint if needed,
    and initializes the SamPredictor.
    This version is based on your `setup_sam_fixed` for robustness.
    """
    print("--- Initializing SAM ---")
    try:
        # Ensure model directory exists
        os.makedirs('sam_models', exist_ok=True)
        sam_checkpoint = "sam_models/sam_vit_h_4b8939.pth"
        model_type = "vit_h" # Hardcoded to ViT-H as in your script

        if not os.path.exists(sam_checkpoint):
            print(f"SAM checkpoint not found at {sam_checkpoint}. Downloading (this may take a few minutes)...")
            urllib.request.urlretrieve(
                "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
                sam_checkpoint
            )
            print("✓ SAM model downloaded successfully!")

        device_type = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Attempting to load SAM model on device: {device_type}")

        # Load SAM model with error handling
        try:
            sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
            sam_model.to(device=device_type)
            # Test the model with a simple operation by creating a predictor
            predictor = SamPredictor(sam_model)
            print(f"✓ SAM model ({model_type}) loaded and predictor initialized successfully on {device_type}!")
            # Note: SamAutomaticMaskGenerator is configured differently and separately if needed.
            # Your main loop uses SamPredictor, which is what we're returning.
            return predictor, device_type, sam_model # Return sam_model for potential explicit deletion
        except Exception as e:
            print(f" CUDA SAM loading failed: {e}. Falling back to CPU for SAM.")
            if device_type == "cuda": # Only try CPU if CUDA failed
                device_type = "cpu"
                sam_model = sam_model_registry[model_type](checkpoint=sam_checkpoint)
                sam_model.to(device=device_type)
                predictor = SamPredictor(sam_model)
                print(f"✓ SAM model ({model_type}) loaded and predictor initialized successfully on CPU (fallback)!")
                return predictor, device_type, sam_model
            else: # If initial device was CPU and it failed
                print(f"  SAM model loading failed on CPU as well: {e}")
                return None, "cpu", None

    except Exception as e:
        print(f" SAM setup failed completely: {e}")
        return None, "cpu", None

# Initialize SAM globally if you intend to use the same instance throughout.
# However, your `run_ultimate_analysis_fixed` calls `setup_sam_fixed` (now `setup_sam`)
# inside, which is fine if you want to re-initialize or handle setup per run.
# For a script processing many images, initializing once is usually better.
# Let's initialize it here for the script's scope.
# The main loop will receive this predictor.

# SAM_PREDICTOR, SAM_DEVICE, SAM_MODEL_INSTANCE = setup_sam()
# The main loop will call setup_sam() to get these.

print("✓ SAM setup function defined.")

# ==============================================================================
# SECTION IV: SAM MASKING FUNCTIONS
# (Leveraging SamPredictor for more controlled segmentation based on prompts)
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION IV: DEFINING SAM MASKING LOGIC")
print("----------------------------------------------------------------------")

def get_sam_mask_from_bbox(image_rgb, bbox_coords, predictor):
    """
    Generates a segmentation mask using SAM based on a bounding box prompt.
    Args:
        image_rgb (np.ndarray): The input image in RGB format.
        bbox_coords (tuple): (x, y, w, h) for the bounding box.
        predictor (SamPredictor): The initialized SAM predictor instance.
    Returns:
        mask (np.ndarray): Boolean mask, or zeros if failed.
        confidence (float): Confidence score of the best mask.
        status (str): Status message ("bbox_success", "invalid_bbox", etc.).
    """
    if predictor is None:
        print(" SAM predictor not available in get_sam_mask_from_bbox.")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "predictor_unavailable"

    # It's good practice to clear cache before potentially large GPU operations,
    # though this will be called more broadly between images.
    # clear_cuda_cache() # Potentially too frequent if called for every bbox. Manage at image level.

    x, y, w, h = bbox_coords

    # Validate bbox
    if x < 0 or y < 0 or w <= 0 or h <= 0:
        print(f" Invalid bbox: {bbox_coords}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "invalid_bbox"
    if x + w > image_rgb.shape[1] or y + h > image_rgb.shape[0]:
        # Clamp bbox to image dimensions if slightly out, or reject if too far
        # For simplicity, let's reject if the starting point is okay but extent is too large
        print(f" Bbox {bbox_coords} partially outside image bounds ({image_rgb.shape[1]}x{image_rgb.shape[0]}). Clamping not implemented here, might lead to errors.")
        # A more robust way would be to clamp:
        # x2 = min(x + w, image_rgb.shape[1])
        # y2 = min(y + h, image_rgb.shape[0])
        # w = x2 - x
        # h = y2 - y
        # if w <= 0 or h <= 0: return ... "bbox_out_of_bounds_after_clamping"
        # For now, proceed with caution or return error for out of bounds.
        # Let's return an error state for this case to be safe.
        # return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "bbox_out_of_bounds"

    try:
        predictor.set_image(image_rgb)
    except Exception as e:
        print(f" SAM predictor.set_image failed: {e}")
        # Attempt to reset predictor or re-initialize SAM if this becomes a recurring issue
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "set_image_failed"

    input_box = np.array([x, y, x + w, y + h])

    try:
        # multimask_output=True gives 3 masks. The first one is usually the best.
        # Scores are IOU predictions for each mask.
        masks, scores, logits = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],  # SAM expects a batch of boxes
            multimask_output=True,
        )
        
        if masks is not None and len(masks) > 0 and scores is not None and len(scores) > 0:
            best_idx = np.argmax(scores)
            return masks[best_idx], float(scores[best_idx]), "bbox_success"
        else:
            print(" SAM predictor.predict returned no masks/scores for bbox.")
            return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "no_masks_from_bbox"
            
    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            print(f" CUDA Out of Memory during SAM bbox prediction! Bbox: {bbox_coords}")
            clear_cuda_cache() # Attempt to clear for next operations
            # Consider more drastic recovery or skipping if OOM is frequent
        else:
            print(f" SAM bbox prediction runtime error: {e}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "prediction_failed_runtime_error"
    except Exception as e:
        print(f" SAM bbox prediction failed with unknown error: {e}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0, "prediction_failed_unknown_error"

def get_sam_mask_from_points(image_rgb, points_coords, predictor, point_labels=None):
    """
    Generates a segmentation mask using SAM based on point prompts.
    Args:
        image_rgb (np.ndarray): The input image in RGB format.
        points_coords (np.ndarray): Nx2 array of (x,y) point coordinates.
        predictor (SamPredictor): The initialized SAM predictor instance.
        point_labels (np.ndarray, optional): Nx1 array of labels (1 for foreground, 0 for background). Defaults to all foreground.
    Returns:
        mask (np.ndarray): Boolean mask, or zeros if failed.
        confidence (float): Confidence score of the best mask.
    """
    if predictor is None:
        print(" SAM predictor not available in get_sam_mask_from_points.")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0

    if point_labels is None:
        point_labels = np.ones(len(points_coords)) # Assume all points are foreground

    try:
        predictor.set_image(image_rgb)
    except Exception as e:
        print(f" SAM predictor.set_image failed in point prediction: {e}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0

    try:
        masks, scores, logits = predictor.predict(
            point_coords=np.array(points_coords),
            point_labels=np.array(point_labels),
            box=None,
            multimask_output=True,
        )

        if masks is not None and len(masks) > 0 and scores is not None and len(scores) > 0:
            best_idx = np.argmax(scores)
            return masks[best_idx], float(scores[best_idx])
        else:
            print("⚠️ SAM predictor.predict returned no masks/scores for points.")
            return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0

    except RuntimeError as e:
        if "CUDA out of memory" in str(e):
            print(f" CUDA Out of Memory during SAM point prediction! Points: {points_coords[:3]}") # Log first few points
            clear_cuda_cache()
        else:
            print(f" SAM point prediction runtime error: {e}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0
    except Exception as e:
        print(f" SAM point prediction failed: {e}")
        return np.zeros((image_rgb.shape[0], image_rgb.shape[1]), dtype=bool), 0.0

print("✓ SAM masking functions defined.")

# ==============================================================================
# SECTION V: KEYPOINT & ARROW DETECTION
# (Functions for advanced prompting strategies for SAM)
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION V: DEFINING KEYPOINT AND ARROW DETECTION LOGIC")
print("----------------------------------------------------------------------")

def detect_keypoints_and_regions(image_rgb, bbox_coords):
    """
    Detects keypoints (SIFT, FAST, LoG Blobs) and horizontal lines within a given ROI.
    Args:
        image_rgb (np.ndarray): Full image in RGB.
        bbox_coords (tuple): (x, y, w, h) defining the Region of Interest.
    Returns:
        list: A list of tuples, where each tuple is (region_type_str, points_array).
    """
    x, y, w, h = map(int, bbox_coords) # Ensure integer coordinates

    # Validate ROI coordinates against image dimensions
    img_h, img_w = image_rgb.shape[:2]
    if x < 0 or y < 0 or w <= 0 or h <= 0 or x + w > img_w or y + h > img_h:
        print(f" Invalid ROI {bbox_coords} for keypoint detection in image of shape {image_rgb.shape}. Skipping.")
        return []

    roi = image_rgb[y:y+h, x:x+w]
    if roi.size == 0:
        print(f" Empty ROI extracted for keypoint detection with bbox {bbox_coords}. Skipping.")
        return []
        
    gray_roi = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
    detected_regions = []

    # 1. SIFT for distinctive points
    try:
        sift = cv2.SIFT_create()
        keypoints_sift = sift.detect(gray_roi, None) # No need for descriptors here
        if keypoints_sift:
            sift_points = np.array([[kp.pt[0] + x, kp.pt[1] + y] for kp in keypoints_sift[:10]], dtype=int) # Top 10
            if sift_points.size > 0:
                 detected_regions.append(('sift_points', sift_points))
    except Exception as e:
        print(f"SIFT detection failed or unavailable: {e}")

    # 2. LoG blobs (Laplacian of Gaussian) - if scikit-image is available
    if SKIMAGE_AVAILABLE:
        try:
            # Adjust parameters for medical images: may need larger sigma or different threshold
            blobs_log = feature.blob_log(gray_roi, max_sigma=20, min_sigma=5, num_sigma=5, threshold=0.05) # Tweaked params
            if len(blobs_log) > 0:
                # blobs_log returns (y, x, sigma)
                blob_points = np.array([[blob[1] + x, blob[0] + y] for blob in blobs_log[:5]], dtype=int) # Top 5
                if blob_points.size > 0:
                    detected_regions.append(('log_blobs', blob_points))
        except Exception as e:
            print(f"ℹ️ LoG Blob detection failed: {e}")
    
    # 3. FAST for corners
    try:
        fast = cv2.FastFeatureDetector_create(threshold=10, nonmaxSuppression=True) # Standard params
        keypoints_fast = fast.detect(gray_roi, None)
        if keypoints_fast:
            fast_points = np.array([[kp.pt[0] + x, kp.pt[1] + y] for kp in keypoints_fast[:8]], dtype=int) # Top 8
            if fast_points.size > 0:
                detected_regions.append(('fast_corners', fast_points))
    except Exception as e:
        print(f" FAST corner detection failed: {e}")

    # 4. Detect horizontal lines (e.g., for fluid levels or flat structures)
    try:
        edges = cv2.Canny(gray_roi, 50, 150)
        lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=20, minLineLength=w // 4, maxLineGap=5) # Adjusted threshold & length
        if lines is not None:
            horizontal_points = []
            for line in lines:
                x1, y1, x2, y2 = line[0]
                angle = np.abs(np.arctan2(y2 - y1, x2 - x1) * 180 / np.pi)
                if angle < 10 or angle > 170: # More strictly horizontal
                    mid_x = (x1 + x2) // 2 + x
                    mid_y = (y1 + y2) // 2 + y
                    horizontal_points.append([mid_x, mid_y])
            if horizontal_points:
                detected_regions.append(('horizontal_lines_midpoints', np.array(horizontal_points, dtype=int)))
    except Exception as e:
        print(f" Horizontal line detection failed: {e}")
        
    # print(f"Detected regions for bbox {bbox_coords}: {[r[0] for r in detected_regions]}")
    return detected_regions


def advanced_arrow_detection(image_rgb, bbox_coords, debug=False):
    """
    Detects arrows within an ROI and estimates their target.
    Args:
        image_rgb (np.ndarray): Full image in RGB.
        bbox_coords (tuple): (x, y, w, h) for the ROI potentially containing an arrow.
    Returns:
        dict: Information about detected arrow, including 'found', 'target_bbox'.
    """
    x, y, w, h = map(int, bbox_coords)
    img_h, img_w = image_rgb.shape[:2]

    arrow_info = {
        'found': False, 'direction': None, 'target_bbox': bbox_coords, # Default to original bbox
        'confidence': 0.0, 'method': 'none'
    }

    if x < 0 or y < 0 or w <= 0 or h <= 0 or x + w > img_w or y + h > img_h:
        if debug: print(f" Invalid ROI {bbox_coords} for arrow detection. Using original bbox as target.")
        arrow_info['method'] = 'invalid_roi_coords'
        return arrow_info
        
    roi = image_rgb[y:y+h, x:x+w]
    if roi.size == 0:
        if debug: print(f" Empty ROI for arrow detection with bbox {bbox_coords}.")
        arrow_info['method'] = 'empty_roi'
        return arrow_info

    gray_roi = cv2.cvtColor(roi, cv2.COLOR_RGB2GRAY)
    
    try:
        # Using Canny and HoughLinesP for line detection
        edges = cv2.Canny(gray_roi, 50, 150, apertureSize=3)
        lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=15, # Threshold for line detection
                                minLineLength=max(10, min(w, h) // 4), # Min length relative to ROI size
                                maxLineGap=max(5, min(w,h) // 10))     # Max gap relative to ROI size

        if lines is not None and len(lines) > 0:
            potential_arrows = []
            for line_segment in lines:
                x1, y1, x2, y2 = line_segment[0]
                length = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
                # Heuristic: a line is a potential arrow shaft if it's reasonably long within the ROI
                if length > min(w, h) * 0.25: # Must be at least 25% of the smaller dimension of ROI
                    line_direction = np.array([x2 - x1, y2 - y1]) / length # Normalized direction
                    potential_arrows.append({
                        'line': (x1, y1, x2, y2), 'length': length, 'direction': line_direction,
                        'midpoint': ((x1+x2)/2, (y1+y2)/2)
                    })
            
            if potential_arrows:
                # Select the longest line as the primary candidate for an arrow shaft
                main_arrow_candidate = max(potential_arrows, key=lambda p: p['length'])
                ax1, ay1, ax2, ay2 = main_arrow_candidate['line']
                arrow_direction = main_arrow_candidate['direction'] # From (ax1,ay1) to (ax2,ay2)
                
                # Determine arrow tip: check intensity or structure around line ends (ax2, ay2)
                # This part can be complex. For a simpler heuristic:
                # Assume (ax2,ay2) is the tip if it's 'sharper' or points towards a change.
                # For now, we'll just use the line direction and extend from one end.
                # Let's assume the end further from ROI center is part of the shaft, pointing away.
                
                # Estimate target: extend the line beyond the ROI to find where it points
                # The 'target_bbox' should be outside the current 'arrow' bbox.
                extension_factor = 1.0 * main_arrow_candidate['length'] # Extend by its own length
                
                # Tip of the arrow within ROI (local coordinates)
                tip_local_x, tip_local_y = ax2, ay2

                # Projected point in global image coordinates
                target_global_x = x + tip_local_x + arrow_direction[0] * extension_factor
                target_global_y = y + tip_local_y + arrow_direction[1] * extension_factor

                # Create a new small bounding box around this projected target point
                target_box_size = max(20, int(min(w, h) * 0.5)) # Size of the target bbox
                
                new_target_x = int(target_global_x - target_box_size // 2)
                new_target_y = int(target_global_y - target_box_size // 2)
                
                # Clamp new_target_bbox to image boundaries
                new_target_x = max(0, min(new_target_x, img_w - target_box_size))
                new_target_y = max(0, min(new_target_y, img_h - target_box_size))
                
                new_target_bbox = (new_target_x, new_target_y, target_box_size, target_box_size)

                arrow_info.update({
                    'found': True,
                    'direction': arrow_direction.tolist(), # Convert numpy array for JSON serializability if needed
                    'target_bbox': new_target_bbox,
                    'confidence': min(main_arrow_candidate['length'] / max(1,max(w, h)), 1.0), # Confidence based on relative length
                    'method': 'line_detection_heuristic'
                })
                if debug: print(f"✓ Arrow detected by line heuristic. Original bbox: {bbox_coords}, Target bbox: {new_target_bbox}, Conf: {arrow_info['confidence']:.2f}")
    except Exception as e:
        if debug: print(f"Arrow detection (line method) failed for bbox {bbox_coords}: {e}")
        arrow_info['method'] = 'line_detection_exception'

    return arrow_info


def correct_directional_labels(label_text, bbox_coords, image_width):
    """
    Corrects directional terms (left/right) in labels based on the bbox's horizontal position.
    Assumes standard anatomical view (image left is patient's right).
    """
    x, _, w, _ = bbox_coords
    center_x = x + w / 2.0
    
    # Determine if the bbox is on the image's left or right side
    # This corresponds to patient's right or left side respectively.
    image_side_is_left = center_x < image_width / 2.0 # Bbox is on the left side of the image

    corrected_label = label_text # Start with original label
    label_lower = label_text.lower()

    # If label says "right" but bbox is on image's left side (patient's right) -> No change needed for 'right'
    # If label says "left" but bbox is on image's right side (patient's left) -> No change needed for 'left'

    # Corrections are needed if label mismatches anatomical position relative to image sides
    if "right" in label_lower and not image_side_is_left: # Label "right", bbox on image's right (patient's left)
        corrected_label = label_text.replace("right", "left").replace("Right", "Left")
        print(f"Directional label correction: '{label_text}' -> '{corrected_label}' (bbox on image right, implies patient left)")
    elif "left" in label_lower and image_side_is_left: # Label "left", bbox on image's left (patient's right)
        corrected_label = label_text.replace("left", "right").replace("Left", "Right")
        print(f" Directional label correction: '{label_text}' -> '{corrected_label}' (bbox on image left, implies patient right)")
    # Similar logic for other languages if needed (e.g., "derecha", "izquierda")

    return corrected_label.title() # Return title-cased


# This is the refined SAM mask selection logic from your original PART 6
def sam_with_intelligent_arrow_following(image_rgb, bbox_coords, predictor, label_text, debug=False):
    """
    Enhanced SAM processing: Tries direct bbox, then arrow following if confidence is low,
    then keypoint-based prompting. Returns the best mask found.
    """
    # 1. Try SAM with the original bounding box
    if debug: print(f"Attempting SAM for '{label_text}' with original bbox: {bbox_coords}")
    original_mask, original_conf, strategy_msg = get_sam_mask_from_bbox(image_rgb, bbox_coords, predictor)
    
    best_mask = original_mask
    best_conf = original_conf
    best_strategy = f"original_bbox ({strategy_msg})"

    if debug: print(f"  Original bbox SAM: conf={original_conf:.3f}, strategy='{best_strategy}'")

    # 2. If confidence is low, try arrow detection and follow-up SAM
    # Threshold for "low confidence" can be tuned.
    arrow_strategy_applied = False
    if best_conf < 0.75: # Confidence threshold to trigger arrow/keypoint strategies
        if debug: print(f"  Low confidence ({best_conf:.3f}) for '{label_text}'. Trying arrow detection from bbox: {bbox_coords}")
        arrow_info = advanced_arrow_detection(image_rgb, bbox_coords, debug=debug)
        
        if arrow_info['found'] and arrow_info['target_bbox'] != bbox_coords: # Ensure target is different
            arrow_strategy_applied = True
            if debug: print(f"  Arrow found for '{label_text}', method: {arrow_info['method']}. New target bbox: {arrow_info['target_bbox']}")
            
            arrow_target_mask, arrow_target_conf, arrow_strategy_msg = get_sam_mask_from_bbox(image_rgb, arrow_info['target_bbox'], predictor)
            
            # Boost confidence slightly if arrow logic was successful in finding a plausible target
            adjusted_arrow_conf = arrow_target_conf * (1.0 + arrow_info['confidence'] * 0.2) # Modest boost
            
            if debug: print(f"    SAM on arrow target: conf={arrow_target_conf:.3f} (adjusted: {adjusted_arrow_conf:.3f})")
            if adjusted_arrow_conf > best_conf:
                best_mask = arrow_target_mask
                best_conf = adjusted_arrow_conf
                best_strategy = f"arrow_followed_{arrow_info['method']}_{arrow_strategy_msg}"
                if debug: print(f"    Selected arrow strategy for '{label_text}' with new conf: {best_conf:.3f}")

    # 3. Try keypoint-based SAM prompting, regardless of arrow outcome if initial conf was low, or if arrow didn't improve much
    # Only try keypoints if initial confidence is still not great
    if best_conf < 0.80: # If still not highly confident
        if debug: print(f"  Confidence for '{label_text}' still potentially low ({best_conf:.3f}). Trying keypoint strategies from bbox: {bbox_coords}")
        detected_keypoint_regions = detect_keypoints_and_regions(image_rgb, bbox_coords)
        
        keypoint_strategy_applied_successfully = False
        for region_type, points_array in detected_keypoint_regions:
            if points_array.ndim == 2 and points_array.shape[0] > 0: # Ensure valid points
                if debug: print(f"    Trying SAM with '{region_type}' ({len(points_array)} points) for '{label_text}'")
                try:
                    # Using points requires predictor instance
                    keypoint_mask, keypoint_conf = get_sam_mask_from_points(image_rgb, points_array, predictor)
                    if debug: print(f"      SAM with '{region_type}': conf={keypoint_conf:.3f}")
                    
                    if keypoint_conf > best_conf:
                        best_mask = keypoint_mask
                        best_conf = keypoint_conf
                        best_strategy = f"keypoints_{region_type}"
                        keypoint_strategy_applied_successfully = True
                        if debug: print(f"      Selected keypoint strategy '{region_type}' for '{label_text}' with new conf: {best_conf:.3f}")
                except Exception as e_kp:
                    if debug: print(f"      Error during SAM with '{region_type}' points: {e_kp}")
            else:
                if debug: print(f"    Skipping '{region_type}' for '{label_text}' due to no/invalid points.")
        if keypoint_strategy_applied_successfully and debug:
             print(f"  Keypoint strategy improved mask for '{label_text}'.")


    if debug: print(f"  Final best strategy for '{label_text}': {best_strategy}, Final best conf: {best_conf:.3f}")
    return best_mask, best_conf, best_strategy


# This is the main mask selection function from your original PART 6, renamed slightly for clarity.
def select_best_mask_for_label_advanced(image_rgb, bbox_coords, predictor, label_text, image_id_for_heatmaps=None, export_heatmaps_flag=True):
    """
    Selects the best SAM mask using intelligent strategies (direct bbox, arrow following, keypoints).
    Optionally generates and saves heatmaps for good detections.
    Args:
        image_id_for_heatmaps: If provided, used in naming saved heatmaps.
        export_heatmaps_flag: Boolean to control heatmap generation.
    """
    try:
        mask, confidence, strategy = sam_with_intelligent_arrow_following(
            image_rgb, bbox_coords, predictor, label_text, debug=True # Enable debug for detailed logs
        )
        
        # Generate and save heatmaps only for masks with reasonable confidence
        if export_heatmaps_flag and confidence > 0.6 and image_id_for_heatmaps is not None:
            print(f"  Generating useful heatmap for '{label_text}' (Image: {image_id_for_heatmaps}, Conf: {confidence:.2f})")
            # Ensure generate_useful_heatmap_analysis is defined and works as intended.
            # This function was in your PART 6.
            generate_useful_heatmap_analysis(image_rgb, bbox_coords, mask, confidence, label_text, image_id_for_heatmaps, OUTPUT_DIR)
        
        return mask, confidence, strategy
        
    except Exception as e:
        print(f" Error in advanced mask selection for '{label_text}': {e}")
        # Fallback to basic SAM if advanced fails catastrophically
        try:
            print(f"  Falling back to basic bbox SAM for '{label_text}' due to error.")
            mask, confidence, strategy_msg = get_sam_mask_from_bbox(image_rgb, bbox_coords, predictor)
            return mask, confidence, f"fallback_basic_bbox ({strategy_msg})"
        except Exception as e2:
            print(f" Basic fallback SAM also failed for '{label_text}': {e2}")
            h_img, w_img = image_rgb.shape[:2]
            return np.zeros((h_img, w_img), dtype=bool), 0.0, "error_no_mask_generated"

print("✓ Keypoint, Arrow Detection, and Advanced Mask Selection functions defined.")


# ==============================================================================
# SECTION VI: YOLO ENHANCEMENT FUNCTIONS
# (Using YOLO for object detection, currently runs on CPU)
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION VI: DEFINING YOLO ENHANCEMENT LOGIC")
print("----------------------------------------------------------------------")

# --- Heatmap Generation ---
def generate_useful_heatmap_analysis(image_rgb, bbox_coords, mask_array, confidence_score, label_text, image_id_str, base_output_dir):
    """
    Generates and SAVES a 4-panel analysis image: Original+BBox, Detected Mask, Mask Overlay, ROI.
    This version is for saving analysis, not immediate display.
    Args:
        base_output_dir: The main output directory for the script. Heatmaps go into a subfolder.
    """
    try:
        x, y, w, h = map(int, bbox_coords) # Ensure integer coordinates

        # Create a dedicated heatmap directory if it doesn't exist
        heatmap_output_dir = os.path.join(base_output_dir, "heatmaps_and_masks", str(image_id_str))
        os.makedirs(heatmap_output_dir, exist_ok=True)

        fig, axes = plt.subplots(2, 2, figsize=(12, 10)) # Slightly smaller for individual saves
        fig.suptitle(f'Mask Analysis: {image_id_str} - {label_text}\n(Conf: {confidence_score:.3f})', fontsize=14)

        # 1. Original image with BBox
        axes[0, 0].imshow(image_rgb)
        rect_patch = Rectangle((x, y), w, h, linewidth=2, edgecolor='red', facecolor='none')
        axes[0, 0].add_patch(rect_patch)
        axes[0, 0].set_title('Original Image + BBox')
        axes[0, 0].axis('off')

        # 2. Detected Mask
        axes[0, 1].imshow(mask_array, cmap='viridis') # 'viridis' is often good for masks
        axes[0, 1].set_title(f'Detected Mask')
        axes[0, 1].axis('off')

        # 3. Mask Overlay on image
        overlay_img = image_rgb.copy()
        # Ensure mask_array is boolean for indexing
        boolean_mask = mask_array.astype(bool)
        # Apply a distinct color to the mask region, e.g., semi-transparent red
        overlay_color = np.array([255, 0, 0], dtype=np.uint8) # Red
        # Blend where mask is true
        overlay_img[boolean_mask] = cv2.addWeighted(overlay_img[boolean_mask], 0.5, overlay_color, 0.5, 0)
        axes[1, 0].imshow(overlay_img)
        axes[1, 0].set_title('Mask Overlay')
        axes[1, 0].axis('off')

        # 4. Region of Interest (ROI) from original image
        # Clamp ROI coordinates to be within image bounds
        roi_x_end = min(x + w, image_rgb.shape[1])
        roi_y_end = min(y + h, image_rgb.shape[0])
        roi_x_start = max(0, x)
        roi_y_start = max(0, y)

        if roi_x_end > roi_x_start and roi_y_end > roi_y_start:
             roi_img = image_rgb[roi_y_start:roi_y_end, roi_x_start:roi_x_end]
             axes[1, 1].imshow(roi_img)
        else:
            axes[1,1].text(0.5, 0.5, "Invalid ROI dims", ha='center', va='center')
        axes[1, 1].set_title('Region of Interest')
        axes[1, 1].axis('off')
        
        plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust for suptitle
        
        # Sanitize label_text for filename
        safe_label_text = "".join(c if c.isalnum() else "_" for c in label_text)
        heatmap_filename = f"mask_analysis_{image_id_str}_{safe_label_text}.png"
        heatmap_save_path = os.path.join(heatmap_output_dir, heatmap_filename)
        
        plt.savefig(heatmap_save_path, dpi=100) # Lower DPI for individual diagnostic images
        plt.close(fig) # Close the figure to free memory and prevent display
        # print(f"  ✓ Saved mask analysis: {heatmap_save_path}")

    except Exception as e:
        print(f" Error generating 'useful heatmap analysis' for {label_text} on {image_id_str}: {e}")
        if 'fig' in locals() and fig is not None: # Ensure figure is closed on error
            plt.close(fig)

def calculate_bbox_iou(boxA, boxB):
    """
    Calculate Intersection over Union (IoU) between two bounding boxes.
    Assumes box format: [x1, y1, x2, y2] or can be adapted for (x,y,w,h)
    Let's assume input dicts like {'x': x, 'y': y, 'width': w, 'height': h} as used later.
    """
    # Convert to (x1, y1, x2, y2)
    b1_x1, b1_y1 = boxA['x'], boxA['y']
    b1_x2, b1_y2 = boxA['x'] + boxA['width'], boxA['y'] + boxA['height']
    
    b2_x1, b2_y1 = boxB['x'], boxB['y']
    b2_x2, b2_y2 = boxB['x'] + boxB['width'], boxB['y'] + boxB['height']

    # Determine the (x, y)-coordinates of the intersection rectangle
    x_A = max(b1_x1, b2_x1)
    y_A = max(b1_y1, b2_y1)
    x_B = min(b1_x2, b2_x2)
    y_B = min(b1_y2, b2_y2)

    # Compute the area of intersection rectangle
    interArea = max(0, x_B - x_A) * max(0, y_B - y_A)

    # Compute the area of both the prediction and ground-truth rectangles
    boxAArea = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
    boxBArea = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
    
    iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6) # Add epsilon for stability
    return iou

def enhanced_yolo_detection(image_rgb, existing_sam_bboxes_list, confidence_threshold=0.25):
    """
    Performs object detection using YOLOv8n on CPU.
    Filters detections based on confidence and non-overlap with existing SAM bboxes.
    Args:
        image_rgb (np.ndarray): Input image.
        existing_sam_bboxes_list (list of dicts): List of bboxes from SAM, 
                                                  each dict like {'x': x, 'y': y, 'width': w, 'height': h, 'label': label}.
        confidence_threshold (float): Minimum confidence for YOLO detections.
    Returns:
        list: List of new bounding boxes detected by YOLO.
    """
    if not YOLO_AVAILABLE:
        print(" YOLO not available, skipping YOLO detection.")
        return []

    # clear_cuda_cache() # YOLO is on CPU, so less critical here, but good if other GPU ops preceded.
    
    yolo_model_name = 'yolov8n.pt' # Nano model, good for speed/CPU
    newly_detected_bboxes_by_yolo = []

    try:
        model = ultralytics.YOLO(yolo_model_name)
        print(f"✓ YOLO model '{yolo_model_name}' loaded for CPU detection.")
    except Exception as e:
        print(f" Failed to load YOLO model '{yolo_model_name}': {e}")
        return []

    try:
        # Perform detection on CPU as per your original script
        # Using with torch.no_grad() is good practice for any PyTorch inference
        with torch.no_grad():
            results = model(image_rgb, conf=confidence_threshold, verbose=False, device='cpu')
        
        if results and results[0].boxes is not None:
            yolo_boxes = results[0].boxes.xywh.cpu().numpy() # x_center, y_center, width, height
            yolo_confs = results[0].boxes.conf.cpu().numpy()
            yolo_classes = results[0].boxes.cls.cpu().numpy()
            class_names = results[0].names # Dictionary of class_id: class_name

            for i in range(len(yolo_boxes)):
                if yolo_confs[i] >= confidence_threshold:
                    xc, yc, w, h = yolo_boxes[i]
                    x1 = int(xc - w / 2)
                    y1 = int(yc - h / 2)
                    w = int(w)
                    h = int(h)
                    
                    yolo_bbox_dict = {
                        'x': x1, 'y': y1, 'width': w, 'height': h,
                        'confidence': float(yolo_confs[i]),
                        'class_id': int(yolo_classes[i]),
                        'label': class_names[int(yolo_classes[i])] # Get class name
                    }

                    # Check for significant overlap with any existing SAM bbox
                    is_new_detection = True
                    if existing_sam_bboxes_list: # Only check if there are SAM bboxes
                        for sam_bbox in existing_sam_bboxes_list:
                            # Ensure sam_bbox has width and height keys
                            if 'width' in sam_bbox and 'height' in sam_bbox :
                                if calculate_bbox_iou(yolo_bbox_dict, sam_bbox) > 0.5: # IoU threshold for "overlap"
                                    is_new_detection = False
                                    break
                            else:
                                print(f"Warning: SAM bbox missing width/height: {sam_bbox.get('label', 'Unknown Label')}")


                    if is_new_detection:
                        newly_detected_bboxes_by_yolo.append(yolo_bbox_dict)
                        # print(f"  ✓ YOLO found new object: {yolo_bbox_dict['label']} (Conf: {yolo_bbox_dict['confidence']:.2f})")
            if newly_detected_bboxes_by_yolo:
                 print(f"✓ YOLO detection complete. Found {len(newly_detected_bboxes_by_yolo)} new, non-overlapping objects.")

    except Exception as e:
        print(f" Error during YOLO detection process: {e}")
        import traceback
        traceback.print_exc()

    return newly_detected_bboxes_by_yolo


def refine_sam_data_with_yolo(image_id, image_rgb, current_sam_data_df):
    """
    Takes SAM detections (as a DataFrame) for an image, runs YOLO,
    and adds new, non-overlapping YOLO detections to the DataFrame.
    Args:
        image_id (str/int): The ID of the image.
        image_rgb (np.ndarray): The image.
        current_sam_data_df (pd.DataFrame): DataFrame of SAM detections for this image.
    Returns:
        pd.DataFrame: Updated DataFrame with SAM + new YOLO detections.
    """
    if not YOLO_AVAILABLE:
        # print(" YOLO not available, returning original SAM data.")
        return current_sam_data_df.copy() # Return a copy to avoid modifying original df outside

    # Convert SAM DataFrame rows to list of dicts for overlap checking
    existing_sam_bboxes_list = []
    if not current_sam_data_df.empty:
        for _, row in current_sam_data_df.iterrows():
            existing_sam_bboxes_list.append({
                'x': int(row['x']), 'y': int(row['y']),
                'width': int(row['width']), 'height': int(row['height']),
                'label': row['Label'] # For potential debugging
            })
    
    yolo_added_detections = enhanced_yolo_detection(image_rgb, existing_sam_bboxes_list)
    
    if not yolo_added_detections:
        return current_sam_data_df.copy()

    yolo_rows_to_add = []
    for det in yolo_added_detections:
        new_row = {
            'ImageID': image_id, # Ensure consistent ImageID
            'Label': f"YOLO_{det['label']}", # Prefix to distinguish from SAM labels
            'x': det['x'], 'y': det['y'],
            'width': det['width'], 'height': det['height'],
            # Optionally add YOLO confidence or class_id if your DataFrame schema supports it
            # 'confidence': det['confidence'] 
        }
        yolo_rows_to_add.append(new_row)
    
    if yolo_rows_to_add:
        yolo_df_to_add = pd.DataFrame(yolo_rows_to_add)
        updated_df = pd.concat([current_sam_data_df, yolo_df_to_add], ignore_index=True)
        print(f"✓ Added {len(yolo_added_detections)} new detections from YOLO to ImageID {image_id}.")
        return updated_df
    else:
        return current_sam_data_df.copy()

print("✓ YOLO Enhancement functions defined.")

# ==============================================================================
# SECTION VII: VISUALIZATION & REPORTING UTILITIES
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION VII: DEFINING VISUALIZATION AND REPORTING UTILITIES")
print("----------------------------------------------------------------------")




# This is the 4-panel display function from "Headmaps Visualization" (original script)
# It is for *showing* a detailed heatmap, potentially interactively in a notebook.
def generate_comprehensive_heatmap_display(image_rgb, bbox_coords, mask_array, confidence_score, label_text, image_id_str, base_output_dir, show_plot=False):
    """
    Generates and optionally SHOWS/SAVES a comprehensive 4-panel heatmap (Original+BBox, Probability-like Heatmap, Overlay, 3D Surface).
    Note: The "probability-like heatmap" and "3D surface" are derived from confidence and bbox, not directly from SAM's internal heatmaps.
    """
    try:
        x, y, w, h = map(int, bbox_coords)
        heatmap_output_dir = os.path.join(base_output_dir, "detailed_heatmaps", str(image_id_str))
        os.makedirs(heatmap_output_dir, exist_ok=True)

        fig = plt.figure(figsize=(15, 12)) # Larger figure for detailed display
        fig.suptitle(f'Detailed Heatmap Analysis: {image_id_str} - {label_text}\n(Confidence: {confidence_score:.3f})', fontsize=16, fontweight='bold')

        # Panel 1: Original image with BBox
        ax1 = fig.add_subplot(2, 2, 1)
        ax1.imshow(image_rgb)
        rect1 = Rectangle((x, y), w, h, linewidth=3, edgecolor='red', facecolor='none')
        ax1.add_patch(rect1)
        ax1.set_title('Original Image with BBox', fontsize=12)
        ax1.axis('off')

        # Panel 2: "Probability" Heatmap (Gaussian-like based on bbox and confidence)
        ax2 = fig.add_subplot(2, 2, 2)
        # Create a simple heatmap centered on the bbox, scaled by confidence
        prob_heatmap = np.zeros((image_rgb.shape[0], image_rgb.shape[1]))
        center_x, center_y = x + w // 2, y + h // 2
        # Create a grid of coordinates
        yy, xx = np.mgrid[0:image_rgb.shape[0], 0:image_rgb.shape[1]]
        # Calculate squared distance from center, apply Gaussian-like decay
        # Sigma can be a fraction of bbox size
        sigma_x = max(1, w / 4.0)
        sigma_y = max(1, h / 4.0)
        dist_sq = ((xx - center_x) / sigma_x)**2 + ((yy - center_y) / sigma_y)**2
        prob_heatmap = confidence_score * np.exp(-dist_sq / 2.0)
        # Ensure it's zero outside a slightly larger bbox area for clarity
        mask_area = np.zeros_like(prob_heatmap, dtype=bool)
        ext_x, ext_y, ext_w, ext_h = x - w//2, y - h//2, w*2, h*2 # Extended area
        mask_area[max(0,ext_y):min(image_rgb.shape[0],ext_y+ext_h), max(0,ext_x):min(image_rgb.shape[1],ext_x+ext_w)] = True
        prob_heatmap[~mask_area] = 0


        im = ax2.imshow(prob_heatmap, cmap='hot', vmin=0, vmax=max(0.01, confidence_score)) # Ensure vmax is not 0
        ax2.set_title('Conceptual Probability Heatmap', fontsize=12)
        ax2.axis('off')
        cbar = plt.colorbar(im, ax=ax2, fraction=0.046, pad=0.04)
        cbar.set_label('Conceptual "Probability"', rotation=270, labelpad=15)

        # Panel 3: Heatmap Overlay
        ax3 = fig.add_subplot(2, 2, 3)
        overlay_img_detailed = image_rgb.astype(float) / 255.0 # Normalize for blending
        hot_cmap = plt.get_cmap('hot')
        # Blend the conceptual heatmap onto the image
        for i in range(3): # For R, G, B channels
            overlay_img_detailed[:,:,i] = overlay_img_detailed[:,:,i] * (1 - prob_heatmap*0.7) + hot_cmap(prob_heatmap)[:,:,i] * prob_heatmap*0.7
        overlay_img_detailed = np.clip(overlay_img_detailed, 0, 1)
        ax3.imshow(overlay_img_detailed)
        ax3.set_title('Heatmap Overlay (Conceptual)', fontsize=12)
        ax3.axis('off')

        # Panel 4: 3D Heatmap Surface (of the conceptual heatmap)
        ax4 = fig.add_subplot(2, 2, 4, projection='3d')
        # Create a meshgrid for the ROI to plot
        roi_x_coords = np.arange(max(0, x - w//2), min(image_rgb.shape[1], x + w + w//2))
        roi_y_coords = np.arange(max(0, y - h//2), min(image_rgb.shape[0], y + h + h//2))

        if len(roi_x_coords) > 1 and len(roi_y_coords) > 1:
            X_3d, Y_3d = np.meshgrid(roi_x_coords, roi_y_coords)
            # Z_3d is the conceptual heatmap values in this ROI
            Z_3d = prob_heatmap[Y_3d, X_3d] # Indexing: Y_3d gives row indices, X_3d column indices
            ax4.plot_surface(X_3d, Y_3d, Z_3d, cmap='hot', edgecolor='none', alpha=0.8)
            ax4.set_title('3D Heatmap Surface (Conceptual)', fontsize=10)
            ax4.set_xlabel('X', fontsize=8); ax4.set_ylabel('Y', fontsize=8); ax4.set_zlabel('Prob', fontsize=8)
            ax4.tick_params(axis='both', which='major', labelsize=6)
        else:
            ax4.text(0.5,0.5,0.5, "ROI too small for 3D plot", ha='center', va='center')


        plt.tight_layout(rect=[0, 0, 1, 0.95]) # Adjust for suptitle

        safe_label_text = "".join(c if c.isalnum() else "_" for c in label_text)
        detailed_heatmap_filename = f"detailed_heatmap_display_{image_id_str}_{safe_label_text}.png"
        detailed_heatmap_save_path = os.path.join(heatmap_output_dir, detailed_heatmap_filename)
        plt.savefig(detailed_heatmap_save_path, dpi=150) # Higher DPI for this one
        
        if show_plot:
            plt.show() # Display if requested
        
        plt.close(fig) # Always close to free memory
        print(f"  ✓ Saved detailed heatmap display: {detailed_heatmap_save_path}")

    except Exception as e:
        print(f" Error generating comprehensive heatmap display for {label_text} on {image_id_str}: {e}")
        if 'fig' in locals() and fig is not None:
            plt.close(fig)


# --- Main Result Visualization ---
def create_legend_for_main_plot(label_colors_dict, ax_to_add_legend):
    """Creates a legend for anatomical labels on the main visualization plot."""
    legend_elements = []
    sorted_labels = sorted(label_colors_dict.keys()) # Sort for consistent legend order

    for label in sorted_labels:
        color_hex = label_colors_dict[label]
        color_rgb_mpl = hex_to_rgb(color_hex) # For Matplotlib (0-1 range)
        legend_elements.append(
            Rectangle((0, 0), 1, 1, facecolor=color_rgb_mpl, edgecolor='black', linewidth=0.5, label=label)
        )
    
    if legend_elements:
        ax_to_add_legend.legend(
            handles=legend_elements, loc='upper left', bbox_to_anchor=(1.01, 1.0), # Place outside plot area
            fontsize=8, frameon=True, fancybox=True, shadow=False, title="Detected Structures", title_fontsize=10
        )

def create_clean_visualization_with_labels(image_rgb_original, sam_data_df_for_image, label_colors_map, confidence_scores_map):
    """
    Creates a clean visualization with bounding boxes and labels (including confidence).
    Returns an image with annotations drawn using OpenCV.
    """
    display_image_cv = image_rgb_original.copy() # Work on a copy
    
    # Group by label to handle multiple bboxes for the same structure if any (e.g. from YOLO vs SAM)
    # However, the logic in process_medical_analysis_ultimate seems to consolidate this to one best per label.
    # This function will draw all bboxes present in sam_data_df_for_image for the given labels.
    
    if sam_data_df_for_image.empty:
        return display_image_cv # Return original if no data

    unique_labels_in_data = sam_data_df_for_image['Label'].unique()

    for original_label_from_df in unique_labels_in_data:
        # The label in label_colors_map and confidence_scores_map might be the 'cleaned_label'
        # We need to ensure we map correctly. Assuming the sam_data_df_for_image already has the final labels
        # that match keys in label_colors_map and confidence_scores_map.
        # If not, we might need to clean original_label_from_df here too. For now, assume direct match.
        
        label_to_use = original_label_from_df # This should be the key for colors/scores

        if label_to_use in label_colors_map and label_to_use in confidence_scores_map:
            color_hex = label_colors_map[label_to_use]
            # Convert hex to BGR for OpenCV (matplotlib uses RGB 0-1, OpenCV BGR 0-255)
            color_bgr_cv = [int(c * 255) for c in reversed(hex_to_rgb(color_hex))] # Reversed for BGR
            confidence = confidence_scores_map[label_to_use]

            # Get all bboxes for this label from the DataFrame
            bboxes_for_label_df = sam_data_df_for_image[sam_data_df_for_image['Label'] == original_label_from_df]

            for _, row_bbox in bboxes_for_label_df.iterrows():
                x, y, w, h = int(row_bbox['x']), int(row_bbox['y']), int(row_bbox['width']), int(row_bbox['height'])

                # Clamp bbox to image dimensions before drawing to avoid OpenCV errors
                img_h_cv, img_w_cv = display_image_cv.shape[:2]
                x1_cv = max(0, x)
                y1_cv = max(0, y)
                x2_cv = min(img_w_cv -1 , x + w)
                y2_cv = min(img_h_cv -1 , y + h)

                if x2_cv <= x1_cv or y2_cv <= y1_cv: continue # Skip if bbox is invalid after clamping


                # Draw rectangle
                cv2.rectangle(display_image_cv, (x1_cv, y1_cv), (x2_cv, y2_cv), color_bgr_cv, 2) # Thickness 2

                # Prepare text for label
                label_text_cv = f"{label_to_use}"
                font_face = cv2.FONT_HERSHEY_SIMPLEX
                font_scale = 0.45
                font_thickness = 1

                (text_w, text_h), baseline = cv2.getTextSize(label_text_cv, font_face, font_scale, font_thickness)
                
                # Position for text background and text
                text_bg_y1 = y1_cv - text_h - 10 if y1_cv > text_h + 15 else y2_cv + baseline + 5
                text_bg_y2 = text_bg_y1 + text_h + baseline + 5
                text_x_cv = x1_cv
                text_y_cv = text_bg_y1 + text_h + baseline // 2 # Centered vertically in bg

                # Draw text background rectangle (slightly transparent if possible, or solid)
                # For solid background:
                cv2.rectangle(display_image_cv, (text_x_cv - 2, text_bg_y1 -2), (text_x_cv + text_w + 2, text_bg_y2 +2 ), color_bgr_cv, -1)
                # Draw text
                cv2.putText(display_image_cv, label_text_cv, (text_x_cv, text_y_cv),
                            font_face, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA) # White text
    return display_image_cv


def create_improved_caption_subplot(fig_main, caption_text_main, label_colors_for_caption, grid_spec_rows=12):
    """
    Creates a dedicated subplot on the main figure for displaying an explainable caption
    and detected medical structures.
    Args:
        fig_main: The main matplotlib Figure object.
        caption_text_main: The primary caption string.
        label_colors_for_caption: Dict of {label_name: hex_color}.
        grid_spec_rows: Total rows in the GridSpec for positioning.
    """
    # This assumes fig_main uses a GridSpec. If not, plt.subplot or fig.add_axes might be used.
    # Example positioning: occupies bottom part of a 12-row grid.
    # caption_ax = fig_main.add_subplot(grid_spec_rows, 1, (grid_spec_rows - 3, grid_spec_rows)) # Last 3 rows
    # Using add_axes for more direct control relative to figure [left, bottom, width, height]
    caption_ax = fig_main.add_axes([0.05, 0.02, 0.9, 0.15]) # Adjust these ratios as needed
    caption_ax.axis('off')

    # Title for the caption section
    caption_ax.text(0.01, 0.90, "Explainable Medical Image Report", transform=caption_ax.transAxes,
                    fontsize=14, weight='bold', color='navy')

    # Main Caption Text (wrapped)
    wrapped_caption = textwrap.fill(caption_text_main if caption_text_main else "No caption available.", width=120) # Adjust width
    caption_ax.text(0.01, 0.45, wrapped_caption, transform=caption_ax.transAxes, fontsize=10,
                    verticalalignment='top', horizontalalignment='left',
                    bbox=dict(boxstyle='round,pad=0.5', facecolor='aliceblue', alpha=0.7, edgecolor='lightgrey'),
                    wrap=True)

    # Detected Medical Structures (if any)
    if label_colors_for_caption:
        struct_text_y_start = 0.25 # Position below caption
        max_struct_per_line = 4
        current_struct_count = 0
        line_height = 0.08 # Approximate height for each structure line text

        struct_str_list = []
        for i, (structure_name, hex_color) in enumerate(label_colors_for_caption.items()):
            # Matplotlib text doesn't directly support rich text like HTML for individual word colors.
            # So, we list them, possibly with a colored square marker if using a legend approach.
            # Here, we just list names. The main plot will have colored bboxes.
            struct_str_list.append(structure_name)
            if (i + 1) % max_struct_per_line == 0 :
                struct_str_list.append("\n") # Newline

        detected_structures_text = "Detected Structures: " + ", ".join(s for s in struct_str_list if s != "\n")
        detected_structures_text = detected_structures_text.replace(", \n, ", "\n") # Clean up newlines

        caption_ax.text(0.01, struct_text_y_start, textwrap.fill(detected_structures_text,width=70),
                        transform=caption_ax.transAxes, fontsize=9, weight='normal', color='darkgreen',
                        verticalalignment='top')
                        
    # Pipeline Info
    pipeline_info = "Processing: Input Data -> SAM/YOLO -> Advanced Masking -> Visualization"
    caption_ax.text(0.01, 0.05, pipeline_info, transform=caption_ax.transAxes,
                    fontsize=8, style='italic', color='grey')


# --- Data Table for Reports ---
def create_image_data_table_df(image_id_str, concepts_data_df, caption_file_data_df, explanations_data_df_main):
    """
    Creates a Pandas DataFrame summarizing data from various sources for a specific image.
    Args:
        image_id_str: The image ID.
        concepts_data_df: DataFrame from CONCEPTS_PATH.
        caption_file_data_df: DataFrame from CAPTION_FILE_PATH.
        explanations_data_df_main: DataFrame from CAPTIONS_PATH (used for explanations/main caption).
    """
    image_data_summary = { 'Data Source': [], 'Content Summary': [], 'Details': [] }
    id_cols_to_check = ['ImageID', 'ID', 'id', 'image_id'] # Common ID column names

    def find_matching_row(df, id_val, id_columns_list):
        if df is None or df.empty: return None
        for col_name in id_columns_list:
            if col_name in df.columns:
                # Ensure type consistency for comparison if image_id_str is numeric string
                try:
                    df_id_col_type = df[col_name].dtype
                    if pd.api.types.is_numeric_dtype(df_id_col_type) and isinstance(id_val, str) and id_val.isnumeric():
                        id_val_comp = int(id_val)
                    elif pd.api.types.is_string_dtype(df_id_col_type) and not isinstance(id_val, str):
                         id_val_comp = str(id_val)
                    else:
                        id_val_comp = id_val

                    match = df[df[col_name] == id_val_comp]
                    if not match.empty: return match.iloc[0] # Return first match
                except TypeError: # In case of comparison error
                    continue # Try next column
        return None

    # A) Concepts Data
    concept_row = find_matching_row(concepts_data_df, image_id_str, id_cols_to_check)
    if concept_row is not None:
        # Try to find columns named 'concept' or starting with 'C'
        concept_cols = [col for col in concept_row.index if 'concept' in col.lower() or col.startswith('C')]
        concept_details = [f"{col}: {concept_row[col]}" for col in concept_cols if pd.notna(concept_row[col])]
        if not concept_details: # Fallback to first few non-ID columns
            non_id_cols = [c for c in concept_row.index if c not in id_cols_to_check][:3]
            concept_details = [f"{col}: {concept_row[col]}" for col in non_id_cols if pd.notna(concept_row[col])]
        image_data_summary['Data Source'].append('Medical Concepts')
        image_data_summary['Content Summary'].append(', '.join([str(concept_row[col]) for col in concept_cols if pd.notna(concept_row[col])][:3]))
        image_data_summary['Details'].append(' | '.join(concept_details) if concept_details else "No specific concept details found.")
    else:
        image_data_summary['Data Source'].append('Medical Concepts')
        image_data_summary['Content Summary'].append('N/A')
        image_data_summary['Details'].append('Image ID not found in concepts data or concepts data unavailable.')

    # B) Captions from caption_file_df (CAPTION_FILE_PATH)
    caption_row = find_matching_row(caption_file_data_df, image_id_str, id_cols_to_check)
    if caption_row is not None:
        caption_text_cols = [col for col in caption_row.index if 'caption' in col.lower()]
        caption_text = caption_row[caption_text_cols[0]] if caption_text_cols and pd.notna(caption_row[caption_text_cols[0]]) else "No caption text."
        image_data_summary['Data Source'].append('Image Caption (File)')
        image_data_summary['Content Summary'].append(textwrap.shorten(str(caption_text), width=50))
        image_data_summary['Details'].append(str(caption_text))
    else:
        image_data_summary['Data Source'].append('Image Caption (File)')
        image_data_summary['Content Summary'].append('N/A')
        image_data_summary['Details'].append('Image ID not found in caption file data or data unavailable.')

    # C) Explanations/Main Caption from explanations_data_df_main (CAPTIONS_PATH)
    explanation_row = find_matching_row(explanations_data_df_main, image_id_str, id_cols_to_check)
    if explanation_row is not None:
        # Try to find columns related to explanation or text
        expl_text_cols = [col for col in explanation_row.index if any(k in col.lower() for k in ['explanation', 'caption', 'text', 'description'])]
        expl_text = explanation_row[expl_text_cols[0]] if expl_text_cols and pd.notna(explanation_row[expl_text_cols[0]]) else "No explanation text."
        image_data_summary['Data Source'].append('Medical Explanation/Report')
        image_data_summary['Content Summary'].append(textwrap.shorten(str(expl_text), width=50))
        image_data_summary['Details'].append(str(expl_text))
    else:
        image_data_summary['Data Source'].append('Medical Explanation/Report')
        image_data_summary['Content Summary'].append('N/A')
        image_data_summary['Details'].append('Image ID not found in explanations data or data unavailable.')
        
    return pd.DataFrame(image_data_summary)

def display_formatted_data_table(image_id_str, data_table_df):
    """Prints a formatted version of the image data summary table."""
    print(f"\n--- Data Summary Table for Image: {image_id_str} ---")
    if data_table_df.empty:
        print("No data to display.")
        return
        
    with pd.option_context('display.max_colwidth', 70, 'display.width', 120, 'display.colheader_justify', 'left'):
        # Using to_string for better control in script output than just print(df)
        print(data_table_df.to_string(index=False, line_width=120, formatters={
            'Content Summary': lambda x: textwrap.fill(x, width=40),
            'Details': lambda x: textwrap.fill(x, width=60)
        }))
    print("--------------------------------------------------")


# --- Comprehensive Reporting ---
# These are the de-duplicated versions from your original PART 9

def export_analysis_metrics_to_csv(all_images_analysis_stats_dict, base_output_dir):
    """Exports detailed confidence scores, image summaries, and strategy performance to CSV files."""
    if not all_images_analysis_stats_dict:
        print(" No analysis stats provided to export_analysis_metrics_to_csv.")
        return {}

    metrics_output_dir = os.path.join(base_output_dir, 'analysis_metrics')
    os.makedirs(metrics_output_dir, exist_ok=True)
    print(f"--- Exporting Analysis Metrics to: {metrics_output_dir} ---")

    # 1. Detailed confidence scores per structure per image
    confidence_records_list = []
    for img_id, stats_per_image in all_images_analysis_stats_dict.items():
        if 'confidence_scores' in stats_per_image and 'processing_details' in stats_per_image:
            for structure_name, conf_score in stats_per_image['confidence_scores'].items():
                confidence_records_list.append({
                    'image_id': img_id,
                    'structure': structure_name,
                    'confidence_score': conf_score,
                    'confidence_category': 'High' if conf_score >= 0.8 else 'Medium' if conf_score >= 0.5 else 'Low',
                    'processing_strategy': stats_per_image['processing_details'].get(structure_name, 'unknown_strategy')
                })
    
    confidence_df = pd.DataFrame(confidence_records_list)
    confidence_csv_path = os.path.join(metrics_output_dir, 'detailed_confidence_scores.csv')
    if not confidence_df.empty:
        confidence_df.to_csv(confidence_csv_path, index=False)
        print(f"✓ Detailed confidence scores exported: {confidence_csv_path}")
    else:
        print(" No detailed confidence records to export.")


    # 2. Summary statistics per image
    image_summary_records_list = []
    for img_id, stats_per_image in all_images_analysis_stats_dict.items():
        yolo_stats = stats_per_image.get('yolo_stats', {'yolo_additions': 0}) # Handle if yolo_stats is missing
        image_summary_records_list.append({
            'image_id': img_id,
            'total_structures': stats_per_image.get('structures_count', 0),
            'total_bboxes_processed': stats_per_image.get('bbox_count', 0),
            'avg_confidence': stats_per_image.get('avg_confidence', 0.0),
            'yolo_enhancements_added': yolo_stats.get('yolo_additions',0),
            'high_confidence_count': len([c for c in stats_per_image.get('confidence_scores', {}).values() if c >= 0.8]),
            'medium_confidence_count': len([c for c in stats_per_image.get('confidence_scores', {}).values() if 0.5 <= c < 0.8]),
            'low_confidence_count': len([c for c in stats_per_image.get('confidence_scores', {}).values() if c < 0.5]),
            'max_confidence': max(stats_per_image.get('confidence_scores', {0:0}).values()) if stats_per_image.get('confidence_scores') else 0.0,
            'min_confidence': min(stats_per_image.get('confidence_scores', {0:0}).values()) if stats_per_image.get('confidence_scores') else 0.0
        })

    image_summary_df = pd.DataFrame(image_summary_records_list)
    summary_csv_path = os.path.join(metrics_output_dir, 'image_level_analysis_summary.csv')
    if not image_summary_df.empty:
        image_summary_df.to_csv(summary_csv_path, index=False)
        print(f"✓ Image-level analysis summary exported: {summary_csv_path}")
    else:
        print(" No image summary records to export.")


    # 3. Performance of different processing strategies
    strategy_performance_stats = defaultdict(lambda: {'count': 0, 'confidences': []})
    for img_id, stats_per_image in all_images_analysis_stats_dict.items():
        if 'processing_details' in stats_per_image and 'confidence_scores' in stats_per_image:
            for structure_name, strategy_name in stats_per_image['processing_details'].items():
                conf = stats_per_image['confidence_scores'].get(structure_name, 0.0)
                strategy_performance_stats[strategy_name]['count'] += 1
                strategy_performance_stats[strategy_name]['confidences'].append(conf)
    
    strategy_summary_list = []
    for strategy_name, data in strategy_performance_stats.items():
        confs = data['confidences']
        if confs:
            strategy_summary_list.append({
                'strategy_name': strategy_name,
                'usage_count': data['count'],
                'avg_confidence': np.mean(confs) if confs else 0.0,
                'std_dev_confidence': np.std(confs) if confs else 0.0,
                'success_rate_gt_0.6': len([c for c in confs if c > 0.6]) / len(confs) if confs else 0.0
            })
            
    strategy_df = pd.DataFrame(strategy_summary_list)
    strategy_csv_path = os.path.join(metrics_output_dir, 'processing_strategy_performance.csv')
    if not strategy_df.empty:
        strategy_df.to_csv(strategy_csv_path, index=False)
        print(f"✓ Strategy performance summary exported: {strategy_csv_path}")
    else:
        print(" No strategy performance data to export.")

    return {
        'confidence_data_df': confidence_df, 'summary_data_df': image_summary_df, 'strategy_data_df': strategy_df,
        'confidence_csv_path': confidence_csv_path, 'summary_csv_path': summary_csv_path, 'strategy_csv_path': strategy_csv_path
    }


def generate_overall_final_report_text(all_images_stats_dict, exported_data_dfs):
    """Generates a text-based comprehensive final report summarizing the entire analysis run."""
    print("\n======================================================================")
    print(" COMPREHENSIVE FINAL ANALYSIS REPORT (OVERALL)")
    print("======================================================================")

    if not all_images_stats_dict:
        print("No analysis data available to generate a final report.")
        return {}

    total_images_processed = len(all_images_stats_dict)
    total_structures_identified = sum(stats.get('structures_count', 0) for stats in all_images_stats_dict.values())
    total_bboxes_analyzed = sum(stats.get('bbox_count', 0) for stats in all_images_stats_dict.values())
    
    all_conf_scores_flat = []
    for stats in all_images_stats_dict.values():
        if 'confidence_scores' in stats:
            all_conf_scores_flat.extend(stats['confidence_scores'].values())
    
    avg_overall_confidence = np.mean(all_conf_scores_flat) if all_conf_scores_flat else 0.0
    num_high_conf = len([c for c in all_conf_scores_flat if c >= 0.8])
    num_med_conf = len([c for c in all_conf_scores_flat if 0.5 <= c < 0.8])
    num_low_conf = len([c for c in all_conf_scores_flat if c < 0.5])

    total_yolo_added = sum(stats.get('yolo_stats',{}).get('yolo_additions',0) for stats in all_images_stats_dict.values())

    print(f"\n --- OVERALL STATISTICS ---")
    print(f"  Total Images Processed: {total_images_processed}")
    print(f"  Total Medical Structures Identified (sum over images): {total_structures_identified}")
    print(f"  Total Bounding Boxes Analyzed (sum over images): {total_bboxes_analyzed}")
    if total_images_processed > 0:
        print(f"  Avg. Structures per Image: {total_structures_identified / total_images_processed:.1f}")
    if YOLO_AVAILABLE:
         print(f"  Total YOLO Detections Added Across All Images: {total_yolo_added}")


    print(f"\n --- CONFIDENCE ANALYSIS (ACROSS ALL STRUCTURES) ---")
    if all_conf_scores_flat:
        print(f"  Overall Average Confidence: {avg_overall_confidence:.3f}")
        print(f"  Confidence Range: {np.min(all_conf_scores_flat):.3f} - {np.max(all_conf_scores_flat):.3f}")
        print(f"  Confidence Std. Deviation: {np.std(all_conf_scores_flat):.3f}")
        print(f"\n   Confidence Distribution:")
        total_conf_scores = len(all_conf_scores_flat)
        print(f"    High Confidence (≥0.8): {num_high_conf} ({num_high_conf / total_conf_scores * 100:.1f}%)")
        print(f"    Medium Confidence (0.5-0.8): {num_med_conf} ({num_med_conf / total_conf_scores * 100:.1f}%)")
        print(f"    Low Confidence (<0.5): {num_low_conf} ({num_low_conf / total_conf_scores * 100:.1f}%)")
    else:
        print("  No confidence scores recorded.")

    summary_df = exported_data_dfs.get('summary_data_df', pd.DataFrame())
    if not summary_df.empty:
        print(f"\n --- TOP PERFORMING IMAGES (BY AVG. CONFIDENCE) ---")
        top_images_df = summary_df.nlargest(min(5, len(summary_df)), 'avg_confidence')
        for _, row in top_images_df.iterrows():
            print(f"  - Image '{row['image_id']}': Avg. Conf {row['avg_confidence']:.3f} ({row['total_structures']} structures)")

        print(f"\n --- IMAGES POTENTIALLY NEEDING REVIEW (LOWEST AVG. CONFIDENCE) ---")
        # Define "needing review" threshold, e.g., avg_confidence < 0.6
        attention_images_df = summary_df[summary_df['avg_confidence'] < 0.6].nsmallest(min(5, len(summary_df)), 'avg_confidence')
        if not attention_images_df.empty:
            for _, row in attention_images_df.iterrows():
                print(f"  - Image '{row['image_id']}': Avg. Conf {row['avg_confidence']:.3f} ({row['low_confidence_count']} low-conf structures)")
        else:
            print("  All processed images have average confidence ≥ 0.6 or no images processed.")
    
    strategy_df = exported_data_dfs.get('strategy_data_df', pd.DataFrame())
    if not strategy_df.empty:
        print(f"\n --- PROCESSING STRATEGY PERFORMANCE ---")
        sorted_strategy_df = strategy_df.sort_values('avg_confidence', ascending=False)
        for _, row in sorted_strategy_df.iterrows():
            print(f"  - Strategy '{row['strategy_name']}': Used {row['usage_count']} times, Avg. Conf {row['avg_confidence']:.3f}, Success (>0.6) {row['success_rate_gt_0.6']:.1%}")

    print(f"\n --- RECOMMENDATIONS & NOTES ---")
    if avg_overall_confidence >= 0.75:
        print(f"   Overall performance appears strong with an average confidence of {avg_overall_confidence:.3f}.")
    elif avg_overall_confidence >= 0.5:
        print(f"   Overall performance is moderate. Review strategies for low-confidence detections.")
    else:
        print(f"   Overall performance may need improvement. Focus on enhancing low-confidence scenarios and data quality.")
    
    if total_yolo_added > 0 and YOLO_AVAILABLE:
        print(f"   YOLO enhancement contributed {total_yolo_added} additional detections. Consider fine-tuning YOLO on specific medical entities if not already done.")
    elif YOLO_AVAILABLE:
        print(f"  ℹ YOLO was available but did not add new detections, or was not triggered significantly. Review YOLO confidence thresholds and overlap criteria.")

    print(f"\n Output files (CSVs, heatmaps, visualizations) are saved in subdirectories within: {OUTPUT_DIR}")
    print("======================================================================")

    return { # Return key overall metrics for potential further use
        'total_images_processed': total_images_processed, 'avg_overall_confidence': avg_overall_confidence,
        'num_high_conf': num_high_conf, 'num_med_conf': num_med_conf, 'num_low_conf': num_low_conf,
        'total_yolo_added': total_yolo_added
    }

print("✓ Visualization and Reporting utilities defined.")

# ==============================================================================
# SECTION VIII: MAIN ANALYSIS FUNCTION PER IMAGE
# (This is the refined version of your original `process_medical_analysis_ultimate` from PART 10)
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION VIII: DEFINING MAIN PER-IMAGE ANALYSIS LOGIC")
print("----------------------------------------------------------------------")

def process_single_image_analysis(
    image_id_str,
    sam_data_for_image_df,   # DataFrame of initial SAM coordinates for this specific image
    explanations_data_for_image_df, # DataFrame containing the main caption/explanation for this image
    concepts_full_df,        # Full DataFrame for all concepts
    caption_file_full_df,    # Full DataFrame for all captions (from caption.csv)
    sam_predictor_instance,
    base_output_dir
):
    """
    Processes a single medical image: loads, enhances with YOLO, applies advanced SAM,
    generates visualizations, and collects statistics.
    """
    print(f"\n Starting Analysis for Image ID: {image_id_str} ")
    
    # 0. Initial clear_cuda_cache before loading image and model processing for this image
    # clear_cuda_cache() # This is now handled in the main loop BEFORE calling this function.

    # 1. Load Image
    image_path = find_image_file(str(image_id_str), IMAGES_DIR) # IMAGES_DIR is global config
    if not image_path:
        print(f" Image file not found for ID {image_id_str} in {IMAGES_DIR}. Skipping this image.")
        return None, {} # Return None for path, empty dict for stats

    image_bgr = cv2.imread(image_path)
    if image_bgr is None:
        print(f" Could not load image from path: {image_path}. Skipping this image.")
        return None, {}

    image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
    img_height, img_width = image_rgb.shape[:2]
    print(f"✓ Image '{image_id_str}' loaded ({img_width}x{img_height}).")

    # 2. Initial SAM Data & YOLO Enhancement
    # `sam_data_for_image_df` are the initial bboxes for this image
    # `refine_sam_data_with_yolo` adds new, non-overlapping YOLO detections.
    print(f"  Initial SAM bboxes for image {image_id_str}: {len(sam_data_for_image_df)}")
    enhanced_detections_df = refine_sam_data_with_yolo(image_id_str, image_rgb, sam_data_for_image_df)
    yolo_additions_count = len(enhanced_detections_df) - len(sam_data_for_image_df)
    print(f"  Total detections after YOLO for image {image_id_str}: {len(enhanced_detections_df)} ({yolo_additions_count} added by YOLO).")


    # 3. Process each detected structure/label with advanced SAM
    # `grouped_data` will iterate over unique labels in the `enhanced_detections_df`
    # For each label, it will take the bbox(es) and apply advanced SAM.
    
    processed_label_colors = {} # Stores {final_label_name: color_hex}
    final_confidence_scores = {} # Stores {final_label_name: confidence_float}
    final_processing_strategies = {} # Stores {final_label_name: strategy_string}
    # DataFrame to store the final bboxes that are kept after processing all strategies
    # This will be used for the clean visualization.
    # It might contain multiple bboxes if YOLO added some that SAM didn't refine, or if a label has multiple distinct instances.
    # However, the current `select_best_mask_for_label_advanced` processes one bbox at a time.
    # The visualization `create_clean_visualization_with_labels` will draw based on the labels in `processed_label_colors`.
    # Let's build a list of dicts for the final chosen bboxes to create a new DataFrame for visualization.
    final_bboxes_for_visualization_list = []


    total_initial_bboxes_to_process = len(enhanced_detections_df)
    processed_bbox_count_for_stats = 0 # Count of bboxes actually processed by SAM

    # Group by 'Label' column which contains original SAM labels and YOLO prefixed labels
    if not enhanced_detections_df.empty:
        grouped_initial_detections = enhanced_detections_df.groupby('Label')
        
        for original_label_from_df, group_df in grouped_initial_detections:
            print(f"\n  Processing label group: '{original_label_from_df}' ({len(group_df)} initial bbox(es))")
            
            # For each original label, we might have one or more bboxes (e.g. if SAM generated multiple for a concept)
            # We should pick the "best" initial bbox from the group if there are multiple, or process all.
            # The current structure of your script seems to imply processing each bbox row from the input CSVs.
            # Let's assume we process each row in `group_df` if they are distinct enough,
            # or select one representative one. For simplicity and to match `select_best_mask_for_label_advanced`
            # which takes one bbox, let's process the first bbox in the group as representative for this label.
            # A more advanced approach might iterate all, or merge/NMS bboxes for the same label.
            
            # Take the first row of the group as the representative for this label's initial bbox
            representative_row = group_df.iloc[0]
            initial_bbox_coords = (
                int(representative_row['x']), int(representative_row['y']),
                int(representative_row['width']), int(representative_row['height'])
            )
            processed_bbox_count_for_stats += 1 # Count this as one bbox processed by SAM

            # Clean and correct the label name
            cleaned_label_base = clean_label_text(original_label_from_df)
            # Directional correction needs the *initial* bbox for context
            final_label_name = correct_directional_labels(cleaned_label_base, initial_bbox_coords, img_width)
            
            print(f"    Cleaned/Corrected Label: '{final_label_name}' from BBox: {initial_bbox_coords}")

            # Assign color if new label
            if final_label_name not in processed_label_colors:
                color_idx = len(processed_label_colors) % len(COLORS) # COLORS is global
                processed_label_colors[final_label_name] = COLORS[color_idx]

            # Apply advanced SAM processing for this label and its representative bbox
            mask_array, confidence, strategy = select_best_mask_for_label_advanced(
                image_rgb, initial_bbox_coords, sam_predictor_instance, final_label_name,
                image_id_str, export_heatmaps_flag=True # Heatmaps (mask analysis) saved by this function
            )
            
            final_confidence_scores[final_label_name] = confidence
            final_processing_strategies[final_label_name] = strategy
            
            # Store this chosen bbox and its final label for the main visualization
            # Note: `initial_bbox_coords` is used here as the reference for drawing.
            # SAM provides a mask; the bbox is the prompt.
            final_bboxes_for_visualization_list.append({
                'ImageID': image_id_str,
                'Label': final_label_name, # Use the corrected label
                'x': initial_bbox_coords[0], 'y': initial_bbox_coords[1],
                'width': initial_bbox_coords[2], 'height': initial_bbox_coords[3],
                'confidence_sam': confidence, # Store SAM's confidence for this mask
                'strategy': strategy
            })

            print(f"    ✓ Result for '{final_label_name}': Strategy='{strategy}', Confidence={confidence:.3f}")

            # Optionally, generate the very detailed 4-panel conceptual heatmap display
            # This might be too much for every structure. Perhaps only for high-confidence ones or for debugging.
            if confidence > 0.75: # Example: generate for high confidence results
                 generate_comprehensive_heatmap_display(image_rgb, initial_bbox_coords, mask_array, confidence, final_label_name, image_id_str, base_output_dir, show_plot=False)


    final_bboxes_for_visualization_df = pd.DataFrame(final_bboxes_for_visualization_list)

    # 4. Create Main Visualization (Matplotlib figure with OpenCV drawn image and caption)
    print(f"\n  Generating main visualization for {image_id_str}...")
    # The figure size needs to accommodate the image and the caption area
    # Aspect ratio of image:
    fig_aspect_ratio = img_width / img_height
    fig_width_mpl = 12 # Base width in inches for Matplotlib figure
    fig_height_mpl_image_area = fig_width_mpl / fig_aspect_ratio
    fig_height_mpl_caption_area = 3 # Inches for caption area
    total_fig_height_mpl = fig_height_mpl_image_area + fig_height_mpl_caption_area

    main_fig = plt.figure(figsize=(fig_width_mpl, total_fig_height_mpl), constrained_layout=False) # Disable constrained_layout for add_axes
    
    # Create axes for the image using add_axes [left, bottom, width, height]
    image_ax_height_ratio = fig_height_mpl_image_area / total_fig_height_mpl
    image_ax = main_fig.add_axes([0.05, 1 - image_ax_height_ratio - 0.02, 0.9, image_ax_height_ratio]) # Top part for image


    # Use OpenCV to draw bboxes and labels on the image
    annotated_image_cv = create_clean_visualization_with_labels(
        image_rgb, final_bboxes_for_visualization_df, processed_label_colors, final_confidence_scores
    )
    image_ax.imshow(annotated_image_cv)
    image_ax.axis('off')
    image_ax.set_title(f'Medical Image Analysis: {image_id_str}', fontsize=16, fontweight='bold', pad=10)
    
    # Add legend to the image axes
    create_legend_for_main_plot(processed_label_colors, image_ax)

    # Extract main caption/explanation for this image
    main_caption_text = f"Report for Image ID: {image_id_str}." # Default caption
    if not explanations_data_for_image_df.empty:
        # Try to find a 'caption' or 'explanation' column
        cap_col_names = [col for col in explanations_data_for_image_df.columns if any(k in col.lower() for k in ['caption', 'explanation', 'text'])]
        if cap_col_names:
            main_caption_text = explanations_data_for_image_df.iloc[0][cap_col_names[0]]
            if pd.isna(main_caption_text): main_caption_text = f"No caption found for {image_id_str} in provided data."
        else:
            main_caption_text = f"Caption/explanation column not found for {image_id_str}."
    else:
        main_caption_text = f"No explanation/caption data provided for {image_id_str}."


    # Add improved caption subplot to the main figure
    create_improved_caption_subplot(main_fig, main_caption_text, processed_label_colors)

    # Save the main visualization
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") # datetime needs to be imported
    main_viz_filename = f"{image_id_str}_main_analysis_{timestamp}.png"
    main_viz_path = os.path.join(base_output_dir, "main_visualizations")
    os.makedirs(main_viz_path, exist_ok=True)
    full_save_path = os.path.join(main_viz_path, main_viz_filename)
    
    try:
        main_fig.savefig(full_save_path, dpi=150, bbox_inches='tight') # Adjust DPI as needed
        print(f"✓ Main visualization saved: {full_save_path}")
    except Exception as e_save:
        print(f" Error saving main visualization for {image_id_str}: {e_save}")
    plt.close(main_fig) # IMPORTANT: Close the figure to free memory


    # 5. Generate and Display Data Table (console output)
    data_summary_table_df = create_image_data_table_df(image_id_str, concepts_full_df, caption_file_full_df, explanations_data_for_image_df)
    display_formatted_data_table(image_id_str, data_summary_table_df)
    
    # 6. Compile statistics for this image
    avg_conf_this_image = np.mean(list(final_confidence_scores.values())) if final_confidence_scores else 0.0
    
    image_analysis_stats = {
        'image_id': image_id_str,
        'structures_count': len(final_confidence_scores),
        'bbox_count': processed_bbox_count_for_stats, # Bboxes actually processed by SAM
        'avg_confidence': avg_conf_this_image,
        'confidence_scores': final_confidence_scores,    # Dict {label: score}
        'processing_details': final_processing_strategies, # Dict {label: strategy}
        'yolo_stats': {
            'original_sam_bboxes_for_img': len(sam_data_for_image_df),
            'yolo_additions': yolo_additions_count,
            'total_detections_before_sam_adv': len(enhanced_detections_df)
        },
        'main_visualization_path': full_save_path
    }
    
    print(f" Analysis Completed for Image ID: {image_id_str} ")
    return full_save_path, image_analysis_stats

# Need to import datetime for the timestamp
from datetime import datetime

print("✓ Main Per-Image Analysis function defined.")

# ==============================================================================
# SECTION IX: MAIN EXECUTION RUNNER / PIPELINE
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION IX: DEFINING MAIN EXECUTION PIPELINE")
print("----------------------------------------------------------------------")

def run_full_analysis_pipeline(num_images_to_process=None, recess_time_seconds=10):
    """
    Main pipeline to run the complete medical image analysis for multiple images.
    Args:
        num_images_to_process (int, optional): Limit the number of images to process. None for all.
        recess_time_seconds (int): Pause duration in seconds between processing images for memory recovery.
    """
    print("======================================================================")
    print(" INITIATING FULL MEDICAL IMAGE ANALYSIS PIPELINE ")
    print("======================================================================")

    # 1. Initial Setup: SAM, Load Data
    # Clear CUDA cache thoroughly at the very beginning of the pipeline
    print("\n--- Phase 1: Initial Setup ---")
    clear_cuda_cache()
    
    global SAM_PREDICTOR, SAM_DEVICE, SAM_MODEL_INSTANCE # Make them global to manage their lifecycle
    SAM_PREDICTOR, SAM_DEVICE, SAM_MODEL_INSTANCE = setup_sam() # Returns predictor, device, and model instance
    if SAM_PREDICTOR is None:
        print("CRITICAL ERROR: SAM Predictor setup failed. Aborting pipeline.")
        return

    # Load all annotation/textual data
    # These are full DataFrames, will be filtered per image inside process_single_image_analysis
    sam_coords_full_df, explanations_full_df, concepts_full_df, caption_file_full_df = load_data()
    if sam_coords_full_df.empty:
        print("CRITICAL ERROR: Initial SAM coordinates data (sam_coord.csv) is empty or failed to load. Aborting.")
        # Clean up SAM model if it was loaded
        if SAM_MODEL_INSTANCE is not None:
            del SAM_MODEL_INSTANCE
            SAM_MODEL_INSTANCE = None
        if SAM_PREDICTOR is not None:
            del SAM_PREDICTOR
            SAM_PREDICTOR = None
        clear_cuda_cache()
        return

    # Determine images to process
    # Assuming 'ImageID' is the column in sam_coords_full_df
    if 'ImageID' not in sam_coords_full_df.columns:
        print(f"CRITICAL ERROR: 'ImageID' column not found in {SAM_COORD_PATH}. Aborting.")
        # Clean up SAM model
        if SAM_MODEL_INSTANCE is not None: del SAM_MODEL_INSTANCE; SAM_MODEL_INSTANCE = None
        if SAM_PREDICTOR is not None: del SAM_PREDICTOR; SAM_PREDICTOR = None
        clear_cuda_cache()
        return
        
    unique_image_ids = sam_coords_full_df['ImageID'].unique().tolist()
    if num_images_to_process is not None and num_images_to_process > 0:
        unique_image_ids = unique_image_ids[:num_images_to_process]
    
    total_images = len(unique_image_ids)
    if total_images == 0:
        print("No images found to process based on SAM coordinates file. Exiting.")
        # Clean up SAM model
        if SAM_MODEL_INSTANCE is not None: del SAM_MODEL_INSTANCE; SAM_MODEL_INSTANCE = None
        if SAM_PREDICTOR is not None: del SAM_PREDICTOR; SAM_PREDICTOR = None
        clear_cuda_cache()
        return
        
    print(f"Found {total_images} unique image IDs to process.")
    print(f"Recess time between images: {recess_time_seconds} seconds.")
    print(f"Output will be saved in: {OUTPUT_DIR}") # OUTPUT_DIR is global

    # 2. Per-Image Processing Loop
    print("\n--- Phase 2: Processing Images ---")
    all_results_paths = []
    all_images_analysis_statistics = {} # Dict to store stats for each processed image

    for i, current_image_id in enumerate(unique_image_ids):
        print(f"\n------------------------- PROCESSING IMAGE {i+1}/{total_images}: {current_image_id} -------------------------")
        
        # Prepare data for the current image
        # Filter sam_coords_full_df for the current_image_id
        current_sam_data_for_image_df = sam_coords_full_df[sam_coords_full_df['ImageID'] == current_image_id]
        
        # Filter explanations_full_df for the current_image_id
        # Need to know the ID column in explanations_full_df (e.g., 'ImageID', 'ID')
        # Assuming 'ImageID' for explanations_full_df as well, or adapt column name
        id_col_expl = 'ImageID' # Default, check if this column exists
        if explanations_full_df is not None and not explanations_full_df.empty:
             if id_col_expl not in explanations_full_df.columns: # Try other common ID names
                 common_ids = ['ID', 'id', 'image_id']
                 for c_id in common_ids:
                     if c_id in explanations_full_df.columns:
                         id_col_expl = c_id
                         break
             current_explanations_for_image_df = explanations_full_df[explanations_full_df[id_col_expl] == current_image_id] if id_col_expl in explanations_full_df.columns else pd.DataFrame()
        else:
            current_explanations_for_image_df = pd.DataFrame()


        # Call the main processing function for this single image
        try:
            viz_path, image_stats = process_single_image_analysis(
                current_image_id,
                current_sam_data_for_image_df,
                current_explanations_for_image_df,
                concepts_full_df,
                caption_file_full_df,
                SAM_PREDICTOR, # Pass the globally initialized predictor
                OUTPUT_DIR     # Pass the global output directory
            )
            if viz_path and image_stats: # If processing was successful
                all_results_paths.append(viz_path)
                all_images_analysis_statistics[current_image_id] = image_stats
        except Exception as e_img_proc:
            print(f"Unhandled CRITICAL error during processing of image {current_image_id}: {e_img_proc}")
            import traceback
            traceback.print_exc()
            all_images_analysis_statistics[current_image_id] = {'error': str(e_img_proc), 'status': 'failed_critically'}


        # Memory Management: Pause and Clear Cache
        if i < total_images - 1: # Don't pause after the last image
            print(f"\n--- Inter-Image Recess & Cleanup ({current_image_id} done) ---")
            clear_cuda_cache() # Clear CUDA cache thoroughly
            if SAM_DEVICE == "cuda": # Only makes sense to offload if SAM is on CUDA
                print("Offloading SAM predictor to CPU during recess (if on CUDA)...")
                if SAM_PREDICTOR is not None and hasattr(SAM_PREDICTOR.model, 'to'):
                    SAM_PREDICTOR.model.to('cpu') # Move model to CPU
                    # Note: If you move the model, you'll need to move it back to CUDA
                    # before the next `predictor.set_image` call if `SAM_DEVICE` is 'cuda'.
                    # This is handled by `setup_sam` if it re-checks device, or predictor needs re-init or model.to(SAM_DEVICE)

            print(f"Pausing for {recess_time_seconds} seconds...")
            time.sleep(recess_time_seconds)

            if SAM_DEVICE == "cuda":
                 print("Restoring SAM predictor to CUDA (if applicable)...")
                 if SAM_PREDICTOR is not None and hasattr(SAM_PREDICTOR.model, 'to'):
                    SAM_PREDICTOR.model.to(SAM_DEVICE) # Move model back to original device
                 clear_cuda_cache() # Clear again after moving model back

    print("\n--- Phase 3: Final Reporting & Cleanup ---")
    # Generate overall reports from all_images_analysis_statistics
    if all_images_analysis_statistics:
        print("Generating final summary reports and CSVs...")
        exported_dfs_and_paths = export_analysis_metrics_to_csv(all_images_analysis_statistics, OUTPUT_DIR)
        generate_overall_final_report_text(all_images_analysis_statistics, exported_dfs_and_paths)
    else:
        print("No images were successfully processed, skipping final report generation.")

    # Final cleanup of SAM model from GPU memory
    print("Performing final cleanup of SAM model...")
    if SAM_MODEL_INSTANCE is not None:
        del SAM_MODEL_INSTANCE # Remove reference to model object
        SAM_MODEL_INSTANCE = None
    if SAM_PREDICTOR is not None:
        del SAM_PREDICTOR # Remove reference to predictor object
        SAM_PREDICTOR = None
    clear_cuda_cache() # Final clear

    print("\n======================================================================")
    print("✅✅✅ MEDICAL IMAGE ANALYSIS PIPELINE COMPLETED ✅✅✅")
    print(f"Processed {len(all_images_analysis_statistics)} images. Results in: {OUTPUT_DIR}")
    print("======================================================================")


# ==============================================================================
# SECTION X: SCRIPT EXECUTION TRIGGER
# ==============================================================================
print("\n----------------------------------------------------------------------")
print("SECTION X: TRIGGERING SCRIPT EXECUTION")
print("----------------------------------------------------------------------")

if __name__ == '__main__':
    print("Starting the main execution block...")
    
    # Configuration for the run:
    NUMBER_OF_IMAGES_TO_RUN = None # Set to an integer to limit, e.g., 2 for testing, None for all
    PAUSE_BETWEEN_IMAGES_SEC = 15  # Increase if OOM errors persist, decrease if memory is ample
                                  # For an 8GB card with ViT-H, a longer pause might be beneficial.
    
    # Ensure global variables for SAM are initialized to None before the pipeline tries to use them
    SAM_PREDICTOR = None
    SAM_DEVICE = "cpu" # Default
    SAM_MODEL_INSTANCE = None

    try:
        run_full_analysis_pipeline(
            num_images_to_process=NUMBER_OF_IMAGES_TO_RUN,
            recess_time_seconds=PAUSE_BETWEEN_IMAGES_SEC
        )
    except Exception as e_pipeline:
        print(f"CRITICAL FAILURE IN PIPELINE EXECUTION")
        print(f"Error: {e_pipeline}")
        import traceback
        traceback.print_exc()
        # Attempt a final cleanup even on pipeline failure
        if 'SAM_MODEL_INSTANCE' in globals() and SAM_MODEL_INSTANCE is not None: del SAM_MODEL_INSTANCE
        if 'SAM_PREDICTOR' in globals() and SAM_PREDICTOR is not None: del SAM_PREDICTOR
        clear_cuda_cache()
    finally:
        print("\nScript execution finished.")
		
		

----------------------------------------------------------------------
SECTION I: INITIALIZING LIBRARIES, CONFIGURATION, AND CUDA
----------------------------------------------------------------------
✓ Scikit-image (skimage) available.
✓ Ultralytics YOLO available.

--- Loading Configuration ---
✓ Output directory set to: ./result_explain-3/sam_deduplicated
✓ Color palette loaded with 15 colors.
 Clearing CUDA cache...
✓ CUDA cache cleared and garbage collected.

----------------------------------------------------------------------
SECTION II: DEFINING CORE UTILITY FUNCTIONS
----------------------------------------------------------------------
✓ Core utility functions defined.

----------------------------------------------------------------------
SECTION III: SETTING UP SEGMENT ANYTHING MODEL (SAM)
----------------------------------------------------------------------
✓ SAM setup function defined.

----------------------------------------------------------------------
SECTION IV: D