In [None]:
import torch
from PIL import Image
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

def setup_model_on_device(model_path, device_id=0):
    # Set the device
    torch.cuda.set_device(device_id)
    device = f'cuda:{device_id}'
    
    # Disable torch initialization for faster loading
    disable_torch_init()
    
    # Load the model and components
    processor = load_pretrained_model(
        model_path=model_path,
        model_base=None,
        model_name='llava-med-v1.5-mistral-7b'
    )
    
    # Unpack processor
    tokenizer, model, image_processor, context_len = processor
    
    # Move model to specified device
    model = model.to(device)
    
    return (image_processor, tokenizer), model, device

def process_image_and_generate(image_path, prompt, processor, model, vlm_args):
    # Unpack processor
    image_processor, tokenizer = processor
    
    # Load and process image
    raw_image = Image.open(image_path)
    image_tensor = image_processor.preprocess(
        raw_image, 
        return_tensors="pt"
    )["pixel_values"][0]
    
    # Prepare inputs
    inputs = tokenizer([prompt])
    input_ids = torch.as_tensor(inputs.input_ids).to(model.device)
    
    # Generate response
    with torch.inference_mode():
        output = model.generate(
            input_ids,
            images=image_tensor.unsqueeze(0).half().to(model.device),
            max_new_tokens=vlm_args["max_new_tokens"],
        )
    
    # Decode response
    result = tokenizer.batch_decode(output, skip_special_tokens=True)[0].strip()
    return result

def main():
    model_path = '/share/ssddata/sarimhashmi/posix/llava_med/llava-med'
    image_path = '/share/ssddata/sarimhashmi/iuxray/vanillah_really/matching_subset_images/CXR3716_IM-1856-2001.png'
    # prompt = "[INST] <image> what is this image of? [/INST]"
    prompt = "[INST] <image> Does the X-ray image show any signs of pneumothorax? Please choose from the following two options: [yes, no] [/INST]"
    
    # Configuration for generation
    vlm_args = {
        "max_new_tokens": 512
    }
    
    # Setup model with specific device
    processor, model, device = setup_model_on_device(model_path, device_id=1)
    
    # Process image and generate response
    response = process_image_and_generate(image_path, prompt, processor, model, vlm_args)
    print("Model Response:", response)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.82it/s]
Some weights of the model checkpoint at /share/ssddata/sarimhashmi/posix/llava_med/llava-med were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.11.self_attn.v_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.20.self_attn.out_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.9.self_attn.v_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.8.self_attn.v_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.20.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.2.mlp.fc2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.16.self_attn.out_proj.weight', 'm

Model Response: The chest x-ray image does not show any signs of pneumothorax.


In [1]:
import torch
from PIL import Image
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from dataclasses import dataclass
import pandas as pd
import Levenshtein
import json
from tqdm.auto import tqdm
from typing import Tuple, List, Dict
import os
from pathlib import Path
import logging
from datetime import datetime

@dataclass
class PosixConfig:
    max_new_tokens: int = 20
    device_id: int = 0
    batched: bool = False

@dataclass
class PosixTrace:
    prompts: list
    responses: list
    logprob_matrices: list
    prompt_sensitivities: list
    posix: float

