In [None]:
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForPreTraining
from dataclasses import dataclass
import pandas as pd
import Levenshtein
import json
from tqdm.auto import tqdm
from typing import Tuple, List, Dict
import os
from pathlib import Path
import logging
from datetime import datetime



@dataclass
class PosixConfig:
    max_new_tokens: int = 20
    batched: bool = False

@dataclass
class PosixTrace:
    prompts: list
    responses: list
    logprob_matrices: list
    prompt_sensitivities: list
    posix: float

    
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

class LlavaModel:
    def __init__(self, model_path: str, device: str = "cuda:7"):
        self.processor = LlavaNextProcessor.from_pretrained(model_path)
        self.model = LlavaNextForConditionalGeneration.from_pretrained(model_path)
        self.device = device
        self.model.to(device)
    
    def get_responses(self, image_path: str, prompts: list[str], **kwargs) -> Tuple[list[list[int]], list[str], list[int]]:
        """Generate responses with proper instruction length tracking."""
        image = Image.open(image_path)
        response_tokens = []
        responses = []
        instruction_lengths = []
        
        for prompt in prompts:
            # Create unique conversation for each prompt
            conversation = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt.replace("<image>", "").strip()},
                        {"type": "image"}
                    ],
                }
            ]
            
            # Apply chat template with original prompt
            raw_prompt = self.processor.apply_chat_template(
                conversation, 
                add_generation_prompt=True,
                return_tensors=None
            )
            
            # Get instruction length before processing
            inst_length = len(self.processor.tokenizer(raw_prompt)['input_ids'])
            instruction_lengths.append(inst_length)
            
            # Process with image
            inputs = self.processor(
                images=image, 
                text=raw_prompt, 
                return_tensors="pt"
            ).to(self.device)
            
            with torch.no_grad():
                output = self.model.generate(
                    **inputs, 
                    max_new_tokens=kwargs.get('max_new_tokens', 20),
                    temperature=0.1,
                    do_sample=True
                )
            
            response_tokens.append(output[0].tolist())
            full_text = self.processor.decode(output[0], skip_special_tokens=True)
            response_only = full_text.split('[/INST]')[-1].strip()
            responses.append(response_only)
            
        return response_tokens, responses, instruction_lengths
    
    def compute_log_probabilties(self, image_path: str, prompt: str, response_tokens: list[int], instruction_length: int) -> float:
        """Compute log probabilities with proper token handling."""
        image = Image.open(image_path)
        
        # Create conversation with specific prompt
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt.replace("<image>", "").strip()},
                    {"type": "image"}
                ],
            }
        ]
        
        # Get raw prompt template
        raw_prompt = self.processor.apply_chat_template(
            conversation, 
            add_generation_prompt=True,
            return_tensors=None
        )
        
        # Get prompt tokens
        prompt_tokens = self.processor.tokenizer(raw_prompt)['input_ids']
        
        # Create full sequence
        full_sequence = prompt_tokens + response_tokens
        
        # Process with image
        inputs = self.processor(
            images=image,
            text=self.processor.tokenizer.decode(full_sequence),
            return_tensors="pt"
        ).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            
            # Get logits for response portion only
            response_logits = logits[:, instruction_length-1:instruction_length-1+len(response_tokens), :]
            log_probs = torch.log_softmax(response_logits, dim=-1)
            
            final_logprob = 0.0
            for i, token in enumerate(response_tokens):
                token_logprob = log_probs[0, i, token].item()
                final_logprob += token_logprob
                
        return final_logprob

