In [16]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from scipy import stats
from typing import Dict, List, Tuple, Optional
import json
from tqdm import tqdm
import logging
import gc
import GPUtil
import ast
from dataclasses import dataclass

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
@dataclass
class AttentionConfig:
    """Configuration for MedGemma attention analysis"""
    image_start: int = 1
    image_end: int = 257
    num_patches: int = 256
    patch_grid_size: int = 16
    num_heads: int = 8
    num_layers: int = 34

In [4]:
class MedGemmaAttentionExtractor:

    def __init__(self, model, processor, device='cuda'):
        self.model = model
        self.processor = processor
        self.device = device
        self.config = AttentionConfig()
        self.model.eval()

        # Storage for debugging
        self.last_extraction_debug = {}

    def prepare_inputs(self, image_path: str, question: str, options: List[str]) -> Dict:
        """Prepare inputs using your working approach"""

        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')

        # Create clinical prompt
        prompt = self._create_clinical_prompt(question, options)

        # Create messages format that works with MedGemma
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image", "image": image}
                ]
            }
        ]

        # Process inputs using your working method
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        )

        # Move to GPU
        inputs_gpu = {
            k: v.to(self.device) if torch.is_tensor(v) else v
            for k, v in inputs.items()
        }

        # Store debug info
        self.last_extraction_debug['input_length'] = inputs_gpu['input_ids'].shape[1]

        return inputs_gpu, image

    def _create_clinical_prompt(self, question: str, options: List[str]) -> str:
        """Create prompt optimized for medical analysis"""
        valid_options = [opt for opt in options if opt]

        prompt = f"""Analyze this chest X-ray and answer: {question}
Options: {', '.join(valid_options)}

Focus on the relevant anatomical regions for this question."""

        return prompt

    def extract_attention_safe(self, outputs, inputs, token_idx: int = 0) -> np.ndarray:
        """
        Safe extraction based on your working code
        Handles MedGemma's fixed attention matrix structure
        """
        try:
            if token_idx >= len(outputs.attentions):
                token_idx = len(outputs.attentions) - 1

            # Get attention for this token (last layer)
            attn = outputs.attentions[token_idx][-1].cpu().float()

            if len(attn.shape) == 4:
                attn = attn[0]  # Remove batch dimension

            # Average over heads
            avg_attn = attn.mean(dim=0)

            # MedGemma uses fixed size, find the right position
            input_len = inputs['input_ids'].shape[1]

            # The attention matrix doesn't grow, so we use the last valid position
            if avg_attn.shape[0] == avg_attn.shape[1]:
                gen_pos = min(input_len + token_idx, avg_attn.shape[0] - 1)
            else:
                gen_pos = avg_attn.shape[0] - 1

            # Extract attention to image tokens
            if gen_pos >= 0 and self.config.image_end <= avg_attn.shape[1]:
                attn_to_image = avg_attn[gen_pos, self.config.image_start:self.config.image_end]

                # Verify we got 256 values (16x16 patches)
                if len(attn_to_image) == self.config.num_patches:
                    return attn_to_image.reshape(self.config.patch_grid_size,
                                                self.config.patch_grid_size).numpy()

            logger.warning(f"Could not extract valid attention for token {token_idx}")
            return np.zeros((self.config.patch_grid_size, self.config.patch_grid_size))

        except Exception as e:
            logger.error(f"Error extracting attention: {e}")
            return np.zeros((self.config.patch_grid_size, self.config.patch_grid_size))

    def compute_layer_weighted_relevancy(self, outputs, inputs, token_idx: int = 0) -> np.ndarray:
        """
        Compute relevancy across all layers with increasing weights for deeper layers
        Based on your relevancy analyzer implementation
        """
        relevancy_scores = []

        if token_idx < len(outputs.attentions):
            token_attentions = outputs.attentions[token_idx]

            for layer_idx, layer_attn in enumerate(token_attentions):
                if torch.is_tensor(layer_attn):
                    layer_attn = layer_attn.cpu().float()

                    if len(layer_attn.shape) == 4:
                        layer_attn = layer_attn[0]

                    # Average over heads
                    layer_attn = layer_attn.mean(dim=0)

                    # Find position
                    input_len = inputs['input_ids'].shape[1]
                    src_pos = min(input_len + token_idx, layer_attn.shape[0] - 1)

                    # Extract attention to image
                    if src_pos >= 0 and self.config.image_end <= layer_attn.shape[1]:
                        attn_to_image = layer_attn[src_pos, self.config.image_start:self.config.image_end]

                        # Weight by layer depth
                        layer_weight = (layer_idx + 1) / len(token_attentions)
                        weighted_attn = attn_to_image * layer_weight

                        relevancy_scores.append(weighted_attn)

        if relevancy_scores:
            final_relevancy = torch.stack(relevancy_scores).mean(dim=0)
            return final_relevancy.reshape(self.config.patch_grid_size,
                                         self.config.patch_grid_size).numpy()

        return np.zeros((self.config.patch_grid_size, self.config.patch_grid_size))

    def compute_head_importance(self, outputs, inputs, token_idx: int = 0) -> Dict:
        """Identify which attention heads are most important for this token"""

        if token_idx >= len(outputs.attentions):
            return {}

        # Get last layer attention
        last_layer_attn = outputs.attentions[token_idx][-1].cpu().float()

        if len(last_layer_attn.shape) == 4:
            last_layer_attn = last_layer_attn[0]

        head_importance = {}
        input_len = inputs['input_ids'].shape[1]
        src_pos = min(input_len + token_idx, last_layer_attn.shape[0] - 1)

        for h in range(self.config.num_heads):
            head_attn = last_layer_attn[h]

            if src_pos >= 0 and self.config.image_end <= head_attn.shape[1]:
                attn_to_image = head_attn[src_pos, self.config.image_start:self.config.image_end]

                # Calculate importance metrics
                max_attn = attn_to_image.max().item()
                entropy = -(attn_to_image * torch.log(attn_to_image + 1e-10)).sum().item()

                head_importance[h] = {
                    'max_attention': max_attn,
                    'entropy': entropy,
                    'focus_score': max_attn * (1 / (1 + entropy))
                }

        return head_importance

