In [1]:
import torch
from PIL import Image
from dataclasses import dataclass
import pandas as pd
import Levenshtein
import json
from tqdm.auto import tqdm
from typing import Tuple, List, Dict
import os
from pathlib import Path
import logging
from datetime import datetime
from open_flamingo import create_model_and_transforms
from einops import repeat
import sys
sys.path.append('..')
from src.utils import FlamingoProcessor
from demo_utils import image_paths, clean_generation


@dataclass
class PosixConfig:
    max_new_tokens: int = 50
    temperature: float = 0.1
    llama_path: str = '/shared/shashmi/llama-7b'
    checkpoint_path: str = '/shared/shashmi/med-flamingo/model.pt'

@dataclass
class PosixTrace:
    prompts: list
    responses: list
    sensitivities: list
    posix: float

class MedFlamingoModel:
    def __init__(self, config: PosixConfig, device: str = "cuda:7"):
        """Initialize Med-Flamingo model."""
        self.config = config
        self.device = device
        
        print('Loading model...')
        self.model, image_processor, tokenizer = create_model_and_transforms(
            clip_vision_encoder_path="ViT-L-14",
            clip_vision_encoder_pretrained="openai",
            lang_encoder_path=config.llama_path,
            tokenizer_path=config.llama_path,
            cross_attn_every_n_layers=4
        )
        
        self.model.load_state_dict(
            torch.load(config.checkpoint_path, map_location=device),
            strict=False
        )
        self.processor = FlamingoProcessor(tokenizer, image_processor)
        self.model.eval()
        self.model.to(device)
    
    def get_responses(self, image_path: str, prompts: list[str], **kwargs) -> list[str]:
        """Generate responses for given prompts."""
        responses = []
        
        # Load and process image
        image = Image.open(image_path)
        pixels = self.processor.preprocess_images([image])
        pixels = repeat(pixels, 'N c h w -> b N T c h w', b=1, T=1)
        
        for prompt in prompts:
            # Process text
            tokenized_data = self.processor.encode_text(prompt)
            
            # Generate response
            with torch.no_grad():
                generated_text = self.model.generate(
                    vision_x=pixels.to(self.device),
                    lang_x=tokenized_data["input_ids"].to(self.device),
                    attention_mask=tokenized_data["attention_mask"].to(self.device),
                    max_new_tokens=kwargs.get('max_new_tokens', self.config.max_new_tokens),
                    temperature=self.config.temperature
                )
            
            response = self.processor.tokenizer.decode(generated_text[0])
            responses.append(response)
            
        return responses

class PromptSensitivityAnalyzer:
    @staticmethod
    def calculate_char_level_similarity(str1: str, str2: str) -> float:
        """Calculate character-level similarity using Levenshtein distance."""
        distance = Levenshtein.distance(str1, str2)
        max_len = max(len(str1), len(str2))
        return 1 - (distance / max_len)
    
    @staticmethod
    def extract_prompt_sets(json_data: Dict) -> List[Dict[str, List[str]]]:
        """Extract prompt sets from the JSON data."""
        prompt_sets = []
        
        variations = [json_data[f'variation_{i}'] for i in range(1, 11)]
        current_set = {
            'original': json_data['question'],
            'variations': variations,
            'answer': json_data['answer'],
            'image': json_data['image']
        }
        prompt_sets.append(current_set)
        
        return prompt_sets

def get_medflamingo_posix(
    model: MedFlamingoModel,
    prompt_sets: List[Dict[str, List[str]]],
    config: PosixConfig,
    verbose: bool = False
) -> Tuple[float, PosixTrace, List[Dict]]:
    """Calculate POSIX scores using response similarity instead of log probabilities."""
    N_prompt_sets = len(prompt_sets)
    
    responses = []
    sensitivities = []
    detailed_scores = []
    
    pbar = tqdm(range(N_prompt_sets))
    for i in range(N_prompt_sets):
        prompt_set = prompt_sets[i]
        all_prompts = [prompt_set['original']] + prompt_set['variations']
        image_path = prompt_set['image']
        
        if verbose:
            print(f"\nProcessing prompt set {i+1}")
            print(f"Original prompt: {all_prompts[0]}")
            print(f"First variation: {all_prompts[1]}")
        
        # Generate responses for all prompts
        set_responses = model.get_responses(
            image_path,
            all_prompts,
            max_new_tokens=config.max_new_tokens
        )
        responses.append(set_responses)
        
        # Calculate sensitivities based on response similarities
        scores = {}
        set_sensitivities = []
        
        original_response = set_responses[0]
        for var_idx, var_response in enumerate(set_responses[1:], 1):
            similarity = PromptSensitivityAnalyzer.calculate_char_level_similarity(
                original_response,
                var_response
            )
            sensitivity = 1 - similarity
            set_sensitivities.append(sensitivity)
            scores[f"Variation {var_idx}"] = sensitivity
            
        detailed_scores.append(scores)
        
        # Calculate average sensitivity for this set
        set_sensitivity = sum(set_sensitivities) / len(set_sensitivities) if set_sensitivities else 0.0
        sensitivities.append(set_sensitivity)
        
        if verbose:
            print(f"\nPrompt sensitivity: {set_sensitivity:.4f}")
        
        pbar.update(1)
    
    # Calculate overall POSIX score
    posix = sum(sensitivities) / N_prompt_sets if sensitivities else 0.0
    
    trace = PosixTrace(
        [set['original'] for set in prompt_sets],
        responses,
        sensitivities,
        posix
    )
    
    return posix, trace, detailed_scores

def setup_logging(output_dir: str):
    """Setup logging configuration"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    log_dir = Path(output_dir) / 'logs'
    log_dir.mkdir(parents=True, exist_ok=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_dir / f'processing_{timestamp}.log'),
            logging.StreamHandler()
        ]
    )

def main():
    # File paths and configuration
    model_path = "/shared/shashmi/med-flamingo/model.pt"
    input_dir = "/ephemeral/shashmi/posix_new_improved/Thesis/spell_error_question_variants"
    output_dir = "/ephemeral/shashmi/posix_new_improved/medflamingo/med-flamingo/scripts/spell_error_posix"
    
    # Setup logging
    setup_logging(output_dir)
    logging.info(f"Starting analysis with model: {model_path}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize configuration and model
    config = PosixConfig(checkpoint_path=model_path)
    model = MedFlamingoModel(config)
    
    # Get list of JSON files
    json_files = sorted(Path(input_dir).glob("question_*_variants.json"))
    logging.info(f"Found {len(json_files)} files to process")
    
    # Process each file
    all_results = []
    for file_path in tqdm(json_files, desc="Processing files"):
        try:
            # Read JSON file
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            # Process file
            analyzer = PromptSensitivityAnalyzer()
            prompt_sets = analyzer.extract_prompt_sets(json_data)
            
            posix, trace, detailed_scores = get_medflamingo_posix(
                model,
                prompt_sets,
                config,
                verbose=False
            )
            
            # Save individual result
            output_file = Path(output_dir) / f"{file_path.stem}_results.json"
            with open(output_file, 'w') as f:
                json.dump({
                    "file_name": file_path.name,
                    "overall_posix": posix,
                    "detailed_scores": detailed_scores,
                    "trace": trace.__dict__
                }, f, indent=4)
            
            # Store summary
            all_results.append({
                "file_name": file_path.name,
                "posix_score": posix
            })
            
            logging.info(f"Successfully processed {file_path.name} with POSIX score: {posix:.4f}")
            
        except Exception as e:
            logging.error(f"Error processing {file_path.name}: {str(e)}")
            all_results.append({
                "file_name": file_path.name,
                "posix_score": None,
                "error": str(e)
            })
        
        # Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    # Save summary results
    summary_df = pd.DataFrame(all_results)
    summary_df.to_csv(os.path.join(output_dir, "all_results_summary.csv"), index=False)
    logging.info("Analysis complete. Results saved to CSV.")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm
2024-11-28 08:11:00,419 - INFO - Starting analysis with model: /shared/shashmi/med-flamingo/model.pt
2024-11-28 08:11:00,421 - INFO - Loaded ViT-L-14 model config.


Loading model...


2024-11-28 08:11:03,791 - INFO - Loading pretrained ViT-L-14 weights (openai).
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.26s/it]


Flamingo model initialized with 1309919248 trainable parameters


2024-11-28 08:12:22,474 - INFO - Found 400 files to process
100%|██████████| 1/1 [00:30<00:00, 30.78s/it]00<?, ?it/s]
2024-11-28 08:12:53,262 - INFO - Successfully processed question_100_variants.json with POSIX score: 0.2841
100%|██████████| 1/1 [00:30<00:00, 30.86s/it]31<3:32:10, 31.91s/it]
2024-11-28 08:13:25,246 - INFO - Successfully processed question_101_variants.json with POSIX score: 0.1939
100%|██████████| 1/1 [00:30<00:00, 30.95s/it]02<3:27:40, 31.31s/it]
2024-11-28 08:13:56,227 - INFO - Successfully processed question_102_variants.json with POSIX score: 0.1630
100%|██████████| 1/1 [00:30<00:00, 30.04s/it]33<3:26:09, 31.16s/it]
2024-11-28 08:14:26,296 - INFO - Successfully processed question_103_variants.json with POSIX score: 0.0387
100%|██████████| 1/1 [00:30<00:00, 30.15s/it]03<3:22:47, 30.73s/it]
2024-11-28 08:14:56,471 - INFO - Successfully processed question_104_variants.json with POSIX score: 0.1055
100%|██████████| 1/1 [00:30<00:00, 30.14s/it]34<3:20:58, 30.53s/it]
20