class LLaVaMedModel:
    def __init__(self, model_path: str, device_id: int = 0):
        """Initialize LLaVA-med model and components."""
        # Set the device
        torch.cuda.set_device(device_id)
        self.device = f'cuda:{device_id}'
        
        # Disable torch initialization for faster loading
        disable_torch_init()
        
        # Load the model and components
        processor = load_pretrained_model(
            model_path=model_path,
            model_base=None,
            model_name='llava-med-v1.5-mistral-7b'
        )
        
        # Unpack processor
        self.tokenizer, self.model, self.image_processor, self.context_len = processor
        
        # Move model to specified device
        self.model = self.model.to(self.device)
    
    def format_prompt(self, question: str) -> str:
        """Format the question using LLaVA's instruction template."""
        return f"[INST] <image> {question} [/INST]"
    
    def get_responses(self, image_path: str, prompts: list[str], max_new_tokens: int = 20) -> Tuple[list[list[int]], list[str], list[int]]:
        """Generate responses for a set of prompts."""
        response_tokens = []
        responses = []
        instruction_lengths = []
        
        # Load and process image once
        raw_image = Image.open(image_path)
        image_tensor = self.image_processor.preprocess(
            raw_image, 
            return_tensors="pt"
        )["pixel_values"][0]
        
        for prompt in prompts:
            # Format prompt with instruction template
            formatted_prompt = self.format_prompt(prompt)
            
            # Prepare inputs
            inputs = self.tokenizer([formatted_prompt])
            input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
            instruction_lengths.append(len(input_ids[0]))
            
            # Generate response
            with torch.inference_mode():
                output = self.model.generate(
                    input_ids,
                    images=image_tensor.unsqueeze(0).half().to(self.device),
                    max_new_tokens=max_new_tokens,
                )
            
            # Get response tokens and text
            response_tokens.append(output[0].tolist())
            response_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
            responses.append(response_text)
        
        return response_tokens, responses, instruction_lengths
    
    def compute_log_probabilities(self, image_path: str, prompt: str, response_tokens: list[int], instruction_length: int) -> float:
        """Compute log probabilities for a prompt-response pair."""
        # Load and process image
        raw_image = Image.open(image_path)
        image_tensor = self.image_processor.preprocess(
            raw_image, 
            return_tensors="pt"
        )["pixel_values"][0]
        
        # Format prompt with instruction template
        formatted_prompt = self.format_prompt(prompt)
        
        # Prepare inputs
        inputs = self.tokenizer([formatted_prompt])
        input_ids = torch.as_tensor(inputs.input_ids).to(self.device)
        
        # Combine input and response tokens
        full_sequence = torch.cat([
            input_ids,
            torch.tensor([response_tokens]).to(self.device)
        ], dim=1)
        
        with torch.inference_mode():
            outputs = self.model(
                full_sequence,
                images=image_tensor.unsqueeze(0).half().to(self.device)
            )
            logits = outputs.logits
            
            # Get logits for response portion only
            response_logits = logits[:, instruction_length-1:instruction_length-1+len(response_tokens), :]
            log_probs = torch.log_softmax(response_logits, dim=-1)
            
            final_logprob = 0.0
            for i, token in enumerate(response_tokens):
                token_logprob = log_probs[0, i, token].item()
                final_logprob += token_logprob
                
        return final_logprob

class PromptSensitivityAnalyzer:
    @staticmethod
    def calculate_char_level_similarity(str1: str, str2: str) -> float:
        """Calculate character-level similarity using Levenshtein distance."""
        distance = Levenshtein.distance(str1, str2)
        max_len = max(len(str1), len(str2))
        return 1 - (distance / max_len)
    
    @staticmethod
    def extract_prompt_sets(json_data: Dict) -> List[Dict[str, List[str]]]:
        """Extract prompt sets from the JSON data."""
        prompt_sets = []
        
        variations = [json_data[f'variation_{i}'] for i in range(1, 11)]
        current_set = {
            'original': json_data['question'],
            'variations': variations,
            'answer': json_data['answer'],
            'image': json_data['image']
        }
        prompt_sets.append(current_set)
        
        return prompt_sets

