In [3]:
import json
import random
import re

# Configuration
INPUT_JSON = "radgraph_processed.json"
OUTPUT_JSON = "slava_llava_split_"
TEST_SPLIT = 0.1# 10% test split

# Use multi-image tokens (LLaVA-Phi expects these for dual images if trained that way)
IMAGE_TOKEN = "<image>"

# Prompt pools
recognition_prompts = [
    "Enumerate all abnormal radiographic findings seen on the frontal and lateral chest X-rays, along with their precise anatomical locations.",
    "List every visible pathology in the lungs, heart, pleura, and bones, as observed on both frontal and lateral views.",
    "Describe only the radiographic abnormalities visible in these dual-view chest X-rays. Exclude normal structures.",
    "Identify and localize any abnormal opacities, effusions, consolidations, or structural deviations present in the chest radiographs.",
    "Specify all observed abnormalities in the dual chest views, including their type (e.g., mass, effusion, opacity) and anatomical location.",
    "Report abnormal findings only from the provided frontal and lateral chest X-ray images. Do not describe normal appearances.",
    "What pathological signs can be identified across both X-ray views? Be specific about laterality and anatomical regions.",
    "Describe all clinically relevant radiographic findings, focusing on abnormalities in the lungs, mediastinum, and chest wall.",
    "From the dual-view chest radiographs, list any deviations from normal radiographic anatomy or pathology that requires clinical attention.",
    "Summarize all abnormal chest X-ray findings, organized by anatomical region (e.g., lungs, heart, pleura, bones)."
]


reasoning_prompts = [
    "Based on the findings in the frontal and lateral chest X-rays, what is the most likely clinical diagnosis?",
    "Interpret the radiographic abnormalities observed in both views and explain their clinical implications.",
    "Given the dual-view chest radiographs, what is your diagnostic impression and reasoning behind it?",
    "Using the observed abnormalities in these X-rays, infer the likely pathology and explain its clinical relevance.",
    "What clinical condition best explains the abnormal findings visible in these frontal and lateral chest radiographs?"
]
def impression_starts_with_normal_phrases(impression):
    impression_text = impression.strip() if impression else ''
    starts_with_patterns = [
        r"^no\b",
        r"^no evidence\b",
        r"^no acute\b",
        r"^normal\b"
    ]

    for pattern in starts_with_patterns:
        if re.match(pattern, impression_text, re.IGNORECASE):
            return False

    return True
def infer_view_hint(findings_text):
    text_upper = findings_text.upper()
    if "PA" in text_upper and "AP" in text_upper:
        return "PA and AP chest X-ray views are shown."
    elif "PA" in text_upper:
        return "PA and lateral chest X-ray views are shown."
    elif "AP" in text_upper:
        return "AP and lateral chest X-ray views are shown."
    else:
        return "Frontal and lateral chest X-ray views are shown."

# Load RadGraph-processed JSON
with open(INPUT_JSON, "r") as f:
    data = json.load(f)

output = []
normalcount=0
abnormalCount = 0
for study_id, entry in data.items():
    findings = entry.get("findings", "").strip()
    impression = entry.get("impression", "").strip()
    image_paths = entry.get("image_paths", [])
    frontal = entry.get("image_paths", "")
    lateral= entry.get("image_paths", "")
    if not findings or not impression:
        continue
    if len(image_paths) < 2:
        continue  # Need both frontal and lateral views

    # Dynamic view hint based on findings
    view_hint = infer_view_hint(findings)

    # Build prompts with image tokens and helpful context
    recognition_prompt = f"{IMAGE_TOKEN}{view_hint} {random.choice(recognition_prompts)}"
    reasoning_prompt = f"{IMAGE_TOKEN}{view_hint} {random.choice(reasoning_prompts)}"
    item = {
        "frontal": frontal[0],
        "lateral": lateral[1],
        "recognition_input": recognition_prompt,
        "reasoning_input": reasoning_prompt,
        "findings": findings,
        "impression": impression
    }
    output.append(item)
    
    # if impression_starts_with_normal_phrases(impression) and normalcount<100:
    #     output.append(item)
    #     normalcount+=1
    # elif not impression_starts_with_normal_phrases(impression):
    #     abnormalCount+=1
    #     output.append(item)  

# Shuffle and split into train/test
random.shuffle(output)
split_index = int(len(output) * (1 - TEST_SPLIT))
train_data = output[:split_index]
test_data = output[split_index:]

# Save to JSON files
with open(OUTPUT_JSON + 'train.json', "w") as f:
    json.dump(train_data, f, indent=2)

with open(OUTPUT_JSON + 'test.json', "w") as f:
    json.dump(test_data, f, indent=2)



print(f"✅ Saved {len(train_data)} train, {len(test_data)} test")
print(f"   '{OUTPUT_JSON}train.json'")
print(f"   '{OUTPUT_JSON}test.json'")



✅ Saved 28647 train, 3183 test
   'slava_llava_split_train.json'
   'slava_llava_split_test.json'


In [4]:
import json

input_json = "slava_llava_split_train.json"  # Original input
output_json = "slava_llava_recognition.json"  # Output for LLaVA fine-tuning

with open(input_json, "r") as f:
    data = json.load(f)

converted = []

for item in data:
        try:
            frontal = item["frontal"]
            lateral = item["lateral"]
            instruction = item["recognition_input"]
            response = item["findings"]  # Yes, spelling is preserved as-is
    
            # Skip empty responses
            if not response.strip():
                continue
    
            sample = {
                "frontal": frontal,
                "lateral":lateral,
                "conversations": [
                    {"from": "human", "value": instruction},
                    {"from": "gpt", "value": response}
                ]
            }
    
            converted.append(sample)
           
    
        except KeyError as e:
            print(f"[Skip] Missing key {e} in item")

        