In [5]:
class ClinicalAttentionAnalyzer:
    """
    Analyze attention patterns for clinical interpretation
    Maps attention to anatomical regions relevant for specific pathologies
    """

    def __init__(self, extractor: MedGemmaAttentionExtractor):
        self.extractor = extractor

        # Define anatomical regions for 16x16 grid
        self.anatomical_regions = {
            'upper_left': (0, 5, 0, 5),
            'upper_center': (0, 5, 5, 11),
            'upper_right': (0, 5, 11, 16),
            'middle_left': (5, 11, 0, 5),
            'middle_center': (5, 11, 5, 11),
            'middle_right': (5, 11, 11, 16),
            'lower_left': (11, 16, 0, 5),
            'lower_center': (11, 16, 5, 11),
            'lower_right': (11, 16, 11, 16)
        }

        # Map pathologies to expected regions
        self.pathology_expectations = {
            'pleural effusion': ['lower_left', 'lower_right'],
            'pneumothorax': ['upper_left', 'upper_right'],
            'consolidation': ['middle_left', 'middle_right', 'lower_left', 'lower_right'],
            'cardiomegaly': ['middle_center', 'lower_center'],
            'fracture': ['upper_left', 'upper_right', 'middle_left', 'middle_right']
        }

    def analyze_regional_attention(self, attention_map: np.ndarray) -> Dict:
        """Calculate attention statistics for each anatomical region"""

        regional_stats = {}

        for region_name, (r1, r2, c1, c2) in self.anatomical_regions.items():
            region_attention = attention_map[r1:r2, c1:c2]

            regional_stats[region_name] = {
                'mean': float(region_attention.mean()),
                'max': float(region_attention.max()),
                'sum': float(region_attention.sum()),
                'std': float(region_attention.std())
            }

        # Sort by mean attention
        sorted_regions = sorted(regional_stats.items(),
                              key=lambda x: x[1]['mean'],
                              reverse=True)

        regional_stats['top_regions'] = [r[0] for r in sorted_regions[:3]]

        return regional_stats

    def evaluate_pathology_alignment(self, attention_map: np.ndarray,
                                    pathology: str) -> Dict:
        """Evaluate if attention aligns with expected regions for pathology"""

        regional_stats = self.analyze_regional_attention(attention_map)

        if pathology.lower() in self.pathology_expectations:
            expected_regions = self.pathology_expectations[pathology.lower()]

            # Calculate attention in expected vs unexpected regions
            expected_attention = sum(
                regional_stats[r]['mean']
                for r in expected_regions
                if r in regional_stats
            )

            all_attention = sum(
                stats['mean']
                for r, stats in regional_stats.items()
                if r != 'top_regions'
            )

            alignment_score = expected_attention / (all_attention + 1e-8)

            # Check if top regions overlap with expected
            top_regions = regional_stats.get('top_regions', [])
            overlap = len(set(top_regions) & set(expected_regions))

            return {
                'alignment_score': float(alignment_score),
                'expected_regions': expected_regions,
                'actual_top_regions': top_regions,
                'overlap_count': overlap,
                'is_aligned': overlap >= 1
            }

        return {
            'alignment_score': 0.0,
            'expected_regions': [],
            'actual_top_regions': regional_stats.get('top_regions', []),
            'overlap_count': 0,
            'is_aligned': False
        }

