In [None]:
import os
import json
import random
import torch
import logging
from PIL import Image
from tqdm import tqdm

from transformers import AutoTokenizer, CLIPImageProcessor
from llava_phi.model import LlavaPhiForCausalLM
from llava_phi.constants import DEFAULT_IMAGE_TOKEN
from llava_phi.conversation import conv_templates
from llava_phi.utils import disable_torch_init
from transformers.generation.utils import GenerationMixin
from transformers import logging as hf_logging


MODEL_PATH = "/media/volume/Slava/Dual-View-Slava-Final"
IMAGE_FOLDER = "/media/volume/Slava/MIMIC_Dataset224"
INPUT_JSON = "slava_llava_split_test.json"
OUTPUT_JSON = "Duaal_slava_llava_predict.json"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TORCH_DTYPE = torch.float16
MANUAL_SEED = 42

hf_logging.set_verbosity_error()

REPORT_INSTRUCTIONS = [
    "<image> You are a board-certified radiologist. Analyze the frontal and lateral chest X-rays. Identify and describe all radiographic findings, including location, size, density, and distribution. Provide a structured report with FINDINGS and IMPRESSION. Ensure clinical relevance and diagnostic precision.",
    "<image> Evaluate the dual-view CXRs thoroughly. Identify all signs of disease, subtle or prominent. For each finding, describe anatomical position, severity, and possible differential diagnoses. Conclude with a well-reasoned and structured IMPRESSION.",
    "<image> Generate a detailed radiology report using standard clinical format. Include all abnormal and incidental findings in the FINDINGS section. Summarize key diagnostic insights in the IMPRESSION. Avoid generic language; favor precise anatomical and pathological descriptors.",
    "<image> Carefully examine the frontal and lateral chest X-rays. Write a complete radiology report structured into: FINDINGS (organized by organ system and zones) and IMPRESSION (summary with clinical prioritization). Include all deviations from normal, no matter how subtle.",
    "<image> Interpret these dual-view CXRs. Document all radiologic abnormalities by describing their appearance, location (lung zones, mediastinum, pleura, bones), and clinical implications. Format the response as a formal radiology report with clear headings for FINDINGS and IMPRESSION.",
    "<image> Examine the provided chest X-rays from both frontal and lateral views. Accurately identify and describe any abnormal radiologic signs, even subtle or borderline cases. Structure the report with FINDINGS and a diagnostic IMPRESSION.",
    "<image> Perform a radiological assessment of these dual-view CXRs. List all notable observations and pathological signs, with attention to anatomical detail and clinical context. Format the response as a full report: FINDINGS followed by IMPRESSION.",
    "<image> Review the chest radiographs and generate a report that includes all visible abnormalities. Pay attention to symmetry, lung markings, cardiac silhouette, and bony structures. Use the standard format with FINDINGS and IMPRESSION.",
    "<image> You are an expert thoracic radiologist. Describe all pathological and incidental findings in these frontal and lateral CXRs. Be thorough and concise. Conclude with a structured IMPRESSION summarizing the clinical picture.",
    "<image> Analyze these chest X-rays using your clinical expertise. Report every abnormality using specific radiologic terminology. Clearly differentiate FINDINGS and IMPRESSION. Include zone-wise, side-wise, and severity-based descriptions."
]


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class LlavaWithGenerate(LlavaPhiForCausalLM, GenerationMixin):
    def generate(self, *args, **kwargs):
        self.image_features = None
        self.images = kwargs.pop("images", None)
        return super().generate(*args, **kwargs)


def load_model_and_tokenizer():
    disable_torch_init()
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    if "<image>" not in tokenizer.get_vocab():
        tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = LlavaWithGenerate.from_pretrained(MODEL_PATH, torch_dtype=TORCH_DTYPE, low_cpu_mem_usage=True)
    model.resize_token_embeddings(len(tokenizer))
    model = model.to(DEVICE)

    model.model.bypass_vision_tower = False
    image_token_id = tokenizer.convert_tokens_to_ids("<image>")
    assert image_token_id < model.get_input_embeddings().num_embeddings

    if torch.__version__ >= "2" and torch.cuda.is_available():
        try:
            model = torch.compile(model)
        except Exception as e:
            logger.warning(f"torch.compile() failed: {e}")

    return model, tokenizer, image_token_id


def prepare_conversation(prompt):
    template_name = next((name for name in ["llava_phi", "phi", "v0", "med", "vicuna_v1"] if name in conv_templates), None)
    conv = conv_templates[template_name].copy()
    conv.append_message(conv.roles[0], DEFAULT_IMAGE_TOKEN + "\n" + prompt)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()