with open(output_json, "w") as f:
    json.dump(converted, f, indent=2)

print(f"✅ Converted {len(converted)} samples to {output_json}")


✅ Converted 28647 samples to slava_llava_recognition.json


In [5]:
import json
from collections import defaultdict
import random

input_json = "slava_llava_split_train.json"
output_json = "slava_llava_reasoning.json"

# Load original data
with open(input_json, "r") as f:
    data = json.load(f)

# Categorize by normal/abnormal findings
sample_buckets = defaultdict(list)
for item in data:
    impression = item.get("impression", "").lower()
    is_normal = "no acute" in impression or "normal" in impression or "no evidence" in impression
    sample_buckets["normal" if is_normal else "abnormal"].append(item)

print(f"Original counts - Normal: {len(sample_buckets['normal'])}, Abnormal: {len(sample_buckets['abnormal'])}")

# Oversampling parameters
ABNORMAL_OVERSAMPLE_FACTOR = 3  # 3x more abnormal cases
ENHANCED_PROMPT_TEMPLATES = [
    "GIVEN THESE FINDINGS: {findings}\n{original_instruction}",
    "THE RADIOLOGIST NOTED: {findings}\nBASED ON THIS, {original_instruction}",
    "CLINICAL CONTEXT: {findings}\nPLEASE PROVIDE YOUR ANALYSIS: {original_instruction}",
    "{original_instruction}\nRELEVANT FINDINGS INCLUDE: {findings}"
]
REASONING_PROMPTS = [
    "<image> ANALYZE THE DUAL-VIEW CHEST RADIOGRAPHS AND DESCRIBE THE MOST CLINICALLY SIGNIFICANT FINDINGS.",
    "<image> IDENTIFY AND PRIORITIZE THE TOP 3 RADIOGRAPHIC ABNORMALITIES THAT REQUIRE CLINICAL ATTENTION.",
    "<image> COMPARE THE CURRENT STUDY WITH PRIOR IMAGING (IF AVAILABLE). WHAT INTERVAL CHANGES ARE MOST CONCERNING?",
    "<image> WHAT RADIOGRAPHIC SIGNS SUGGEST DECOMPENSATION IN THIS PATIENT'S CONDITION?",
    "<image> WHICH FINDINGS WOULD YOU IMMEDIATELY REPORT TO THE TREATING PHYSICIAN AND WHY?",
    "<image> ASSESS THE POSITIONING AND PLACEMENT OF ALL TUBES, LINES, AND DEVICES.",
    "<image> DESCRIBE ANY FINDINGS THAT SUGGEST ACUTE VERSUS CHRONIC PATHOLOGICAL PROCESSES.",
    "<image> WHAT DIFFERENTIAL DIAGNOSES WOULD YOU CONSIDER BASED ON THESE RADIOGRAPHIC FINDINGS?",
    "<image> EVALUATE THE CARDIOPULMONARY STATUS AND COMMENT ON ANY DECOMPENSATION SIGNS.",
    "<image> IDENTIFY ANY FINDINGS THAT MAY REQUIRE IMMEDIATE INTERVENTION VERSUS FOLLOW-UP MONITORING."
]


converted = []

def build_enhanced_prompt(item):
    """Generate multiple prompt variations incorporating findings"""
    original_instruction = random.choice(REASONING_PROMPTS)
    findings = item.get("findings", "").strip()
    
    if not findings:
        return [original_instruction]
    
    return [
        template.format(
            findings=findings,
            original_instruction=original_instruction
        )
        for template in ENHANCED_PROMPT_TEMPLATES
    ]

# Process normal samples (no oversampling)
for item in sample_buckets["normal"]:
    try:
        enhanced_prompts = build_enhanced_prompt(item)
        converted.append({
            "frontal": item["frontal"],
            "lateral": item["lateral"],
            "conversations": [
                {"from": "human", "value": enhanced_prompts[0]},
                {"from": "gpt", "value": item["impression"]}
            ],
            "category": "normal"
        })
    except KeyError as e:
        print(f"[Skip] Missing key {e} in normal sample")

# Process and oversample abnormal cases
for item in sample_buckets["abnormal"]:
    try:
        # Original sample
        enhanced_prompts = build_enhanced_prompt(item)
        converted.append({
            "frontal": item["frontal"],
            "lateral": item["lateral"],
            "conversations": [
                {"from": "human", "value": enhanced_prompts[0] },
                {"from": "gpt", "value": item["impression"]}
            ],
            "category": "abnormal_original"
        })
        
        # Create enhanced variations (oversampling)
        for i, prompt in enumerate(enhanced_prompts[:ABNORMAL_OVERSAMPLE_FACTOR]):
            converted.append({
                "frontal": item["frontal"],
                "lateral": item["lateral"],
                "conversations": [
                    {"from": "human", "value": prompt},
                    {"from": "gpt", "value": item["impression"]}
                ],
                "category": f"abnormal_enhanced_{i+1}"
            })
            
    except KeyError as e:
        print(f"[Skip] Missing key {e} in abnormal sample")

# Shuffle the dataset to mix normal/abnormal samples
random.shuffle(converted)

# Save enhanced dataset
with open(output_json, "w") as f:
    json.dump(converted, f, indent=2)

# Print statistics
category_counts = defaultdict(int)
for item in converted:
    category_counts[item["category"]] += 1

print("\nFinal Dataset Composition:")
for category, count in category_counts.items():
    print(f"- {category}: {count} samples")
print(f"\n✅ Saved {len(converted)} samples to {output_json}")

Original counts - Normal: 12755, Abnormal: 15892