class PromptSensitivityAnalyzer:
    @staticmethod
    def calculate_char_level_similarity(str1: str, str2: str) -> float:
        """Calculate character-level similarity using Levenshtein distance."""
        distance = Levenshtein.distance(str1, str2)
        max_len = max(len(str1), len(str2))
        return 1 - (distance / max_len)
    
    @staticmethod
    def extract_prompt_sets(json_data: List[Dict]) -> List[Dict[str, List[str]]]:
        """Extract prompt sets from the JSON data."""
        prompt_sets = []
        
        for entry in json_data:
            current_set = {
                'original': entry['question'],
                'variations': [
                    entry[f'variation_{i}'] for i in range(1, 11)
                ],
                'answer': entry['answer'],
                'image': entry['image']
            }
            prompt_sets.append(current_set)
        
        return prompt_sets
    
    @staticmethod
    def analyze_prompt_sensitivity(prompt_sets: List[Dict[str, List[str]]]) -> pd.DataFrame:
        """Analyze sensitivity between original prompts and variations."""
        results = []
        
        for set_idx, prompt_set in enumerate(prompt_sets, 1):
            original = prompt_set['original']
            
            for var_idx, variation in enumerate(prompt_set['variations'], 1):
                similarity = PromptSensitivityAnalyzer.calculate_char_level_similarity(original, variation)
                sensitivity = 1 - similarity
                
                results.append({
                    'Prompt Set': set_idx,
                    'Variation': f'Variation {var_idx}',
                    'Original': original,
                    'Variation Text': variation,
                    'Expected Answer': prompt_set['answer'],
                    'Image Path': prompt_set['image'],
                    'Similarity': round(similarity, 4),
                    'Sensitivity': round(sensitivity, 4)
                })
        
        return pd.DataFrame(results)


def write_trace_to_json(trace: PosixTrace, output_path: str):
    """Write analysis results to JSON file."""
    N = len(trace.prompts)
    to_write = []
    for i in range(N):
        to_write.append({
            "prompts": trace.prompts[i],
            "responses": trace.responses[i],
            "log_probability_matrix": trace.logprob_matrices[i],
            "prompt_sensitivity": trace.prompt_sensitivities[i]
        })
    with open(output_path, "w") as f:
        json.dump(to_write, f, indent=4)

class PromptSensitivityAnalyzer:
    @staticmethod
    def calculate_char_level_similarity(str1: str, str2: str) -> float:
        """Calculate character-level similarity using Levenshtein distance."""
        distance = Levenshtein.distance(str1, str2)
        max_len = max(len(str1), len(str2))
        return 1 - (distance / max_len)
    
    @staticmethod
    def extract_prompt_sets(json_data: Dict) -> List[Dict[str, List[str]]]:
        """Extract prompt sets from the JSON data."""
        # For single JSON object
        prompt_sets = []
        
        # Create a single prompt set from the JSON object
        variations = [json_data[f'variation_{i}'] for i in range(1, 11)]
        current_set = {
            'original': json_data['question'],
            'variations': variations,
            'answer': json_data['answer'],
            'image': json_data['image']
        }
        prompt_sets.append(current_set)
        
        return prompt_sets
    
    @staticmethod
    def analyze_prompt_sensitivity(prompt_sets: List[Dict[str, List[str]]]) -> pd.DataFrame:
        """Analyze sensitivity between original prompts and variations."""
        results = []
        
        for set_idx, prompt_set in enumerate(prompt_sets, 1):
            original = prompt_set['original']
            
            for var_idx, variation in enumerate(prompt_set['variations'], 1):
                similarity = PromptSensitivityAnalyzer.calculate_char_level_similarity(original, variation)
                sensitivity = 1 - similarity
                
                results.append({
                    'Prompt Set': set_idx,
                    'Variation': f'Variation {var_idx}',
                    'Original': original,
                    'Variation Text': variation,
                    'Expected Answer': prompt_set['answer'],
                    'Image Path': prompt_set['image'],
                    'Similarity': round(similarity, 4),
                    'Sensitivity': round(sensitivity, 4)
                })
        
        return pd.DataFrame(results)
    
    @staticmethod
    def get_summary_statistics(df: pd.DataFrame) -> pd.DataFrame:
        """Calculate summary statistics for each prompt set."""
        summary = df.groupby('Prompt Set').agg({
            'Sensitivity': ['mean', 'std', 'min', 'max']
        }).round(4)
        
        summary.columns = ['Mean Sensitivity', 'Std Sensitivity', 'Min Sensitivity', 'Max Sensitivity']
        return summary.reset_index()