In [7]:
class ViTPrismaIntegration:
    """
    Integrate ViT-Prisma concepts with MedGemma
    Attention rollout and diversity analysis
    """

    def __init__(self, extractor: MedGemmaAttentionExtractor):
        self.extractor = extractor

    def compute_attention_rollout(self, outputs, inputs, token_idx: int = 0) -> np.ndarray:
        """
        Compute attention rollout across layers
        Shows how information flows through the network
        """

        if token_idx >= len(outputs.attentions):
            return np.zeros((16, 16))

        token_attentions = outputs.attentions[token_idx]
        input_len = inputs['input_ids'].shape[1]

        # Initialize with identity
        rollout = None

        for layer_idx, layer_attn in enumerate(token_attentions):
            if torch.is_tensor(layer_attn):
                layer_attn = layer_attn.cpu().float()

                if len(layer_attn.shape) == 4:
                    layer_attn = layer_attn[0]

                # Average over heads
                layer_attn = layer_attn.mean(dim=0)

                # Add residual connection
                eye = torch.eye(layer_attn.shape[0])
                layer_attn = 0.5 * layer_attn + 0.5 * eye

                # Normalize
                layer_attn = layer_attn / layer_attn.sum(dim=-1, keepdim=True)

                if rollout is None:
                    rollout = layer_attn
                else:
                    rollout = torch.matmul(rollout, layer_attn)

        if rollout is not None:
            # Extract final attention to image
            src_pos = min(input_len + token_idx, rollout.shape[0] - 1)

            if src_pos >= 0 and self.extractor.config.image_end <= rollout.shape[1]:
                attn_to_image = rollout[src_pos,
                                       self.extractor.config.image_start:self.extractor.config.image_end]
                return attn_to_image.reshape(16, 16).numpy()

        return np.zeros((16, 16))

    def compute_attention_diversity(self, outputs, inputs, num_tokens: int = 5) -> float:
        """
        Measure diversity of attention patterns across tokens
        Higher diversity suggests model examines different aspects
        """

        attention_maps = []

        for token_idx in range(min(num_tokens, len(outputs.attentions))):
            attn_map = self.extractor.extract_attention_safe(outputs, inputs, token_idx)
            if attn_map.max() > 0:  # Valid attention map
                attention_maps.append(attn_map.flatten())

        if len(attention_maps) < 2:
            return 0.0

        # Compute pairwise distances
        distances = []
        for i in range(len(attention_maps)):
            for j in range(i + 1, len(attention_maps)):
                dist = np.linalg.norm(attention_maps[i] - attention_maps[j])
                distances.append(dist)

        return float(np.mean(distances))

In [8]:
class MIMICCXRProcessor:
    """
    Process MIMIC-CXR data with comprehensive attention analysis
    Builds on your existing code structure
    """

    def __init__(self, model, processor, base_image_path: str, device='cuda'):
        self.model = model
        self.processor = processor
        self.base_image_path = Path(base_image_path)
        self.device = device

        # Initialize components
        self.extractor = MedGemmaAttentionExtractor(model, processor, device)
        self.clinical_analyzer = ClinicalAttentionAnalyzer(self.extractor)
        self.prisma = ViTPrismaIntegration(self.extractor)

        # Storage for results
        self.results_cache = {}

    def process_single_sample(self, row: pd.Series) -> Dict:
        """Process a single MIMIC-CXR sample"""

        try:
            # Prepare inputs
            image_path = self.base_image_path / row['ImagePath']
            if not image_path.exists():
                logger.warning(f"Image not found: {image_path}")
                return None

            inputs_gpu, image = self.extractor.prepare_inputs(
                str(image_path),
                row['question'],
                ast.literal_eval(row['options']) if isinstance(row['options'], str) else row['options']
            )

            # Generate with attention using your working config
            gen_kwargs = {
                "max_new_tokens": 50,
                "do_sample": False,  # Greedy as per your findings
                "output_attentions": True,
                "return_dict_in_generate": True,
                "pad_token_id": self.processor.tokenizer.pad_token_id,
                "eos_token_id": self.processor.tokenizer.eos_token_id,
            }

            with torch.no_grad():
                outputs = self.model.generate(**inputs_gpu, **gen_kwargs)

            # Extract generated text
            generated_ids = outputs.sequences[0][len(inputs_gpu['input_ids'][0]):]
            generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)

            # Attention analysis
            attention_results = self.analyze_attention(outputs, inputs_gpu, row)

            # Compile results
            result = {
                'study_id': row['study_id'],
                'question': row['question'],
                'correct_answer': row['correct_answer'],
                'generated_answer': self._extract_answer(generated_text, row['options']),
                'generated_text': generated_text,
                'attention_analysis': attention_results,
                'image_path': str(image_path)
            }

            # Clean up GPU memory
            del outputs
            torch.cuda.empty_cache()

            return result

        except Exception as e:
            logger.error(f"Error processing sample {row['study_id']}: {e}")
            return None

    def analyze_attention(self, outputs, inputs, row) -> Dict:
        """Comprehensive attention analysis for a sample"""

        # Extract pathology from question
        pathology = self._extract_pathology(row['question'])

        # 1. Raw attention extraction
        raw_attention = self.extractor.extract_attention_safe(outputs, inputs, 0)

        # 2. Layer-weighted relevancy
        relevancy = self.extractor.compute_layer_weighted_relevancy(outputs, inputs, 0)

        # 3. Head importance
        head_importance = self.extractor.compute_head_importance(outputs, inputs, 0)

        # 4. Attention rollout
        rollout = self.prisma.compute_attention_rollout(outputs, inputs, 0)

        # 5. Regional analysis
        regional_stats = self.clinical_analyzer.analyze_regional_attention(relevancy)

        # 6. Pathology alignment
        alignment = self.clinical_analyzer.evaluate_pathology_alignment(relevancy, pathology)

        # 7. Attention diversity
        diversity = self.prisma.compute_attention_diversity(outputs, inputs)

        return {
            'raw_attention': raw_attention.tolist(),
            'relevancy_map': relevancy.tolist(),
            'attention_rollout': rollout.tolist(),
            'head_importance': head_importance,
            'regional_stats': regional_stats,
            'pathology_alignment': alignment,
            'attention_diversity': diversity,
            'pathology_type': pathology
        }

    def _extract_pathology(self, question: str) -> str:
        """Extract pathology type from question"""
        question_lower = question.lower()

        pathologies = ['pleural effusion', 'pneumothorax', 'consolidation',
                      'opacity', 'fracture', 'hernia', 'kyphosis', 'cardiomegaly']

        for pathology in pathologies:
            if pathology in question_lower:
                return pathology

        if 'effusion' in question_lower:
            return 'pleural effusion'

        return 'other'

    def _extract_answer(self, generated_text: str, options: List[str]) -> str:
        """Extract answer from generated text"""
        text_lower = generated_text.lower()

        # Simple heuristic - look for yes/no
        if 'yes' in text_lower[:20]:
            return 'yes'
        elif 'no' in text_lower[:20]:
            return 'no'

        return 'uncertain'

    def process_dataset(self, df: pd.DataFrame, sample_size: Optional[int] = None) -> pd.DataFrame:
        """Process entire dataset"""

        if sample_size:
            df = df.sample(min(sample_size, len(df)))

        results = []

        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing samples"):
            result = self.process_single_sample(row)
            if result:
                results.append(result)

            # Periodic cleanup
            if idx % 10 == 0:
                gc.collect()
                torch.cuda.empty_cache()

        return pd.DataFrame(results)

In [9]:
def create_clinical_visualization(image_path: str, attention_map: np.ndarray,
                                regional_stats: Dict, alignment: Dict,
                                question: str, generated_answer: str) -> plt.Figure:
    """Create comprehensive clinical visualization"""

    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

    # Load image
    image = Image.open(image_path).convert('RGB')

    # 1. Original X-ray
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(image, cmap='gray')
    ax1.set_title('Input X-ray')
    ax1.axis('off')

    # 2. Attention heatmap
    ax2 = fig.add_subplot(gs[0, 1])
    im = ax2.imshow(attention_map, cmap='hot', interpolation='bicubic')
    ax2.set_title('Attention Map')
    ax2.axis('off')
    plt.colorbar(im, ax=ax2, fraction=0.046)

    # 3. Overlay
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(image, cmap='gray')

    # Resize attention to image size
    h, w = image.size[::-1]
    attention_resized = cv2.resize(attention_map, (w, h), interpolation=cv2.INTER_CUBIC)
    ax3.imshow(attention_resized, cmap='hot', alpha=0.5)
    ax3.set_title('Attention Overlay')
    ax3.axis('off')

    # 4. Regional analysis bar chart
    ax4 = fig.add_subplot(gs[1, 0])
    regions = list(regional_stats.keys())
    if 'top_regions' in regions:
        regions.remove('top_regions')
    values = [regional_stats[r]['mean'] for r in regions]

    bars = ax4.bar(range(len(regions)), values, color='skyblue')
    ax4.set_xticks(range(len(regions)))
    ax4.set_xticklabels([r.replace('_', '\n') for r in regions], rotation=0, fontsize=8)
    ax4.set_ylabel('Mean Attention')
    ax4.set_title('Regional Attention Distribution')
    ax4.grid(True, alpha=0.3)

    # Highlight top regions
    top_regions = regional_stats.get('top_regions', [])
    for i, region in enumerate(regions):
        if region in top_regions:
            bars[i].set_color('coral')

    # 5. Pathology alignment
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.axis('off')

    alignment_text = f"Question: {question[:50]}...\n\n"
    alignment_text += f"Generated Answer: {generated_answer}\n\n"
    alignment_text += "Pathology Alignment:\n"
    alignment_text += f"• Alignment Score: {alignment.get('alignment_score', 0):.2%}\n"
    alignment_text += f"• Expected Regions: {', '.join(alignment.get('expected_regions', []))}\n"
    alignment_text += f"• Actual Top Regions: {', '.join(alignment.get('actual_top_regions', []))}\n"
    alignment_text += f"• Clinically Aligned: {'✓' if alignment.get('is_aligned') else '✗'}"

    ax5.text(0.05, 0.95, alignment_text, transform=ax5.transAxes,
            fontsize=10, verticalalignment='top',
            bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

    # 6. Attention statistics
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.hist(attention_map.flatten(), bins=30, alpha=0.7, color='purple')
    ax6.axvline(attention_map.mean(), color='red', linestyle='--',
               label=f'Mean: {attention_map.mean():.3f}')
    ax6.axvline(np.percentile(attention_map, 90), color='green', linestyle='--',
               label=f'90th %ile: {np.percentile(attention_map, 90):.3f}')
    ax6.set_xlabel('Attention Value')
    ax6.set_ylabel('Frequency')
    ax6.set_title('Attention Distribution')
    ax6.legend()
    ax6.grid(True, alpha=0.3)

    plt.suptitle('MedGemma Clinical Attention Analysis', fontsize=14)

    return fig

# Main execution function
def run_mimic_analysis(csv_path: str, image_base_path: str,
                       output_dir: str, sample_size: Optional[int] = None):
    """Run complete MIMIC-CXR analysis with MedGemma"""

    print("="*60)
    print("MEDGEMMA MIMIC-CXR ANALYSIS")
    print("="*60)

    # Set up output directory
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Initialize model (using your working configuration)
    print("\n1. Loading MedGemma model...")
    model_id = 'google/medgemma-4b-it'

    from transformers import AutoProcessor, AutoModelForImageTextToText

    processor = AutoProcessor.from_pretrained(model_id)
    model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0",
        attn_implementation="eager",
        tie_word_embeddings=False
    )
    model.eval()

    print("✓ Model loaded successfully")

    # Load data
    print("\n2. Loading MIMIC-CXR data...")
    df = pd.read_csv(csv_path)
    print(f"✓ Loaded {len(df)} samples")

    # Initialize processor
    print("\n3. Initializing processor...")
    mimic_processor = MIMICCXRProcessor(model, processor, image_base_path)

    # Process dataset
    print(f"\n4. Processing {sample_size if sample_size else len(df)} samples...")
    results_df = mimic_processor.process_dataset(df, sample_size)

    # Save results
    print("\n5. Saving results...")
    results_df.to_csv(output_path / 'attention_results.csv', index=False)

    # Analyze results
    print("\n6. Analyzing results...")

    # Pathology alignment statistics
    if 'attention_analysis' in results_df.columns:
        alignments = []
        for _, row in results_df.iterrows():
            if row['attention_analysis'] and 'pathology_alignment' in row['attention_analysis']:
                alignments.append(row['attention_analysis']['pathology_alignment']['alignment_score'])

        if alignments:
            print(f"\nPathology Alignment Statistics:")
            print(f"  Mean alignment: {np.mean(alignments):.2%}")
            print(f"  Std alignment: {np.std(alignments):.2%}")
            print(f"  Max alignment: {np.max(alignments):.2%}")

    # Generate visualizations for top samples
    print("\n7. Creating visualizations...")
    viz_dir = output_path / 'visualizations'
    viz_dir.mkdir(exist_ok=True)

    for idx, row in results_df.head(5).iterrows():
        if row['attention_analysis']:
            attention_map = np.array(row['attention_analysis']['relevancy_map'])

            fig = create_clinical_visualization(
                row['image_path'],
                attention_map,
                row['attention_analysis']['regional_stats'],
                row['attention_analysis']['pathology_alignment'],
                row['question'],
                row['generated_answer']
            )

            fig.savefig(viz_dir / f'sample_{idx}.png', dpi=150, bbox_inches='tight')
            plt.close(fig)

    print(f"\n✓ Analysis complete! Results saved to {output_path}")

    return results_df

In [11]:
csv_path = "/home/bsada1/lvlm-interpret-medgemma/one-pixel-attack/mimic_adapted_questions.csv"
image_base_path = "/home/bsada1/mimic_cxr_hundred_vqa"
output_dir = "mimic_medgemma_analysis"
results = run_mimic_analysis(csv_path, image_base_path, output_dir, sample_size=10)

MEDGEMMA MIMIC-CXR ANALYSIS

1. Loading MedGemma model...


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.29it/s]


✓ Model loaded successfully

2. Loading MIMIC-CXR data...
✓ Loaded 100 samples

3. Initializing processor...

4. Processing 10 samples...


Processing samples:   0%|          | 0/10 [00:00<?, ?it/s]ERROR:__main__:Error processing sample a6617202-f5a8661d-78eb1442-037bf3e4-3dd8967f: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 626.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.09 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 24.01 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  10%|█         | 1/10 [00:00<00:03,  2.29it/s]ERROR:__main__:Error processing sample 051b7911-cb00aec9-0b309188-89803662-303ec278: CUDA out of memory. Trie


5. Saving results...

6. Analyzing results...

7. Creating visualizations...

✓ Analysis complete! Results saved to mimic_medgemma_analysis


In [None]:
i Loading MedGemma model...
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.29it/s]

✓ Model loaded successfully

2. Loading MIMIC-CXR data...
✓ Loaded 100 samples

3. Initializing processor...

4. Processing 10 samples...
Processing samples:   0%|          | 0/10 [00:00<?, ?it/s]ERROR:__main__:Error processing sample a6617202-f5a8661d-78eb1442-037bf3e4-3dd8967f: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 626.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.09 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 24.01 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  10%|█         | 1/10 [00:00<00:03,  2.29it/s]ERROR:__main__:Error processing sample 051b7911-cb00aec9-0b309188-89803662-303ec278: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 114.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.59 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 536.35 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  20%|██        | 2/10 [00:00<00:02,  3.72it/s]ERROR:__main__:Error processing sample edb88e4a-c04f1be7-aefcf3e0-8889542d-692ff7fd: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 114.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.59 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 536.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  30%|███       | 3/10 [00:00<00:01,  4.74it/s]ERROR:__main__:Error processing sample 9ca8f84e-92fac212-e60ac49d-01779362-caa16791: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 114.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.59 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 536.34 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  40%|████      | 4/10 [00:00<00:01,  5.34it/s]ERROR:__main__:Error processing sample d999236f-95dcb8b7-a4d20a3f-be538f50-ce13a08e: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 114.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.59 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 536.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Processing samples:  50%|█████     | 5/10 [00:01<00:00,  5.67it/s]ERROR:__main__:Error processing sample 2833b85f-3bb4273f-cffd3794-2bf2cd57-7ddb3f5f: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 79.25 GiB of which 114.50 MiB is free. Process 3818168 has 672.00 MiB memory in use. Process 2843597 has 66.58 GiB memory in use. Process 2985595 has 1.28 GiB memory in use. Including non-PyTorch memory, this process has 10.59 GiB memory in use. Of the allocated memory 9.57 GiB is allocated by PyTorch, and 536.36 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLO

In [12]:
# Set up environment for memory efficiency
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def select_best_gpu(min_free_memory_gb: float = 20.0) -> int:
    """
    Select GPU with most free memory
    Returns GPU index or raises error if none available
    """
    try:
        gpus = GPUtil.getGPUs()
        best_gpu = None
        max_free_memory = 0

        print("\n=== GPU Memory Status ===")
        for gpu in gpus:
            free_memory_gb = gpu.memoryFree / 1024
            used_memory_gb = gpu.memoryUsed / 1024
            total_memory_gb = gpu.memoryTotal / 1024

            print(f"GPU {gpu.id}: {free_memory_gb:.1f}GB free / {total_memory_gb:.1f}GB total "
                  f"({gpu.memoryUtil*100:.1f}% used)")

            if free_memory_gb > max_free_memory and free_memory_gb >= min_free_memory_gb:
                max_free_memory = free_memory_gb
                best_gpu = gpu.id

        if best_gpu is None:
            raise RuntimeError(f"No GPU with at least {min_free_memory_gb}GB free memory available")

        print(f"\n✓ Selected GPU {best_gpu} with {max_free_memory:.1f}GB free memory")
        return best_gpu

    except Exception as e:
        print(f"Error selecting GPU: {e}")
        print("Falling back to GPU 1 (seems to have more free memory)")
        return 1

@dataclass
class MemoryConfig:
    """Configuration for memory-efficient processing"""
    max_new_tokens: int = 30  # Reduced from 50
    attention_layers_to_save: List[int] = None  # Only save specific layers
    max_attention_tokens: int = 10  # Only save attention for first N tokens
    batch_size: int = 1
    clear_cache_frequency: int = 5  # Clear cache every N samples
    use_amp: bool = True  # Use automatic mixed precision

    def __post_init__(self):
        if self.attention_layers_to_save is None:
            # Only save first, middle, and last layer by default
            self.attention_layers_to_save = [0, 17, 33]

class MemoryEfficientMedGemmaProcessor:
    """
    Memory-optimized version of MedGemma processor
    Key optimizations:
    1. Selective attention saving (not all layers/tokens)
    2. Automatic GPU selection
    3. Mixed precision inference
    4. Aggressive cache clearing
    5. Gradient checkpointing disabled (inference only)
    """

    def __init__(self, model_id: str = 'google/medgemma-4b-it',
                 device_id: Optional[int] = None,
                 memory_config: Optional[MemoryConfig] = None):

        self.memory_config = memory_config or MemoryConfig()

        # Select best GPU if not specified
        if device_id is None:
            device_id = select_best_gpu()

        self.device = f'cuda:{device_id}'
        print(f"\nInitializing MedGemma on {self.device}")

        # Load model with memory optimizations
        self._load_model(model_id)

        # Initialize components
        self.attention_cache = {}
        self.samples_processed = 0

    def _load_model(self, model_id: str):
        """Load model with memory-efficient settings"""
        from transformers import AutoProcessor, AutoModelForImageTextToText

        print("Loading processor...")
        self.processor = AutoProcessor.from_pretrained(model_id)

        print("Loading model with memory optimizations...")
        self.model = AutoModelForImageTextToText.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,  # Use bfloat16 for memory efficiency
            device_map=self.device,
            attn_implementation="eager",
            tie_word_embeddings=False,
            low_cpu_mem_usage=True,  # Reduce CPU memory usage during loading
        )

        # Set to eval mode and disable gradient computation
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

        print("✓ Model loaded successfully")

        # Print memory usage after loading
        self._print_memory_usage()

    def _print_memory_usage(self):
        """Print current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated(self.device) / 1024**3
            reserved = torch.cuda.memory_reserved(self.device) / 1024**3
            print(f"GPU Memory: {allocated:.1f}GB allocated, {reserved:.1f}GB reserved")

    def process_single_sample_efficient(self, image_path: str, question: str,
                                       options: List[str], study_id: str) -> Dict:
        """
        Process a single sample with aggressive memory management
        """
        try:
            # Prepare inputs
            image = Image.open(image_path).convert('RGB')

            # Create prompt
            valid_options = [opt for opt in options if opt]
            prompt = f"""Analyze this chest X-ray and answer: {question}
Options: {', '.join(valid_options)}
Answer with just 'yes' or 'no'."""

            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image", "image": image}
                    ]
                }
            ]

            # Process inputs
            inputs = self.processor.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt"
            )

            # Move to GPU
            inputs_gpu = {
                k: v.to(self.device) if torch.is_tensor(v) else v
                for k, v in inputs.items()
            }

            # Generate with limited tokens and selective attention saving
            with torch.cuda.amp.autocast(enabled=self.memory_config.use_amp):
                with torch.no_grad():
                    outputs = self.model.generate(
                        **inputs_gpu,
                        max_new_tokens=self.memory_config.max_new_tokens,
                        do_sample=False,
                        output_attentions=True,
                        return_dict_in_generate=True,
                        pad_token_id=self.processor.tokenizer.pad_token_id,
                        eos_token_id=self.processor.tokenizer.eos_token_id,
                    )

            # Extract generated text
            generated_ids = outputs.sequences[0][len(inputs['input_ids'][0]):]
            generated_text = self.processor.decode(generated_ids, skip_special_tokens=True)

            # Extract answer
            answer = self._extract_answer(generated_text)

            # Selective attention extraction (memory efficient)
            attention_summary = self._extract_attention_efficient(outputs, inputs_gpu)

            # Clean up intermediate tensors
            del outputs
            del inputs_gpu
            torch.cuda.empty_cache()

            return {
                'study_id': study_id,
                'question': question,
                'generated_answer': answer,
                'generated_text': generated_text[:100],  # Truncate for storage
                'attention_summary': attention_summary
            }

        except torch.cuda.OutOfMemoryError as e:
            logger.error(f"OOM for sample {study_id}: {e}")
            torch.cuda.empty_cache()
            gc.collect()
            return None

        except Exception as e:
            logger.error(f"Error processing sample {study_id}: {e}")
            return None

    def _extract_attention_efficient(self, outputs, inputs) -> Dict:
        """
        Extract only essential attention information to save memory
        Instead of saving all attention tensors, we compute statistics
        """
        try:
            if not hasattr(outputs, 'attentions') or not outputs.attentions:
                return {}

            # Only process first few tokens
            num_tokens = min(len(outputs.attentions), self.memory_config.max_attention_tokens)

            attention_stats = {
                'num_tokens_analyzed': num_tokens,
                'regional_focus': [],
                'attention_entropy': []
            }

            for token_idx in range(num_tokens):
                if token_idx >= len(outputs.attentions):
                    break

                # Get last layer attention only (most semantic)
                last_layer_attn = outputs.attentions[token_idx][-1].cpu().float()

                if len(last_layer_attn.shape) == 4:
                    last_layer_attn = last_layer_attn[0]

                # Average over heads
                avg_attn = last_layer_attn.mean(dim=0)

                # Extract attention to image region
                input_len = inputs['input_ids'].shape[1]
                src_pos = min(input_len + token_idx, avg_attn.shape[0] - 1)

                if src_pos >= 0 and 257 <= avg_attn.shape[1]:
                    attn_to_image = avg_attn[src_pos, 1:257]

                    # Compute statistics instead of saving tensor
                    attn_2d = attn_to_image.reshape(16, 16).numpy()

                    # Regional analysis (which quadrant has most attention)
                    quadrants = {
                        'upper_left': attn_2d[:8, :8].mean(),
                        'upper_right': attn_2d[:8, 8:].mean(),
                        'lower_left': attn_2d[8:, :8].mean(),
                        'lower_right': attn_2d[8:, 8:].mean()
                    }

                    max_quadrant = max(quadrants, key=quadrants.get)
                    attention_stats['regional_focus'].append(max_quadrant)

                    # Attention entropy (how distributed vs focused)
                    entropy = stats.entropy(attn_2d.flatten() + 1e-10)
                    attention_stats['attention_entropy'].append(float(entropy))

            return attention_stats

        except Exception as e:
            logger.error(f"Error extracting attention: {e}")
            return {}

    def _extract_answer(self, text: str) -> str:
        """Extract yes/no answer from generated text"""
        text_lower = text.lower()[:50]  # Only check beginning

        if 'yes' in text_lower:
            return 'yes'
        elif 'no' in text_lower:
            return 'no'
        else:
            return 'uncertain'

    def process_dataset_batch(self, df: pd.DataFrame, image_base_path: str,
                             output_dir: str, sample_size: Optional[int] = None):
        """
        Process dataset with batching and memory management
        """
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)

        # Sample if requested
        if sample_size:
            df = df.sample(min(sample_size, len(df)), random_state=42)

        results = []
        failed_samples = []

        print(f"\nProcessing {len(df)} samples...")

        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
            # Prepare data
            image_path = Path(image_base_path) / row['ImagePath']

            if not image_path.exists():
                logger.warning(f"Image not found: {image_path}")
                failed_samples.append(row['study_id'])
                continue

            # Parse options
            if isinstance(row['options'], str):
                options = ast.literal_eval(row['options'])
            else:
                options = row['options']

            # Process sample
            result = self.process_single_sample_efficient(
                str(image_path),
                row['question'],
                options,
                row['study_id']
            )

            if result:
                result['correct_answer'] = row['correct_answer']
                result['is_correct'] = result['generated_answer'] == row['correct_answer']
                results.append(result)
            else:
                failed_samples.append(row['study_id'])

            # Periodic memory cleanup
            self.samples_processed += 1
            if self.samples_processed % self.memory_config.clear_cache_frequency == 0:
                gc.collect()
                torch.cuda.empty_cache()
                print(f"\n[Memory cleanup at sample {self.samples_processed}]")
                self._print_memory_usage()

        # Save results
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_path / 'results.csv', index=False)

        # Save failed samples
        if failed_samples:
            with open(output_path / 'failed_samples.txt', 'w') as f:
                f.write('\n'.join(failed_samples))

        # Print summary
        print("\n" + "="*60)
        print("PROCESSING COMPLETE")
        print("="*60)
        print(f"Successful: {len(results)}/{len(df)}")
        print(f"Failed: {len(failed_samples)}")

        if len(results) > 0:
            accuracy = results_df['is_correct'].mean()
            print(f"Accuracy: {accuracy:.2%}")

            # Analyze attention patterns
            self._analyze_attention_patterns(results_df)

        return results_df

    def _analyze_attention_patterns(self, results_df: pd.DataFrame):
        """Analyze attention patterns from results"""
        print("\n=== Attention Pattern Analysis ===")

        # Extract regional focus patterns
        all_regions = []
        for _, row in results_df.iterrows():
            if row['attention_summary'] and 'regional_focus' in row['attention_summary']:
                all_regions.extend(row['attention_summary']['regional_focus'])

        if all_regions:
            from collections import Counter
            region_counts = Counter(all_regions)
            total = len(all_regions)

            print("\nRegional Focus Distribution:")
            for region, count in region_counts.most_common():
                percentage = (count / total) * 100
                print(f"  {region}: {percentage:.1f}%")

        # Analyze entropy
        all_entropies = []
        for _, row in results_df.iterrows():
            if row['attention_summary'] and 'attention_entropy' in row['attention_summary']:
                all_entropies.extend(row['attention_summary']['attention_entropy'])

        if all_entropies:
            print(f"\nAttention Entropy Statistics:")
            print(f"  Mean: {np.mean(all_entropies):.2f}")
            print(f"  Std: {np.std(all_entropies):.2f}")
            print(f"  Min: {np.min(all_entropies):.2f}")
            print(f"  Max: {np.max(all_entropies):.2f}")

def run_memory_efficient_analysis(csv_path: str, image_base_path: str,
                                 output_dir: str, sample_size: int = 10):
    """
    Main function to run memory-efficient MIMIC-CXR analysis
    """
    print("="*60)
    print("MEMORY-EFFICIENT MEDGEMMA ANALYSIS")
    print("="*60)

    # Load data
    print("\nLoading MIMIC-CXR data...")
    df = pd.read_csv(csv_path)
    print(f"Loaded {len(df)} samples")

    # Configure memory settings
    memory_config = MemoryConfig(
        max_new_tokens=20,  # Even more conservative
        max_attention_tokens=5,  # Only analyze first 5 tokens
        clear_cache_frequency=3  # Clear more frequently
    )

    # Initialize processor with automatic GPU selection
    processor = MemoryEfficientMedGemmaProcessor(
        memory_config=memory_config
    )

    # Process dataset
    results_df = processor.process_dataset_batch(
        df,
        image_base_path,
        output_dir,
        sample_size=sample_size
    )

    print(f"\n✓ Results saved to {output_dir}")

    return results_df

# Utility function to check GPU availability before running
def check_gpu_availability():
    """Check and report GPU availability"""
    print("="*60)
    print("GPU AVAILABILITY CHECK")
    print("="*60)

    if not torch.cuda.is_available():
        print("❌ CUDA is not available!")
        return False

    try:
        import GPUtil
        gpus = GPUtil.getGPUs()

        suitable_gpus = []
        for gpu in gpus:
            free_gb = gpu.memoryFree / 1024
            if free_gb >= 15:  # At least 15GB free
                suitable_gpus.append((gpu.id, free_gb))

        if suitable_gpus:
            print(f"✓ Found {len(suitable_gpus)} suitable GPU(s):")
            for gpu_id, free_mem in suitable_gpus:
                print(f"  GPU {gpu_id}: {free_mem:.1f}GB free")
            return True
        else:
            print("❌ No GPUs with sufficient free memory (need at least 15GB)")
            return False

    except ImportError:
        print("⚠️ GPUtil not installed. Install with: pip install gputil")
        print("Checking basic CUDA availability...")
        print(f"✓ CUDA available with {torch.cuda.device_count()} GPU(s)")
        return True

In [17]:
csv_path = "/home/bsada1/lvlm-interpret-medgemma/one-pixel-attack/mimic_adapted_questions.csv"
image_base_path = "/home/bsada1/mimic_cxr_hundred_vqa"
output_dir = "mimic_medgemma_analysis"
results = run_memory_efficient_analysis(
        csv_path,
        image_base_path,
        output_dir,
        sample_size=10  # Start small
    )

MEMORY-EFFICIENT MEDGEMMA ANALYSIS

Loading MIMIC-CXR data...
Loaded 100 samples

=== GPU Memory Status ===
GPU 0: 10.2GB free / 80.0GB total (86.3% used)
GPU 1: 68.9GB free / 80.0GB total (13.0% used)
GPU 2: 78.4GB free / 80.0GB total (1.1% used)
GPU 3: 78.4GB free / 80.0GB total (1.1% used)
GPU 4: 78.4GB free / 80.0GB total (1.1% used)
GPU 5: 78.4GB free / 80.0GB total (1.1% used)
GPU 6: 78.4GB free / 80.0GB total (1.1% used)
GPU 7: 78.4GB free / 80.0GB total (1.1% used)

✓ Selected GPU 2 with 78.4GB free memory

Initializing MedGemma on cuda:2
Loading processor...
Loading model with memory optimizations...


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.15it/s]


✓ Model loaded successfully
GPU Memory: 8.0GB allocated, 8.0GB reserved

Processing 10 samples...


  with torch.cuda.amp.autocast(enabled=self.memory_config.use_amp):
Processing:  30%|███       | 3/10 [00:04<00:11,  1.61s/it]


[Memory cleanup at sample 3]
GPU Memory: 8.0GB allocated, 8.0GB reserved


Processing:  60%|██████    | 6/10 [00:09<00:06,  1.59s/it]


[Memory cleanup at sample 6]
GPU Memory: 8.0GB allocated, 8.0GB reserved


Processing:  90%|█████████ | 9/10 [00:14<00:01,  1.59s/it]


[Memory cleanup at sample 9]
GPU Memory: 8.0GB allocated, 8.0GB reserved


Processing: 100%|██████████| 10/10 [00:15<00:00,  1.57s/it]


PROCESSING COMPLETE
Successful: 10/10
Failed: 0
Accuracy: 60.00%

=== Attention Pattern Analysis ===

Regional Focus Distribution:
  upper_left: 100.0%

Attention Entropy Statistics:
  Mean: 4.27
  Std: 0.15
  Min: 4.01
  Max: 4.74

✓ Results saved to mimic_medgemma_analysis