Final Dataset Composition:
- abnormal_original: 15892 samples
- abnormal_enhanced_2: 15892 samples
- abnormal_enhanced_1: 15892 samples
- abnormal_enhanced_3: 15892 samples
- normal: 12755 samples

✅ Saved 76323 samples to slava_llava_reasoning.json


In [6]:
import json
from collections import defaultdict
import random

input_json = "slava_llava_split_train.json"
output_json = "slava_llava_report.json"

# Load original data
with open(input_json, "r") as f:
    data = json.load(f)

# Categorize by normal/abnormal findings
sample_buckets = defaultdict(list)
for item in data:
    impression = item.get("impression", "").lower()
    is_normal = "no acute" in impression or "normal" in impression or "no evidence" in impression
    sample_buckets["normal" if is_normal else "abnormal"].append(item)

print(f"Original counts - Normal: {len(sample_buckets['normal'])}, Abnormal: {len(sample_buckets['abnormal'])}")

# Oversampling parameters
ABNORMAL_OVERSAMPLE_FACTOR = 4  # 3x more abnormal cases
REPORT_INSTRUCTIONS = [
    "<image> Describe the findings in these frontal and lateral chest X-rays in a structured radiology report format.",
    "<image> Generate a complete radiology report for these CXRs including findings and impression.",
    "<image> Interpret these chest X-rays and provide a professional radiology report.",
    "<image> What abnormalities are visible in these CXRs? Provide a structured report.",
    "<image> Analyze these frontal and lateral chest X-rays and summarize the key findings.",
    "<image> Prepare a radiology report for these images with findings and clinical impression.",
    "<image> Evaluate these CXRs and document your observations in standard report format.",
    "<image> Provide a detailed interpretation of these chest X-rays with findings and conclusion.",
    "<image> Write a radiology report for these images following clinical documentation standards.",
    "<image> Identify and describe any pathological findings in these CXRs in report format."
]


converted = []


# Process normal samples (no oversampling)
for item in sample_buckets["normal"]:
    try:
        base_instruction = random.choice(REPORT_INSTRUCTIONS)
        findings = item["findings"]
        impression = item["impression"] 
        report_text = f"FINDINGS: {findings}"
        report_text += f"IMPRESSION: {impression}"
        converted.append({
            "frontal": item["frontal"],
            "lateral": item["lateral"],
            "conversations": [
                {"from": "human", "value": base_instruction},
                {"from": "gpt", "value":report_text}
            ],
            "category": "normal"
        })
    except KeyError as e:
        print(f"[Skip] Missing key {e} in normal sample")

# Process and oversample abnormal cases
for item in sample_buckets["abnormal"]:
    try:
        # Original sample
        findings = item["findings"]
        impression = item["impression"] 
        report_text = f"FINDINGS: {findings}"
        report_text += f"IMPRESSION: {impression}"
        converted.append({
            "frontal": item["frontal"],
            "lateral": item["lateral"],
            "conversations": [
                {"from": "human", "value": random.choice(REPORT_INSTRUCTIONS) },
                {"from": "gpt", "value":report_text}
            ],
            "category": "abnormal_original"
        })
        
        # Create enhanced variations (oversampling)
        for i in range(ABNORMAL_OVERSAMPLE_FACTOR):
            converted.append({
                "frontal": item["frontal"],
                "lateral": item["lateral"],
                "conversations": [
                    {"from": "human", "value": random.choice(REPORT_INSTRUCTIONS)},
                    {"from": "gpt", "value": report_text}
                ],
                "category": f"abnormal_enhanced_{i+1}"
            })
            
    except KeyError as e:
        print(f"[Skip] Missing key {e} in abnormal sample")

# Shuffle the dataset to mix normal/abnormal samples
random.shuffle(converted)

# Save enhanced dataset
with open(output_json, "w") as f:
    json.dump(converted, f, indent=2)

# Print statistics
category_counts = defaultdict(int)
for item in converted:
    category_counts[item["category"]] += 1

print("\nFinal Dataset Composition:")
for category, count in category_counts.items():
    print(f"- {category}: {count} samples")
print(f"\n✅ Saved {len(converted)} samples to {output_json}")

Original counts - Normal: 12755, Abnormal: 15892

Final Dataset Composition:
- abnormal_original: 15892 samples
- normal: 12755 samples
- abnormal_enhanced_4: 15892 samples
- abnormal_enhanced_2: 15892 samples
- abnormal_enhanced_3: 15892 samples
- abnormal_enhanced_1: 15892 samples

✅ Saved 92215 samples to slava_llava_report.json


In [1]:
import os
import json
import random
import torch
import logging
from PIL import Image
from tqdm import tqdm

from transformers import AutoTokenizer, CLIPImageProcessor
from llava_phi.model import LlavaPhiForCausalLM
from llava_phi.constants import DEFAULT_IMAGE_TOKEN
from llava_phi.conversation import conv_templates
from llava_phi.utils import disable_torch_init
from transformers.generation.utils import GenerationMixin
from transformers import logging as hf_logging

# ====================
# 🔧 CONFIGURATION
# ====================
MODEL_PATH = "/media/volume/Slava/Dual-View-Slava-Final"
IMAGE_FOLDER = "/media/volume/Slava/MIMIC_Dataset224"
INPUT_JSON = "slava_llava_split_test.json"
OUTPUT_JSON = "slava_llava_predict_1.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16  # ✅ Faster
MANUAL_SEED = 42

hf_logging.set_verbosity_error()

KEYWORDS = [
    "pneumonia", "atelectasis", "effusion", "cardiomegaly", "consolidation",
    "edema", "scarring", "opacity", "vascular congestion", "catheter", "radiation"
]

REPORT_INSTRUCTIONS = [
    "<image> You are a board-certified radiologist. Analyze the frontal and lateral chest X-rays. Identify and describe all radiographic findings, including location, size, density, and distribution. Provide a structured report with FINDINGS and IMPRESSION. Ensure clinical relevance and diagnostic precision.",
    "<image> Evaluate the dual-view CXRs thoroughly. Identify all signs of disease, subtle or prominent. For each finding, describe anatomical position, severity, and possible differential diagnoses. Conclude with a well-reasoned and structured IMPRESSION.",
    "<image> Generate a detailed radiology report using standard clinical format. Include all abnormal and incidental findings in the FINDINGS section. Summarize key diagnostic insights in the IMPRESSION. Avoid generic language; favor precise anatomical and pathological descriptors.",
    "<image> Carefully examine the frontal and lateral chest X-rays. Write a complete radiology report structured into: FINDINGS (organized by organ system and zones) and IMPRESSION (summary with clinical prioritization). Include all deviations from normal, no matter how subtle.",
    "<image> Interpret these dual-view CXRs. Document all radiologic abnormalities by describing their appearance, location (lung zones, mediastinum, pleura, bones), and clinical implications. Format the response as a formal radiology report with clear headings for FINDINGS and IMPRESSION.",
    "<image> Examine the provided chest X-rays from both frontal and lateral views. Accurately identify and describe any abnormal radiologic signs, even subtle or borderline cases. Structure the report with FINDINGS and a diagnostic IMPRESSION.",
    "<image> Perform a radiological assessment of these dual-view CXRs. List all notable observations and pathological signs, with attention to anatomical detail and clinical context. Format the response as a full report: FINDINGS followed by IMPRESSION.",
    "<image> Review the chest radiographs and generate a report that includes all visible abnormalities. Pay attention to symmetry, lung markings, cardiac silhouette, and bony structures. Use the standard format with FINDINGS and IMPRESSION.",
    "<image> You are an expert thoracic radiologist. Describe all pathological and incidental findings in these frontal and lateral CXRs. Be thorough and concise. Conclude with a structured IMPRESSION summarizing the clinical picture.",
    "<image> Analyze these chest X-rays using your clinical expertise. Report every abnormality using specific radiologic terminology. Clearly differentiate FINDINGS and IMPRESSION. Include zone-wise, side-wise, and severity-based descriptions."
]


# ====================
# 📋 LOGGING SETUP
# ====================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ====================
# 🧠 MODEL WRAPPER
# ====================
class LlavaWithGenerate(LlavaPhiForCausalLM, GenerationMixin):
    def generate(self, *args, **kwargs):
        self.image_features = None
        self.images = kwargs.pop("images", None)
        return super().generate(*args, **kwargs)

# ====================
# LOAD MODEL + TOKENIZER
# ====================
def load_model_and_tokenizer():
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if "<image>" not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = LlavaWithGenerate.from_pretrained(MODEL_PATH, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)

    model.model.bypass_vision_tower = False
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    assert image_token_id < model.get_input_embeddings().num_embeddings

    if torch.__version__ >= "2" and torch.cuda.is_available():
        try:
            model = torch.compile(model)
        except Exception as e:
            logger.warning(f"torch.compile() failed: {e}")

    return model, tokenizer, image_token_id

# ====================
# CONVERSATION FORMATTER
# ====================
def prepare_conversation(prompt):
    template_name = next((name for name in ["llava_phi", "phi", "v0", "med", "vicuna_v1"] if name in conv_templates), None)
    conv = conv_templates[template_name].copy()
    conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + prompt)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

# ====================
# IMAGE PROCESSING
# ====================
IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")

def expand2square(pil_img, background_color=(0, 0, 0)):
    w, h = pil_img.size
    if w == h:
        return pil_img
    elif w > h:
        result = Image.new(pil_img.mode, (w, w), background_color)
        result.paste(pil_img, (0, (w - h) // 2))
    else:
        result = Image.new(pil_img.mode, (h, h), background_color)
        result.paste(pil_img, ((h - w) // 2, 0))
    return result

def prepare_images(frontal_path, lateral_path):
    image_paths = [frontal_path, lateral_path]
    images = []
    bg_color = tuple(int(x * 255) for x in IMAGE_PROCESSOR.image_mean)
    for img_path in image_paths:
        with Image.open(os.path.join(IMAGE_FOLDER, img_path)) as img:
            img = img.convert("RGB")
            img = expand2square(img, bg_color)
            image_tensor = IMAGE_PROCESSOR(img, return_tensors="pt")['pixel_values'][0]
            images.append(image_tensor)
    return torch.stack(images).contiguous()

# ====================
# TOKENIZER IMAGE HANDLING
# ====================
def tokenizer_image_token(prompt, tokenizer, image_token_index, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])
    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors == 'pt':
        return torch.tensor(input_ids, dtype=torch.long)
    return input_ids

# ====================
# HEURISTIC CHECKS
# ====================
def is_low_confidence(text: str, min_word_count=40):
    return len(text.split()) < min_word_count or "impression:" not in text.lower()

def must_mention_keywords(reference: str, prediction: str, keywords=KEYWORDS):
    ref_keywords = [kw for kw in keywords if kw in reference.lower()]
    if not ref_keywords:
        return False
    missing = [kw for kw in ref_keywords if kw not in prediction.lower()]
    return len(missing) > 0

# ====================
# GENERATION LOGIC
# ====================
def generate_response(model, tokenizer, images, prompt, reference, image_token_index, max_retries=1):
    regen_reason = None
    for attempt in range(max_retries + 1):
        full_prompt = prepare_conversation(prompt)
        input_ids = tokenizer_image_token(full_prompt, tokenizer, image_token_index, return_tensors="pt").unsqueeze(0).to(DEVICE)
        attention_mask = (input_ids != tokenizer.pad_token_id).long().to(DEVICE)
        images = images.unsqueeze(0).to(DEVICE)

        with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=TORCH_DTYPE):
            try:
                output_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    images=images,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.95,
                    max_new_tokens=1024,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
            except RuntimeError as e:
                if "indexSelectLargeIndex" in str(e):
                    return "[ERROR: CUDA index bug]", True, "cuda_index_error"
                raise e

        gen_tokens = output_ids[0, input_ids.shape[1]:]
        decoded = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

        if is_low_confidence(decoded):
            regen_reason = "low_confidence"
            continue
        if must_mention_keywords(reference, decoded):
            regen_reason = "missing_keywords"
            continue

        return decoded, (attempt > 0), regen_reason

    return decoded, True, regen_reason or "max_attempts"

# ====================
# MAIN LOOP
# ====================
def process_validation():
    logger.info("Loading model and tokenizer...")
    model, tokenizer, image_token_index = load_model_and_tokenizer()
    torch.manual_seed(MANUAL_SEED)

    logger.info("Reading annotation file...")
    with open(INPUT_JSON, "r") as f:
        val_data = json.load(f)
    # val_data = val_data[:1000]  
    logger.info(" Checking for existing output...")
    existing_predictions = {}
    if os.path.exists(OUTPUT_JSON):
        with open(OUTPUT_JSON, "r") as f:
            for item in json.load(f):
                key = f"{item['frontal']}|{item['lateral']}"
                existing_predictions[key] = item

    processed_keys = set(existing_predictions.keys())
    logger.info(f"Found {len(processed_keys)} previously processed entries. Skipping them...")
    for entry in tqdm(val_data, desc="Processing"):
        try:
            findings = entry.get("findings")
            impression = entry.get("impression")
            reference_text = f"{findings}. IMPRESSION: {impression}" if findings and impression else ""
            frontal = entry.get("frontal")
            lateral = entry.get("lateral")
            key = f"{frontal}|{lateral}"

            if key in processed_keys:
                continue
            prompt = random.choice(REPORT_INSTRUCTIONS)

            if not frontal or not lateral or not reference_text:
                continue

            images = prepare_images(frontal, lateral)
            prediction, was_regenerated, regen_reason = generate_response(
                model, tokenizer, images, prompt, reference_text, image_token_index
            )

            result={
                "frontal": frontal,
                "lateral": lateral,
                "prompt": prompt,
                "reference": reference_text,
                "prediction": prediction,
                # "regenerated": was_regenerated,
                # "regeneration_reason": regen_reason if was_regenerated else None
            }
            existing_predictions[key] = result
            with open(OUTPUT_JSON, "w") as f:
                    json.dump(list(existing_predictions.values()), f, indent=2)
        except Exception as e:
            logger.warning(f"[Skip] Error processing {entry.get('frontal')}: {str(e)}")

    # with open(OUTPUT_JSON, "w") as f:
    #     json.dump(results, f, indent=2)

    logger.info(f"Completed. Saved")

if __name__ == "__main__":
    process_validation()


  from .autonotebook import tqdm as notebook_tqdm
INFO:__main__:Loading model and tokenizer...
Loading checkpoint shards: 100%|██████████| 3/3 [00:07<00:00,  2.45s/it]
INFO:__main__:Reading annotation file...
INFO:__main__: Checking for existing output...
INFO:__main__:Found 1615 previously processed entries. Skipping them...
Processing: 100%|██████████| 3183/3183 [2:26:07<00:00,  2.75s/it]  


NameError: name 'results' is not defined

In [None]:
import os
import json
import random
import torch
import logging
from PIL import Image
from tqdm import tqdm

from transformers import AutoTokenizer, CLIPImageProcessor
from llava_phi.model import LlavaPhiForCausalLM
from llava_phi.constants import DEFAULT_IMAGE_TOKEN
from llava_phi.conversation import conv_templates
from llava_phi.utils import disable_torch_init
from transformers.generation.utils import GenerationMixin
from transformers import logging as hf_logging

# ====================
# 🔧 CONFIGURATION
# ====================
MODEL_PATH = "/media/volume/Slava/Dual-View-Slava-Final"
IMAGE_FOLDER = "/media/volume/Slava/IU_Xray"
INPUT_JSON = "slava_llava_split_test.json"
OUTPUT_JSON = "slava_llava_predict_1.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16  # ✅ Faster
MANUAL_SEED = 42

hf_logging.set_verbosity_error()

KEYWORDS = [
    "pneumonia", "atelectasis", "effusion", "cardiomegaly", "consolidation",
    "edema", "scarring", "opacity", "vascular congestion", "catheter", "radiation"
]

REPORT_INSTRUCTIONS = [
    "<image> You are a board-certified radiologist. Analyze the frontal and lateral chest X-rays. Identify and describe all radiographic findings, including location, size, density, and distribution. Provide a structured report with FINDINGS and IMPRESSION. Ensure clinical relevance and diagnostic precision.",
    "<image> Evaluate the dual-view CXRs thoroughly. Identify all signs of disease, subtle or prominent. For each finding, describe anatomical position, severity, and possible differential diagnoses. Conclude with a well-reasoned and structured IMPRESSION.",
    "<image> Generate a detailed radiology report using standard clinical format. Include all abnormal and incidental findings in the FINDINGS section. Summarize key diagnostic insights in the IMPRESSION. Avoid generic language; favor precise anatomical and pathological descriptors.",
    "<image> Carefully examine the frontal and lateral chest X-rays. Write a complete radiology report structured into: FINDINGS (organized by organ system and zones) and IMPRESSION (summary with clinical prioritization). Include all deviations from normal, no matter how subtle.",
    "<image> Interpret these dual-view CXRs. Document all radiologic abnormalities by describing their appearance, location (lung zones, mediastinum, pleura, bones), and clinical implications. Format the response as a formal radiology report with clear headings for FINDINGS and IMPRESSION.",
    "<image> Examine the provided chest X-rays from both frontal and lateral views. Accurately identify and describe any abnormal radiologic signs, even subtle or borderline cases. Structure the report with FINDINGS and a diagnostic IMPRESSION.",
    "<image> Perform a radiological assessment of these dual-view CXRs. List all notable observations and pathological signs, with attention to anatomical detail and clinical context. Format the response as a full report: FINDINGS followed by IMPRESSION.",
    "<image> Review the chest radiographs and generate a report that includes all visible abnormalities. Pay attention to symmetry, lung markings, cardiac silhouette, and bony structures. Use the standard format with FINDINGS and IMPRESSION.",
    "<image> You are an expert thoracic radiologist. Describe all pathological and incidental findings in these frontal and lateral CXRs. Be thorough and concise. Conclude with a structured IMPRESSION summarizing the clinical picture.",
    "<image> Analyze these chest X-rays using your clinical expertise. Report every abnormality using specific radiologic terminology. Clearly differentiate FINDINGS and IMPRESSION. Include zone-wise, side-wise, and severity-based descriptions."
]


# ====================
# 📋 LOGGING SETUP
# ====================
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ====================
# 🧠 MODEL WRAPPER
# ====================
class LlavaWithGenerate(LlavaPhiForCausalLM, GenerationMixin):
    def generate(self, *args, **kwargs):
        self.image_features = None
        self.images = kwargs.pop("images", None)
        return super().generate(*args, **kwargs)

# ====================
# LOAD MODEL + TOKENIZER
# ====================
def load_model_and_tokenizer():
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if "<image>" not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = LlavaWithGenerate.from_pretrained(MODEL_PATH, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)

    model.model.bypass_vision_tower = False
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    assert image_token_id < model.get_input_embeddings().num_embeddings

    if torch.__version__ >= "2" and torch.cuda.is_available():
        try:
            model = torch.compile(model)
        except Exception as e:
            logger.warning(f"torch.compile() failed: {e}")

    return model, tokenizer, image_token_id

# ====================
# CONVERSATION FORMATTER
# ====================
def prepare_conversation(prompt):
    template_name = next((name for name in ["llava_phi", "phi", "v0", "med", "vicuna_v1"] if name in conv_templates), None)
    conv = conv_templates[template_name].copy()
    conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + prompt)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

# ====================
# IMAGE PROCESSING
# ====================
IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")

def expand2square(pil_img, background_color=(0, 0, 0)):
    w, h = pil_img.size
    if w == h:
        return pil_img
    elif w > h:
        result = Image.new(pil_img.mode, (w, w), background_color)
        result.paste(pil_img, (0, (w - h) // 2))
    else:
        result = Image.new(pil_img.mode, (h, h), background_color)
        result.paste(pil_img, ((h - w) // 2, 0))
    return result

def prepare_images(frontal_path, lateral_path):
    image_paths = [frontal_path, lateral_path]
    images = []
    bg_color = tuple(int(x * 255) for x in IMAGE_PROCESSOR.image_mean)
    for img_path in image_paths:
        with Image.open(os.path.join(IMAGE_FOLDER, img_path)) as img:
            img = img.convert("RGB")
            img = expand2square(img, bg_color)
            image_tensor = IMAGE_PROCESSOR(img, return_tensors="pt")['pixel_values'][0]
            images.append(image_tensor)
    return torch.stack(images).contiguous()

# ====================
# TOKENIZER IMAGE HANDLING
# ====================
def tokenizer_image_token(prompt, tokenizer, image_token_index, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])
    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors == 'pt':
        return torch.tensor(input_ids, dtype=torch.long)
    return input_ids

# ====================
# HEURISTIC CHECKS
# ====================
def is_low_confidence(text: str, min_word_count=40):
    return len(text.split()) < min_word_count or "impression:" not in text.lower()

def must_mention_keywords(reference: str, prediction: str, keywords=KEYWORDS):
    ref_keywords = [kw for kw in keywords if kw in reference.lower()]
    if not ref_keywords:
        return False
    missing = [kw for kw in ref_keywords if kw not in prediction.lower()]
    return len(missing) > 0

# ====================
# GENERATION LOGIC
# ====================
def generate_response(model, tokenizer, images, prompt, reference, image_token_index, max_retries=1):
    regen_reason = None
    for attempt in range(max_retries + 1):
        full_prompt = prepare_conversation(prompt)
        input_ids = tokenizer_image_token(full_prompt, tokenizer, image_token_index, return_tensors="pt").unsqueeze(0).to(DEVICE)
        attention_mask = (input_ids != tokenizer.pad_token_id).long().to(DEVICE)
        images = images.unsqueeze(0).to(DEVICE)

        with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=TORCH_DTYPE):
            try:
                output_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    images=images,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.95,
                    max_new_tokens=1024,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
            except RuntimeError as e:
                if "indexSelectLargeIndex" in str(e):
                    return "[ERROR: CUDA index bug]", True, "cuda_index_error"
                raise e

        gen_tokens = output_ids[0, input_ids.shape[1]:]
        decoded = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

        if is_low_confidence(decoded):
            regen_reason = "low_confidence"
            continue
        if must_mention_keywords(reference, decoded):
            regen_reason = "missing_keywords"
            continue

        return decoded, (attempt > 0), regen_reason

    return decoded, True, regen_reason or "max_attempts"

# ====================
# MAIN LOOP
# ====================
def process_validation():
    logger.info("Loading model and tokenizer...")
    model, tokenizer, image_token_index = load_model_and_tokenizer()
    torch.manual_seed(MANUAL_SEED)

    logger.info("Reading annotation file...")
    with open(INPUT_JSON, "r") as f:
        val_data = json.load(f)
    # val_data = val_data[:1000]  
    logger.info(" Checking for existing output...")
    existing_predictions = {}
    if os.path.exists(OUTPUT_JSON):
        with open(OUTPUT_JSON, "r") as f:
            for item in json.load(f):
                key = f"{item['frontal']}|{item['lateral']}"
                existing_predictions[key] = item

    processed_keys = set(existing_predictions.keys())
    logger.info(f"Found {len(processed_keys)} previously processed entries. Skipping them...")
    for entry in tqdm(val_data, desc="Processing"):
        try:
            findings = entry.get("findings")
            impression = entry.get("impression")
            reference_text = f"{findings}. IMPRESSION: {impression}" if findings and impression else ""
            frontal = entry.get("frontal")
            lateral = entry.get("lateral")
            key = f"{frontal}|{lateral}"

            if key in processed_keys:
                continue
            prompt = random.choice(REPORT_INSTRUCTIONS)

            if not frontal or not lateral or not reference_text:
                continue

            images = prepare_images(frontal, lateral)
            prediction, was_regenerated, regen_reason = generate_response(
                model, tokenizer, images, prompt, reference_text, image_token_index
            )

            result={
                "frontal": frontal,
                "lateral": lateral,
                "prompt": prompt,
                "reference": reference_text,
                "prediction": prediction,
                # "regenerated": was_regenerated,
                # "regeneration_reason": regen_reason if was_regenerated else None
            }
            existing_predictions[key] = result
            with open(OUTPUT_JSON, "w") as f:
                    json.dump(list(existing_predictions.values()), f, indent=2)
        except Exception as e:
            logger.warning(f"[Skip] Error processing {entry.get('frontal')}: {str(e)}")

    # with open(OUTPUT_JSON, "w") as f:
    #     json.dump(results, f, indent=2)

    logger.info(f"Completed. Saved")

if __name__ == "__main__":
    process_validation()


In [None]:
import os
import json
import random
import torch
import logging
from PIL import Image
from tqdm import tqdm
from bert_score import score as bert_score
from transformers import AutoTokenizer, CLIPImageProcessor
from llava_phi.model import LlavaPhiForCausalLM
from llava_phi.constants import DEFAULT_IMAGE_TOKEN
from llava_phi.conversation import conv_templates
from llava_phi.utils import disable_torch_init
from transformers.generation.utils import GenerationMixin

# ====================
# 🔧 CONFIGURATION
# ====================
MODEL_PATH = "/media/volume/Slava/Dual-View-Slava/Reporting"
IMAGE_FOLDER = "/media/volume/Slava/MIMIC_Dataset224"
INPUT_JSON = "slava_llava_split_test.json"
OUTPUT_JSON = "slava_llava_predict.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float32
MANUAL_SEED = 42

REPORT_INSTRUCTIONS = [
    "<image> Carefully analyze the frontal and lateral chest X-rays. Describe all observed abnormalities in detail, including anatomical location and severity. Conclude with a structured impression.",
    "<image> Generate a detailed radiology report including both FINDINGS and IMPRESSION. Focus especially on any abnormal features or clinically relevant signs.",
    "<image> Write a complete and structured report for these dual-view CXRs. Prioritize abnormal observations and explain their possible clinical implications.",
    "<image> Interpret these chest X-rays. Clearly describe all abnormal findings and avoid assumptions of normality unless explicitly visible. Provide a structured radiology report.",
    "<image> Create a full radiology report with FINDINGS and IMPRESSION. Do not omit any subtle or mild abnormalities seen in the images.",
    "<image> Analyze the given chest X-rays (frontal and lateral). List all pathological findings, even minor ones, and summarize them in the impression.",
    "<image> Provide a radiology report describing all radiographic abnormalities in detail. Include location, extent, and possible causes. Format it as FINDINGS and IMPRESSION.",
    "<image> You are a radiologist. Examine these CXRs and report every deviation from normal. Follow a standard structured reporting format.",
    "<image> Describe all visible pathologies and incidental findings in this CXR study. Use clinical terminology and conclude with a professional impression.",
    "<image> These are chest X-rays showing potential abnormalities. Write a radiology report that prioritizes accuracy and completeness, mentioning all visible signs of disease."
]

KEYWORDS = [
    "pneumonia", "effusion", "pneumothorax", "consolidation", "atelectasis", "edema", "opacity", "pleural",
    "pulmonary", "lung", "lungs", "cardiomegaly", "enlargement", "focal", "infiltrate", "interstitial",
    "scarring", "fibrosis", "calcification", "mass", "nodule", "hyperinflation", "collapse", "thickening",
    "blunting", "catheter", "tubes", "lines", "hernia", "fracture", "kyphosis", "lucency", "emphysema",
    "silhouette", "hilar", "cardiomediastinal", "mediastinal", "vascular congestion"
]

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LlavaWithGenerate(LlavaPhiForCausalLM, GenerationMixin):
    def generate(self, *args, **kwargs):
        self.image_features = None
        self.images = kwargs.pop("images", None)
        return super().generate(*args, **kwargs)

def load_model_and_tokenizer():
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = LlavaWithGenerate.from_pretrained(MODEL_PATH, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True)
    model = model.to(DEVICE)
    model.model.bypass_vision_tower = False
    return model, tokenizer

IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")

def expand2square(pil_img, background_color=(0, 0, 0)):
    w, h = pil_img.size
    if w == h:
        return pil_img
    result = Image.new(pil_img.mode, (max(w, h), max(w, h)), background_color)
    result.paste(pil_img, ((max(w, h) - w) // 2, (max(w, h) - h) // 2))
    return result

def prepare_images(frontal_path, lateral_path):
    image_paths = [frontal_path, lateral_path]
    images = []
    for img_path in image_paths:
        img = Image.open(os.path.join(IMAGE_FOLDER, img_path)).convert("RGB")
        img = expand2square(img, tuple(int(x * 255) for x in IMAGE_PROCESSOR.image_mean))
        image_tensor = IMAGE_PROCESSOR(img, return_tensors="pt")['pixel_values'][0]
        images.append(image_tensor)
    return torch.stack(images)

def prepare_conversation(prompt):
    template_name = next((name for name in ["llava_phi", "phi", "v0", "med", "vicuna_v1"] if name in conv_templates), None)
    conv = conv_templates[template_name].copy()
    conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + prompt)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()

def is_low_confidence(text: str, min_word_count=40):
    return len(text.split()) < min_word_count or "impression:" not in text.lower()

def must_mention_keywords(reference: str, prediction: str, keywords=KEYWORDS):
    ref_keywords = [kw for kw in keywords if kw in reference.lower()]
    return any(kw not in prediction.lower() for kw in ref_keywords)

def is_semantically_mismatched(reference, prediction, threshold=0.85):
    if not prediction.strip() or not reference.strip():
        return True, 0.0
    try:
        P, R, F1 = bert_score([prediction], [reference], lang="en", rescale_with_baseline=False)
        return F1[0].item() < threshold, F1[0].item()
    except Exception as e:
        logger.warning(f"BERTScore failed: {e}")
        return True, 0.0

def generate_response(model, tokenizer, images, prompt, reference, max_retries=2):
    regen_reason = None
    best_f1 = 0.0
    generation_settings = [
        {"temperature": 0.4, "top_p": 0.9, "max_new_tokens": 768},
        {"temperature": 0.2, "top_p": 0.85, "max_new_tokens": 896},
        {"temperature": 0.1, "top_p": 0.80, "max_new_tokens": 1024}
    ]

    for attempt in range(max_retries + 1):
        settings = generation_settings[min(attempt, len(generation_settings) - 1)]
        full_prompt = prepare_conversation(prompt)
        tokenized = tokenizer(
            full_prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=tokenizer.model_max_length - 512
        )
        input_ids = tokenized["input_ids"].to(DEVICE)
        attention_mask = tokenized["attention_mask"].to(DEVICE)
        images = images.unsqueeze(0).to(DEVICE)

        with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=TORCH_DTYPE == torch.bfloat16):
            output_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=images,
                do_sample=True,
                temperature=settings["temperature"],
                top_p=settings["top_p"],
                max_new_tokens=settings["max_new_tokens"],
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )

        gen_tokens = output_ids[0, input_ids.shape[1]:]
        decoded = tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()

        if is_low_confidence(decoded):
            regen_reason = "low_confidence"
            continue

        if must_mention_keywords(reference, decoded):
            regen_reason = "missing_keywords"
            continue

        mismatch, f1_score = is_semantically_mismatched(reference, decoded)
        best_f1 = f1_score
        if mismatch:
            regen_reason = f"low_bertscore ({f1_score:.3f})"
            continue

        return decoded, (attempt > 0), None, round(f1_score, 4)

    return decoded, True, regen_reason or "max_attempts", round(best_f1, 4)

def process_validation():
    logger.info("📦 Loading model and tokenizer...")
    model, tokenizer = load_model_and_tokenizer()
    torch.manual_seed(MANUAL_SEED)

    logger.info("📖 Reading annotation file...")
    with open(INPUT_JSON, "r") as f:
        val_data = json.load(f)

    logger.info("🗂️  Checking for existing output...")
    existing_predictions = {}
    if os.path.exists(OUTPUT_JSON):
        with open(OUTPUT_JSON, "r") as f:
            for item in json.load(f):
                key = f"{item['frontal']}|{item['lateral']}"
                existing_predictions[key] = item

    processed_keys = set(existing_predictions.keys())
    logger.info(f"🔁 Found {len(processed_keys)} previously processed entries. Skipping them...")

    logger.info("🧚 Running inference...")
    for entry in tqdm(val_data, desc="Processing"):
        try:
            frontal = entry.get("frontal")
            lateral = entry.get("lateral")
            key = f"{frontal}|{lateral}"

            if key in processed_keys:
                continue

            findings = entry.get("findings")
            impression = entry.get("impression")
            reference_text = f"{findings}. IMPRESSION: {impression}" if findings and impression else ""
            prompt = random.choice(REPORT_INSTRUCTIONS)

            if not frontal or not lateral or not reference_text:
                continue

            images = prepare_images(frontal, lateral)
            prediction, was_regenerated, regen_reason, bert_f1 = generate_response(
                model, tokenizer, images, prompt, reference_text
            )

            result = {
                "frontal": frontal,
                "lateral": lateral,
                "prompt": prompt,
                "reference": reference_text,
                "prediction": prediction,
                "regenerated": was_regenerated,
                "regeneration_reason": regen_reason,
                "bertscore_f1": bert_f1
            }

            existing_predictions[key] = result

            with open(OUTPUT_JSON, "w") as f:
                json.dump(list(existing_predictions.values()), f, indent=2)

        except Exception as e:
            logger.warning(f"[Skip] Error processing {entry.get('frontal')}: {str(e)}")

    logger.info(f"✅ Finished. Total entries saved: {len(existing_predictions)}")

if __name__ == "__main__":
    process_validation()