def write_trace_to_json(trace: PosixTrace, output_path: str):
    """Write analysis results to JSON file."""
    N = len(trace.prompts)
    to_write = []
    for i in range(N):
        to_write.append({
            "prompts": trace.prompts[i],
            "responses": trace.responses[i],
            "log_probability_matrix": trace.logprob_matrices[i],
            "prompt_sensitivity": trace.prompt_sensitivities[i]
        })
    with open(output_path, "w") as f:
        json.dump(to_write, f, indent=4)

def get_llava_posix(
    model: LlavaModel,
    prompt_sets: List[Dict[str, List[str]]],
    config: PosixConfig,
    verbose: bool = False
) -> Tuple[float, PosixTrace, List[Dict]]:
    N_prompt_sets = len(prompt_sets)
    
    responses = []
    response_tokens = []
    logprob_matrices = []
    prompt_sensitivities = []
    detailed_scores = []
    
    pbar = tqdm(range(N_prompt_sets))
    for i in range(N_prompt_sets):
        prompt_set = prompt_sets[i]
        all_prompts = [prompt_set['original']] + prompt_set['variations']
        image_path = prompt_set['image']
        
        if verbose:
            print(f"\nProcessing prompt set {i+1}")
            print(f"Original prompt: {all_prompts[0]}")
            print(f"First variation: {all_prompts[1]}")
        
        # Generate responses for all prompts
        set_tokens, set_responses, instruction_lengths = model.get_responses(
            image_path,
            all_prompts,
            max_new_tokens=config.max_new_tokens
        )
        
        if verbose:
            print("\nGenerated responses:")
            for idx, resp in enumerate(set_responses):
                print(f"Response {idx}: {resp[:100]}...")
        
        response_tokens.append(set_tokens)
        responses.append(set_responses)
        
        N_prompts = len(all_prompts)
        logprob_matrix = [[0 for _ in range(N_prompts)] for _ in range(N_prompts)]
        
        # Calculate log probabilities for all prompt-response pairs
        for i in range(N_prompts):
            for j in range(N_prompts):
                response_only_tokens = set_tokens[j][instruction_lengths[j]:]
                logprob = model.compute_log_probabilties(
                    image_path,
                    all_prompts[i],
                    response_only_tokens,
                    instruction_lengths[i]
                )
                logprob_matrix[i][j] = logprob
                
                if verbose and i == 0:
                    print(f"\nLog probability for response {j} under original prompt: {logprob:.4f}")
        
        logprob_matrices.append(logprob_matrix)
        
        # Calculate prompt sensitivity
        psi = 0.0
        scores = {}
        for i in range(N_prompts):
            for j in range(N_prompts):
                if i != j:  # Only compare different prompt-response pairs
                    response_length = len(set_tokens[j][instruction_lengths[j]:])
                    diff = abs(logprob_matrix[i][j] - logprob_matrix[i][i]) / response_length
                    psi += diff
                    if i == 0:  # Track scores for original prompt
                        scores[f"Variation {j}"] = diff
                        
        prompt_sensitivity = psi / (N_prompts * (N_prompts - 1))
        prompt_sensitivities.append(prompt_sensitivity)
        detailed_scores.append(scores)
        
        if verbose:
            print(f"\nPrompt sensitivity: {prompt_sensitivity:.4f}")
        
        pbar.update(1)
    
    posix = sum(prompt_sensitivities) / N_prompt_sets
    
    trace = PosixTrace(
        [set['original'] for set in prompt_sets],
        responses,
        logprob_matrices,
        prompt_sensitivities,
        posix
    )
    
    return posix, trace, detailed_scores