IMAGE_PROCESSOR = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")

def expand2square(pil_img, background_color=(0, 0, 0)):
    w, h = pil_img.size
    if w == h:
        return pil_img
    elif w > h:
        result = Image.new(pil_img.mode, (w, w), background_color)
        result.paste(pil_img, (0, (w - h) // 2))
    else:
        result = Image.new(pil_img.mode, (h, h), background_color)
        result.paste(pil_img, ((h - w) // 2, 0))
    return result

def prepare_images(frontal_path, lateral_path):
    image_paths = [frontal_path, lateral_path]
    images = []
    bg_color = tuple(int(x * 255) for x in IMAGE_PROCESSOR.image_mean)
    for img_path in image_paths:
        with Image.open(os.path.join(IMAGE_FOLDER, img_path)) as img:
            img = img.convert("RGB")
            img = expand2square(img, bg_color)
            image_tensor = IMAGE_PROCESSOR(img, return_tensors="pt")['pixel_values'][0]
            images.append(image_tensor)
    return torch.stack(images).contiguous()


def tokenizer_image_token(prompt, tokenizer, image_token_index, return_tensors=None):
    prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]

    def insert_separator(X, sep):
        return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]

    input_ids = []
    offset = 0
    if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
        offset = 1
        input_ids.append(prompt_chunks[0][0])
    for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
        input_ids.extend(x[offset:])

    if return_tensors == 'pt':
        return torch.tensor(input_ids, dtype=torch.long)
    return input_ids


def generate_response(model, tokenizer, images, prompt, reference, image_token_index):
    full_prompt = prepare_conversation(prompt)
    input_ids = tokenizer_image_token(full_prompt, tokenizer, image_token_index, return_tensors="pt").unsqueeze(0).to(DEVICE)
    attention_mask = (input_ids != tokenizer.pad_token_id).long().to(DEVICE)
    images = images.unsqueeze(0).to(DEVICE)

    with torch.inference_mode(), torch.autocast(device_type='cuda', dtype=TORCH_DTYPE):
        try:
            output_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=images,
                do_sample=True,
                temperature=0.2,
                top_p=0.8,
                max_new_tokens=1024,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
        except RuntimeError as e:
            if "indexSelectLargeIndex" in str(e):
                return "[ERROR: CUDA index bug]"
            raise e

    gen_tokens = output_ids[0, input_ids.shape[1]:]
    return tokenizer.decode(gen_tokens, skip_special_tokens=True).strip()


def process_validation():
    logger.info("Loading model and tokenizer...")
    model, tokenizer, image_token_index = load_model_and_tokenizer()
    torch.manual_seed(MANUAL_SEED)

    logger.info("Reading annotation file...")
    with open(INPUT_JSON, "r") as f:
        val_data = json.load(f)

    logger.info("Checking for existing output...")
    existing_predictions = {}
    if os.path.exists(OUTPUT_JSON):
        with open(OUTPUT_JSON, "r") as f:
            for item in json.load(f):
                key = f"{item['frontal']}|{item['lateral']}"
                existing_predictions[key] = item

    processed_keys = set(existing_predictions.keys())
    logger.info(f"Found {len(processed_keys)} previously processed entries. Skipping them...")

    for entry in tqdm(val_data, desc="Processing"):
        try:
            findings = entry.get("findings")
            impression = entry.get("impression")
            reference_text = f"{findings}. IMPRESSION: {impression}" if findings and impression else ""
            frontal = entry.get("frontal")
            lateral = entry.get("lateral")
            key = f"{frontal}|{lateral}"

            if key in processed_keys:
                continue
            if not frontal or not lateral or not reference_text:
                continue

            prompt = random.choice(REPORT_INSTRUCTIONS)
            images = prepare_images(frontal, lateral)
            prediction = generate_response(model, tokenizer, images, prompt, reference_text, image_token_index)

            result = {
                "frontal": frontal,
                "lateral": lateral,
                "prompt": prompt,
                "reference": reference_text,
                "prediction": prediction
            }

            existing_predictions[key] = result
            with open(OUTPUT_JSON, "w") as f:
                json.dump(list(existing_predictions.values()), f, indent=2)

        except Exception as e:
            logger.warning(f"[Skip] Error processing {entry.get('frontal')}: {str(e)}")

    logger.info("Completed. Saved")

if __name__ == "__main__":
    process_validation()