def get_llava_posix(
    model: LLaVaMedModel,
    prompt_sets: List[Dict[str, List[str]]],
    config: PosixConfig,
    verbose: bool = False
) -> Tuple[float, PosixTrace, List[Dict]]:
    """Calculate POSIX scores using LLaVA-med model."""
    N_prompt_sets = len(prompt_sets)
    
    responses = []
    response_tokens = []
    logprob_matrices = []
    prompt_sensitivities = []
    detailed_scores = []
    
    pbar = tqdm(range(N_prompt_sets))
    for i in range(N_prompt_sets):
        prompt_set = prompt_sets[i]
        all_prompts = [prompt_set['original']] + prompt_set['variations']
        image_path = prompt_set['image']
        
        if verbose:
            print(f"\nProcessing prompt set {i+1}")
            print(f"Original prompt: {model.format_prompt(all_prompts[0])}")
            print(f"First variation: {model.format_prompt(all_prompts[1])}")
        
        # Generate responses for all prompts
        set_tokens, set_responses, instruction_lengths = model.get_responses(
            image_path,
            all_prompts,
            max_new_tokens=config.max_new_tokens
        )
        
        response_tokens.append(set_tokens)
        responses.append(set_responses)
        
        N_prompts = len(all_prompts)
        logprob_matrix = [[0 for _ in range(N_prompts)] for _ in range(N_prompts)]
        
        # Calculate log probabilities for all prompt-response pairs
        for i in range(N_prompts):
            for j in range(N_prompts):
                logprob = model.compute_log_probabilities(
                    image_path,
                    all_prompts[i],
                    set_tokens[j],
                    instruction_lengths[i]
                )
                logprob_matrix[i][j] = logprob
        
        logprob_matrices.append(logprob_matrix)
        
        # Calculate prompt sensitivity
        psi = 0.0
        scores = {}
        for i in range(N_prompts):
            for j in range(N_prompts):
                if i != j:
                    response_length = len(set_tokens[j])
                    diff = abs(logprob_matrix[i][j] - logprob_matrix[i][i]) / response_length
                    psi += diff
                    if i == 0:
                        scores[f"Variation {j}"] = diff
                        
        prompt_sensitivity = psi / (N_prompts * (N_prompts - 1))
        prompt_sensitivities.append(prompt_sensitivity)
        detailed_scores.append(scores)
        
        if verbose:
            print(f"\nPrompt sensitivity: {prompt_sensitivity:.4f}")
        
        pbar.update(1)
    
    posix = sum(prompt_sensitivities) / N_prompt_sets
    
    trace = PosixTrace(
        [set['original'] for set in prompt_sets],
        responses,
        logprob_matrices,
        prompt_sensitivities,
        posix
    )
    
    return posix, trace, detailed_scores

def main():
    # File paths and configuration
    model_path = "/share/ssddata/sarimhashmi/posix/llava_med/llava-med"
    input_dir = "/share/ssddata/sarimhashmi/posix_thesis/new_improve_stuff/Thesis/spell_demo"
    output_dir = "/share/ssddata/sarimhashmi/posix_thesis/new_improve_stuff/llava_med_i_swear_final/spell_error_result"
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize model and config
    model = LLaVaMedModel(model_path, device_id=2)
    config = PosixConfig(max_new_tokens=50, device_id=2)
    
    # Get list of JSON files to process
    json_files = sorted(Path(input_dir).glob("question_*_variants.json"))
    print(f"Found {len(json_files)} files to process")
    
    # Process each file
    all_results = []
    for file_path in tqdm(json_files, desc="Processing files"):
        try:
            # Read JSON file
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            # Process file
            analyzer = PromptSensitivityAnalyzer()
            prompt_sets = analyzer.extract_prompt_sets(json_data)
            
            # Calculate POSIX scores
            posix, trace, detailed_scores = get_llava_posix(
                model,
                prompt_sets,
                config,
                verbose=False
            )
            
            # Save individual result
            output_file = os.path.join(output_dir, f"{file_path.stem}_results.json")
            with open(output_file, 'w') as f:
                json.dump({
                    "file_name": file_path.name,
                    "overall_posix": posix,
                    "detailed_scores": detailed_scores,
                    "trace": trace.__dict__
                }, f, indent=4)
            
            # Store summary
            all_results.append({
                "file_name": file_path.name,
                "posix_score": posix,
                "status": "success"
            })
            
        except Exception as e:
            print(f"Error processing {file_path.name}: {str(e)}")
            all_results.append({
                "file_name": file_path.name,
                "posix_score": None,
                "status": "failed",
                "error": str(e)
            })
            
        # Optional: Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Save summary results
    summary_df = pd.DataFrame(all_results)
    summary_df.to_csv(os.path.join(output_dir, "all_results_summary.csv"), index=False)

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.81it/s]
Some weights of the model checkpoint at /share/ssddata/sarimhashmi/posix/llava_med/llava-med were not used when initializing LlavaMistralForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.3.self_attn.k_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.17.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.8.self_attn.k_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.12.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.18.self_attn.out_proj.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.18.self_attn.v_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.11.layer_norm2.weight', 

Found 1 files to process


Processing files:   0%|          | 0/1 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to