def setup_logging():
    """Setup logging configuration"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_dir = Path('posix_results/logs')
    log_dir.mkdir(parents=True, exist_ok=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_dir / f'processing_{timestamp}.log'),
            logging.StreamHandler()
        ]
    )


def process_single_file(
    model: LlavaModel,
    file_path: Path,
    output_dir: Path,
    config: PosixConfig
) -> Dict:
    """Process a single JSON file and return results"""
    try:
        with open(file_path, 'r') as f:
            json_data = json.load(f)
        
        analyzer = PromptSensitivityAnalyzer()
        prompt_sets = analyzer.extract_prompt_sets(json_data)
        
        # Changed here to only expect 2 return values
        posix, trace = get_llava_posix(
            model,
            prompt_sets,
            config,
            verbose=False
        )
        
        # Calculate detailed scores separately if needed
        detailed_scores = []
        for i, prompt_set in enumerate(prompt_sets):
            scores = {}
            for j, variation in enumerate(prompt_set['variations'], 1):
                scores[f"Variation {j}"] = trace.prompt_sensitivities[i]
            detailed_scores.append(scores)
        
        # Prepare results
        results = {
            "file_name": file_path.name,
            "overall_posix": posix,
            "detailed_scores": detailed_scores,
            "trace": trace.__dict__
        }
        
        # Save individual file results
        output_file = output_dir / f"{file_path.stem}_results.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=4)
        
        return {
            "file_name": file_path.name,
            "posix": posix,
            "status": "success"
        }
        
    except Exception as e:
        logging.error(f"Error processing {file_path}: {str(e)}")
        return {
            "file_name": file_path.name,
            "posix": None,
            "status": "failed",
            "error": str(e)
        }


def main():
    # File paths and configuration
    model_path = "/shared/shashmi/llava-v1.6-mistral-7b-hf"
    input_dir = "/ephemeral/shashmi/posix_new_improved/Thesis/paraphrase_error_iuxray_variant"
    output_dir = "/ephemeral/shashmi/posix_new_improved/llava_1.6/new_paraphrase_result_posix"
    
    
    os.makedirs(output_dir, exist_ok=True)
    
    
    model = LlavaModel(model_path, device="cuda:5")
    config = PosixConfig(max_new_tokens=50)
    
    
    json_files = sorted(Path(input_dir).glob("question_*_variants.json"))
    print(f"Found {len(json_files)} files to process")
    
    
    all_results = []
    for file_path in tqdm(json_files, desc="Processing files"):
        try:
            # Read JSON file
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            # Process file
            analyzer = PromptSensitivityAnalyzer()
            prompt_sets = analyzer.extract_prompt_sets(json_data)
            
            # Get POSIX scores - now with all three return values
            posix, trace, detailed_scores = get_llava_posix(
                model,
                prompt_sets,
                config,
                verbose=False
            )
            
            # Save individual result
            output_file = os.path.join(output_dir, f"{file_path.stem}_results.json")
            with open(output_file, 'w') as f:
                json.dump({
                    "file_name": file_path.name,
                    "overall_posix": posix,
                    "detailed_scores": detailed_scores,
                    "trace": trace.__dict__
                }, f, indent=4)
            
            # Store summary
            all_results.append({
                "file_name": file_path.name,
                "posix_score": posix,
                "status": "success"
            })
            
        except Exception as e:
            print(f"Error processing {file_path.name}: {str(e)}")
            all_results.append({
                "file_name": file_path.name,
                "posix_score": None,
                "status": "failed",
                "error": str(e)
            })
            
        # Optional: Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Save summary results
    summary_df = pd.DataFrame(all_results)
    summary_df.to_csv(os.path.join(output_dir, "all_results_summary.csv"), index=False)

if __name__ == "__main__":
    main()
   

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.33s/it]


Found 400 files to process


Processing files:   0%|          | 0/400 [00:00<?, ?it/s]Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.
Setting `pad_token_id` to `eos

Error processing question_179_variants.json: [Errno 2] No such file or directory: '/ephemeral/shashmi/posix_new_improved/matching_subset_images/CXR3991_IM-2044-1001.png'


Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
100%|██████████| 1/1 [06:42<00:00, 402.97s/it]
Processing files:  22%|██▏       | 88/400 [10:21:09<28:33:36, 329.54s/it]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `e