In [None]:
!pip install torch transformers onnx onnxruntime pillow accelerate

In [None]:
!pip install bitsandbytes

In [None]:
!pip install onnxscript

In [None]:
!pip install transformers==4.57.6

In [None]:
"""
Export MedGemma Vision Encoder to ONNX - Direct Export Version
================================================================

This version exports the vision encoder directly without reloading weights.
Simpler and more reliable.

Colab Setup:
```python
# Cell 1: Install dependencies  
!pip install torch transformers onnx onnxruntime pillow accelerate

# Cell 2: Run export
%run export_vision_onnx.py
```
"""

import os
import gc
import json
import torch
import torch.nn as nn
import numpy as np


def clear_memory():
    """Clear GPU and CPU memory."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


class VisionEncoderWrapper(nn.Module):
    """
    Simple wrapper around the vision tower for ONNX export.
    Does NOT copy weights - uses the original vision tower directly.
    """
    
    def __init__(self, vision_tower, projection_weight):
        super().__init__()
        self.vision_tower = vision_tower
        self.register_buffer('projection_weight', projection_weight.clone().detach())
    
    def forward(self, pixel_values):
        # Get vision embeddings from SigLIP
        vision_outputs = self.vision_tower(pixel_values)
        
        if hasattr(vision_outputs, 'last_hidden_state'):
            vision_embeds = vision_outputs.last_hidden_state
        else:
            vision_embeds = vision_outputs[0]
        
        batch_size = vision_embeds.shape[0]
        num_patches = vision_embeds.shape[1]
        hidden_dim = vision_embeds.shape[2]
        
        # Pool to 256 patches if needed (for 16x16 grid)
        if num_patches > 256:
            side = int(num_patches ** 0.5)
            pool_factor = side // 16
            vision_embeds = vision_embeds.reshape(batch_size, side, side, hidden_dim)
            vision_embeds = vision_embeds.reshape(batch_size, 16, pool_factor, 16, pool_factor, hidden_dim)
            vision_embeds = vision_embeds.mean(dim=(2, 4))
            vision_embeds = vision_embeds.reshape(batch_size, 256, hidden_dim)
        
        # Project to LLM embedding space
        projected = torch.matmul(vision_embeds, self.projection_weight)
        
        # Attention scores = projection magnitude (normalized)
        attention = torch.norm(projected, dim=-1)
        att_min = attention.min(dim=-1, keepdim=True)[0]
        att_max = attention.max(dim=-1, keepdim=True)[0]
        attention = (attention - att_min) / (att_max - att_min + 1e-8)
        
        # Global pooled representation
        pooled = projected.mean(dim=1)
        
        return projected, attention, pooled


def main():
    """Main export function - direct approach."""
    from transformers import AutoModelForImageTextToText, AutoProcessor
    
    output_dir = "./onnx_export"
    os.makedirs(output_dir, exist_ok=True)
    
    print("="*60)
    print("MedGemma Vision Encoder ONNX Export")
    print("Direct Export Version (No Weight Reload)")
    print("="*60)
    
    # ========================================
    # Step 1: Load model WITHOUT quantization
    # ========================================
    print("\n[Step 1] Loading model in FP16 (no quantization)...")
    
    model_id = "convaiinnovations/medgemma-4b-ecginstruct"
    
    # Load in FP16 without quantization
    # This uses more memory but gives us proper weights
    model = AutoModelForImageTextToText.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map="auto",
        low_cpu_mem_usage=True,
    )
    processor = AutoProcessor.from_pretrained(model_id)
    
    print(f"Model loaded on {model.device}")
    
    # ========================================
    # Step 2: Extract vision components
    # ========================================
    print("\n[Step 2] Extracting vision encoder...")
    
    vision_tower = model.vision_tower
    projector = model.multi_modal_projector
    
    # Get projection weight
    if hasattr(projector, 'mm_input_projection_weight'):
        proj_weight = projector.mm_input_projection_weight.data.float()
    else:
        for module in projector.modules():
            if isinstance(module, nn.Linear):
                proj_weight = module.weight.data.float()
                break
    
    print(f"Vision tower: {type(vision_tower).__name__}")
    print(f"Projection weight: {proj_weight.shape}")
    
    # Save projection weight for edge CMAS
    np.save(os.path.join(output_dir, "projection_weight.npy"), proj_weight.cpu().numpy())
    
    # Get image size
    if hasattr(processor, 'image_processor'):
        img_size = processor.image_processor.size
        height = img_size.get('height', 896) if isinstance(img_size, dict) else img_size
        width = img_size.get('width', 896) if isinstance(img_size, dict) else img_size
    else:
        height = width = 896
    
    # ========================================
    # Step 3: Create wrapper and move to CPU
    # ========================================
    print("\n[Step 3] Creating export wrapper...")
    
    # Remove accelerate hooks before moving to CPU
    from accelerate.hooks import remove_hook_from_module
    
    def remove_all_hooks(module):
        """Recursively remove accelerate hooks from all submodules."""
        for name, child in module.named_children():
            remove_all_hooks(child)
        if hasattr(module, '_hf_hook'):
            remove_hook_from_module(module)
    
    remove_all_hooks(vision_tower)
    print("Accelerate hooks removed.")
    
    # Move vision tower to CPU and float32 for ONNX export
    vision_tower = vision_tower.cpu().float()
    proj_weight = proj_weight.cpu()
    
    # Delete full model to free GPU memory
    del model
    del projector
    clear_memory()
    print("GPU memory freed.")
    
    # Create wrapper
    wrapper = VisionEncoderWrapper(vision_tower, proj_weight)
    wrapper.eval()
    
    # ========================================
    # Step 4: Test forward pass
    # ========================================
    print("\n[Step 4] Testing forward pass...")
    
    # Ensure dummy input is on CPU (same device as model)
    dummy_input = torch.randn(1, 3, height, width, device='cpu')
    
    with torch.no_grad():
        projected, attention, pooled = wrapper(dummy_input)
    
    print(f"  Projected: {projected.shape}")
    print(f"  Attention: {attention.shape}")  
    print(f"  Pooled: {pooled.shape}")
    
    # ========================================
    # Step 5: Export to ONNX
    # ========================================
    print("\n[Step 5] Exporting to ONNX...")
    
    onnx_path = os.path.join(output_dir, "vision_encoder.onnx")
    
    # Use legacy ONNX export (more stable)
    torch.onnx.export(
        wrapper,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=18,  # Use opset 18 to match PyTorch's implementations
        do_constant_folding=True,
        input_names=['pixel_values'],
        output_names=['projected_embeddings', 'attention_scores', 'pooled_embedding'],
        dynamic_axes={
            'pixel_values': {0: 'batch_size'},
            'projected_embeddings': {0: 'batch_size'},
            'attention_scores': {0: 'batch_size'},
            'pooled_embedding': {0: 'batch_size'}
        },
        dynamo=False  # Use legacy exporter for stability
    )
    
    file_size = os.path.getsize(onnx_path) / (1024 * 1024)
    print(f"ONNX exported: {file_size:.1f} MB")
    
    # ========================================
    # Step 6: Verify
    # ========================================
    print("\n[Step 6] Verifying ONNX model...")
    
    import onnx
    import onnxruntime as ort
    
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("  ONNX validation: ✅")
    
    session = ort.InferenceSession(onnx_path, providers=['CPUExecutionProvider'])
    test_input = np.random.randn(1, 3, height, width).astype(np.float32)
    outputs = session.run(None, {'pixel_values': test_input})
    print(f"  Inference test: ✅")
    print(f"    - Projected: {outputs[0].shape}")
    print(f"    - Attention: {outputs[1].shape}")
    print(f"    - Pooled: {outputs[2].shape}")
    
    # ========================================
    # Save config
    # ========================================
    config = {
        'input_height': height,
        'input_width': width,
        'num_patches': 256,
        'grid_size': 16,
        'projected_dim': int(proj_weight.shape[1]),
        'model_id': model_id,
    }
    
    with open(os.path.join(output_dir, "vision_config.json"), 'w') as f:
        json.dump(config, f, indent=2)
    
    # ========================================
    # Done!
    # ========================================
    print("\n" + "="*60)
    print("✅ EXPORT COMPLETE!")
    print("="*60)
    print(f"""
Files in {output_dir}/:
  • vision_encoder.onnx    ({file_size:.1f} MB)
  • vision_config.json
  • projection_weight.npy

Download and use with edge_inference.py on your desktop.
""")


if __name__ == '__main__':
    main()


In [None]:
%%writefile edge_inference.py
"""
Edge Inference with ONNX Vision Encoder
=========================================

Use this on the desktop app to compute attention/CMAS locally
without needing the full MedGemma model.

Usage:
    from edge_inference import EdgeVisionEncoder
    
    encoder = EdgeVisionEncoder("./onnx_export")
    attention_map, heatmap_image = encoder.get_attention_heatmap(ecg_image)
"""

import os
import json
import numpy as np
from PIL import Image


class EdgeVisionEncoder:
    """
    Local edge inference for attention/CMAS computation.
    Uses ONNX runtime for efficient CPU/GPU inference.
    """
    
    def __init__(self, onnx_dir, use_gpu=False):
        """
        Initialize the edge vision encoder.
        
        Args:
            onnx_dir: Directory containing vision_encoder.onnx and vision_config.json
            use_gpu: Whether to use GPU (requires onnxruntime-gpu)
        """
        import onnxruntime as ort
        
        # Load config
        config_path = os.path.join(onnx_dir, "vision_config.json")
        with open(config_path, 'r') as f:
            self.config = json.load(f)
        
        # Set up ONNX session
        onnx_path = os.path.join(onnx_dir, "vision_encoder.onnx")
        
        if use_gpu:
            providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        else:
            providers = ['CPUExecutionProvider']
        
        self.session = ort.InferenceSession(onnx_path, providers=providers)
        
        # Load projection weight if available (for full CMAS)
        proj_path = os.path.join(onnx_dir, "projection_weight.npy")
        if os.path.exists(proj_path):
            self.projection_weight = np.load(proj_path)
        else:
            self.projection_weight = None
        
        self.input_height = self.config.get('input_height', 896)
        self.input_width = self.config.get('input_width', 896)
        self.grid_size = self.config.get('grid_size', 16)
        
        print(f"EdgeVisionEncoder initialized")
        print(f"  Input size: {self.input_height}x{self.input_width}")
        print(f"  Grid size: {self.grid_size}x{self.grid_size}")
    
    def preprocess_image(self, image):
        """
        Preprocess image for vision encoder.
        
        Args:
            image: PIL Image or path to image
            
        Returns:
            numpy array of shape [1, 3, H, W]
        """
        if isinstance(image, str):
            image = Image.open(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        # Resize to model input size
        image = image.resize((self.input_width, self.input_height), Image.BILINEAR)
        
        # Convert to numpy and normalize to [0, 1]
        img_array = np.array(image).astype(np.float32) / 255.0
        
        # Normalize with ImageNet stats (standard for vision models)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img_array = (img_array - mean) / std
        
        # Transpose to [C, H, W] and add batch dimension
        img_array = img_array.transpose(2, 0, 1)
        img_array = np.expand_dims(img_array, 0).astype(np.float32)
        
        return img_array
    
    def compute_attention(self, image):
        """
        Compute attention scores for an image.
        
        Args:
            image: PIL Image or path to image
            
        Returns:
            attention_2d: [16, 16] attention map
            projected: [256, hidden_dim] projected embeddings
            pooled: [hidden_dim] global representation
        """
        # Preprocess
        pixel_values = self.preprocess_image(image)
        
        # Run ONNX inference
        outputs = self.session.run(None, {'pixel_values': pixel_values})
        projected, attention, pooled = outputs
        
        # Reshape attention to 2D grid (16x16)
        attention_2d = attention[0].reshape(self.grid_size, self.grid_size)
        
        return attention_2d, projected[0], pooled[0]
    
    def compute_cmas_with_reference(self, image, reference_embedding):
        """
        Compute full CMAS using a reference diagnosis embedding.
        
        Args:
            image: PIL Image or path to image
            reference_embedding: [hidden_dim] reference embedding for comparison
            
        Returns:
            cmas_2d: [16, 16] CMAS attention map
        """
        attention_2d, projected, pooled = self.compute_attention(image)
        
        # Compute cosine similarity with reference
        reference = np.array(reference_embedding)
        
        # Normalize
        proj_norm = projected / (np.linalg.norm(projected, axis=-1, keepdims=True) + 1e-8)
        ref_norm = reference / (np.linalg.norm(reference) + 1e-8)
        
        # Cosine similarity per patch
        cos_sim = np.dot(proj_norm, ref_norm)  # [256]
        
        # Magnitude
        magnitudes = np.linalg.norm(projected, axis=-1)  # [256]
        
        # CMAS = magnitude × cosine_similarity
        cmas = magnitudes * cos_sim
        
        # Normalize to [0, 1]
        cmas = cmas - cmas.min()
        cmas = cmas / (cmas.max() + 1e-8)
        
        # Reshape to 16x16 grid
        cmas_2d = cmas.reshape(self.grid_size, self.grid_size)
        
        return cmas_2d
    
    def get_attention_heatmap(self, image, alpha=0.5, colormap='jet', invert=True):
        """
        Generate attention heatmap overlay on image.
        
        Args:
            image: PIL Image or path to image
            alpha: Overlay transparency (0-1)
            colormap: Matplotlib colormap name ('jet', 'hot', 'viridis', etc.)
            invert: If True, invert attention (highlight high-attention as red)
            
        Returns:
            attention_2d: [16, 16] raw attention scores
            heatmap_image: PIL Image with heatmap overlay
        """
        # Load original image
        if isinstance(image, str):
            original_image = Image.open(image).convert('RGB')
        else:
            original_image = image.convert('RGB')
        
        # Compute attention
        attention_2d, _, _ = self.compute_attention(original_image)
        
        # Optionally invert (so high attention = red)
        if invert:
            attention_2d = 1.0 - attention_2d
        
        # Generate heatmap overlay
        heatmap_image = self._generate_heatmap_overlay(
            original_image, attention_2d, alpha, colormap
        )
        
        return attention_2d, heatmap_image
    
    def _generate_heatmap_overlay(self, image, attention_2d, alpha=0.5, colormap='jet'):
        """Generate heatmap overlay on image."""
        import matplotlib.cm as cm
        
        img_array = np.array(image)
        h, w = img_array.shape[:2]
        
        # Resize attention to image size using bilinear interpolation
        attention_resized = Image.fromarray((attention_2d * 255).astype(np.uint8))
        attention_resized = attention_resized.resize((w, h), Image.BILINEAR)
        attention_array = np.array(attention_resized) / 255.0
        
        # Apply colormap
        cmap = cm.get_cmap(colormap)
        heatmap_colored = cmap(attention_array)[:, :, :3]
        heatmap_colored = (heatmap_colored * 255).astype(np.uint8)
        
        # Blend with original image
        blended = (1 - alpha) * img_array + alpha * heatmap_colored
        blended = blended.astype(np.uint8)
        
        return Image.fromarray(blended)
    
    def get_top_attention_regions(self, image, top_k=5):
        """
        Get the top-k most attended regions.
        
        Args:
            image: PIL Image or path to image
            top_k: Number of top regions to return
            
        Returns:
            List of (row, col, attention_score) tuples
        """
        attention_2d, _, _ = self.compute_attention(image)
        
        # Flatten and get top-k indices
        flat_attention = attention_2d.flatten()
        top_indices = np.argsort(flat_attention)[-top_k:][::-1]
        
        regions = []
        for idx in top_indices:
            row = idx // self.grid_size
            col = idx % self.grid_size
            score = float(flat_attention[idx])
            regions.append((row, col, score))
        
        return regions
    
    def get_pooled_embedding(self, image):
        """
        Get the global pooled embedding for an image.
        Useful for comparing images or clustering.
        
        Args:
            image: PIL Image or path to image
            
        Returns:
            pooled: [hidden_dim] global representation
        """
        _, _, pooled = self.compute_attention(image)
        return pooled


def demo_edge_inference():
    """Demo the edge inference with a sample image."""
    import sys
    
    # Check for ONNX export directory
    onnx_dir = "./onnx_export"
    if not os.path.exists(onnx_dir):
        print(f"Error: ONNX export directory not found: {onnx_dir}")
        print("Please run export_vision_onnx.py first to generate the ONNX model.")
        sys.exit(1)
    
    # Initialize encoder
    print("Initializing EdgeVisionEncoder...")
    encoder = EdgeVisionEncoder(onnx_dir)
    
    # Find a test image
    test_images = [
        "ECG-Atrial-Fibrillation-4-1024x561 (1).jpg",
        "normal-sinus-rhythm-2 (1).jpg",
        "ecg_image.png",
        "F1.png"
    ]
    
    test_image = None
    for img_name in test_images:
        if os.path.exists(img_name):
            test_image = img_name
            break
    
    if test_image is None:
        # Create a dummy test image
        test_image = Image.new('RGB', (896, 896), color='white')
        print("Using dummy white image for testing")
    else:
        print(f"Using test image: {test_image}")
    
    # Compute attention
    print("\nComputing attention...")
    attention_2d, heatmap_image = encoder.get_attention_heatmap(test_image)
    
    print(f"Attention map shape: {attention_2d.shape}")
    print(f"Attention range: [{attention_2d.min():.4f}, {attention_2d.max():.4f}]")
    
    # Get top attention regions
    top_regions = encoder.get_top_attention_regions(test_image)
    print(f"\nTop 5 attention regions:")
    for i, (row, col, score) in enumerate(top_regions):
        print(f"  {i+1}. Grid ({row}, {col}): score = {score:.4f}")
    
    # Save heatmap
    output_path = "edge_attention_heatmap.png"
    heatmap_image.save(output_path)
    print(f"\nHeatmap saved to: {output_path}")
    
    return encoder


if __name__ == '__main__':
    demo_edge_inference()


In [None]:
!wget https://upload.wikimedia.org/wikipedia/commons/3/32/ECG_Atrial_Fibrillation.jpg

## Test it

In [None]:
from edge_inference import EdgeVisionEncoder

# Initialize with path to ONNX export folder
encoder = EdgeVisionEncoder("./onnx_export", model_name="vision_encoder_quant.onnx")

# Get attention heatmap
attention_2d, heatmap_image = encoder.get_attention_heatmap("ECG_Atrial_Fibrillation.jpg")
heatmap_image.save("attention_overlay.png")

# Get top attention regions
top_regions = encoder.get_top_attention_regions("ECG_Atrial_Fibrillation.jpg", top_k=5)

# Get pooled embedding for similarity comparisons
embedding = encoder.get_pooled_embedding("ECG_Atrial_Fibrillation.jpg")

## 4 Bit Quantization

In [None]:
"""
Quantize ONNX Vision Encoder to INT8 (4x Smaller)
==================================================

Run this on your DESKTOP to compress the vision_encoder.onnx model.

Usage:
    python quantize_onnx.py

This will create:
- onnx_export/vision_encoder_quant.onnx (~150MB)
"""

import os
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType

def quantize_model(input_path, output_path):
    print(f"Quantizing {input_path}...")
    
    # Dynamic quantization to INT8 (roughly 4x smaller than FP32)
    # This works great for vision transformers/ViT/SigLIP on CPU
    quantize_dynamic(
        model_input=input_path,
        model_output=output_path,
        weight_type=QuantType.QUInt8,  # Quantize weights to UINT8
    )
    
    # Get sizes
    orig_size = os.path.getsize(input_path) / (1024 * 1024)
    quant_size = os.path.getsize(output_path) / (1024 * 1024)
    
    print(f"Done!")
    print(f"Original size: {orig_size:.2f} MB")
    print(f"Quantized size: {quant_size:.2f} MB")
    print(f"Reduction: {orig_size / quant_size:.1f}x")

def main():
    onnx_dir = "./onnx_export"
    input_model = os.path.join(onnx_dir, "vision_encoder.onnx")
    output_model = os.path.join(onnx_dir, "vision_encoder_quant.onnx")
    
    if not os.path.exists(input_model):
        print(f"Error: Could not find {input_model}")
        print("Please make sure you have downloaded the onnx_export folder from Kaggle.")
        return

    quantize_model(input_model, output_model)
    
    print("\n✅ Quantization complete!")
    print(f"New model saved to: {output_model}")
    print("\nTo use this model, update your edge_inference.py call:")
    print('encoder = EdgeVisionEncoder("./onnx_export", model_name="vision_encoder_quant.onnx")')

if __name__ == '__main__':
    main()


# Inference the quantized model

In [2]:
"""
Edge Inference with Quantized ONNX Model (INT8)
=================================================

Dedicated script for running the ~150MB INT8 quantized vision encoder.
Run this on your desktop app for fast, low-memory inference.

Step 1: Run 'quantize_onnx.py' first to generate the quantized model
Step 2: Run this script to test it

Usage:
    from quantized_inference import QuantizedVisionEncoder
    
    encoder = QuantizedVisionEncoder("./onnx_export")
    attention_map, heatmap_image = encoder.get_attention_heatmap(ecg_image)
"""

import os
import json
import time
import numpy as np
from PIL import Image


class QuantizedVisionEncoder:
    """
    Inference wrapper for the Quantized (INT8) ONNX model.
    """
    
    def __init__(self, onnx_dir, model_name="vision_encoder_quant.onnx"):
        import onnxruntime as ort
        
        self.onnx_dir = onnx_dir
        self.model_path = os.path.join(onnx_dir, model_name)
        self.config_path = os.path.join(onnx_dir, "vision_config.json")
        
        # Check if model exists
        if not os.path.exists(self.model_path):
            raise FileNotFoundError(f"Quantized model not found at {self.model_path}. Please run quantize_onnx.py first.")
            
        # Load config
        with open(self.config_path, 'r') as f:
            self.config = json.load(f)
            
        # Initialize session options for CPU performance
        sess_options = ort.SessionOptions()
        sess_options.intra_op_num_threads = 4  # Adjust based on CPU cores
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        
        print(f"Loading quantized model: {model_name}...")
        start_time = time.time()
        
        # Run on CPU (Quantized models are optimized for CPU)
        self.session = ort.InferenceSession(
            self.model_path, 
            sess_options, 
            providers=['CPUExecutionProvider']
        )
        
        # Load projection weight if available (for full CMAS)
        proj_path = os.path.join(onnx_dir, "projection_weight.npy")
        if os.path.exists(proj_path):
            self.projection_weight = np.load(proj_path)
        else:
            self.projection_weight = None

        load_time = time.time() - start_time
        print(f"Model loaded in {load_time:.2f}s")
        
        self.input_height = self.config.get('input_height', 896)
        self.input_width = self.config.get('input_width', 896)
        self.grid_size = self.config.get('grid_size', 16)
    
    def preprocess(self, image):
        """Standard preprocessing for MedGemma vision encoder."""
        if isinstance(image, str):
            image = Image.open(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')
            
        # Resize
        image = image.resize((self.input_width, self.input_height), Image.BILINEAR)
        
        # Normalize
        img = np.array(image).astype(np.float32) / 255.0
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = (img - mean) / std
        
        # CHW format
        img = img.transpose(2, 0, 1)
        img = np.expand_dims(img, 0).astype(np.float32)
        return img

    def get_attention_heatmap(self, image, alpha=0.5, colormap='jet', invert=True, threshold=0.6):
        """
        Get attention heatmap overlay.
        threshold: Only show attention above this value (0-1). 
                   Higher = less red, more focused on peaks.
        """
        # Preprocess
        pixel_values = self.preprocess(image)
        
        # Inference
        start_time = time.time()
        outputs = self.session.run(None, {'pixel_values': pixel_values})
        infer_time = time.time() - start_time
        
        projected, attention, pooled = outputs
        
        # Process attention
        attention_2d = attention[0].reshape(self.grid_size, self.grid_size)
        
        # Normalize to 0-1
        att_min, att_max = attention_2d.min(), attention_2d.max()
        if att_max > att_min:
            attention_2d = (attention_2d - att_min) / (att_max - att_min)
            
        # Apply thresholding/scaling to focus on peaks
        # 1. Hard threshold: Zero out anything below threshold
        # attention_2d[attention_2d < threshold] = 0
        
        # 2. Linear Windowing (Robust Control)
        # vmin: Below this is Blue (0.0) -> Suppress Background
        # vmax: Above this is Red (1.0) -> Boost faint signals
        
        vmin = 0.3  # Increase to suppress more background
        vmax = 0.7  # Decrease to make yellow/orange appear Red
        
        # Apply window
        attention_2d = (attention_2d - vmin) / (vmax - vmin + 1e-8)
        
        # Clip to 0-1 range
        attention_2d = np.clip(attention_2d, 0, 1)
            
        if invert:
            # For 'jet' colormap: Red is high (1.0), Blue is low (0.0)
            # Default attention: high value = important
            # If we want red=important, we don't invert if map is 0..1
            pass 
        else:
            attention_2d = 1.0 - attention_2d
            
        # Generate overlay
        heatmap_image = self._overlay_heatmap(image, attention_2d, alpha, colormap)
        
        return attention_2d, heatmap_image, infer_time

    def _overlay_heatmap(self, image, attention, alpha, colormap):
        import matplotlib.cm as cm
        
        if isinstance(image, str):
            image = Image.open(image).convert('RGB')
        
        w, h = image.size
        
        # Resize attention
        att_img = Image.fromarray((attention * 255).astype(np.uint8))
        att_img = att_img.resize((w, h), Image.BILINEAR)
        att_arr = np.array(att_img) / 255.0
        
        # Colorize
        cmap = cm.get_cmap(colormap)
        colored = (cmap(att_arr)[:, :, :3] * 255).astype(np.uint8)
        
        # Blend
        img_arr = np.array(image)
        blended = (1 - alpha) * img_arr + alpha * colored
        return Image.fromarray(blended.astype(np.uint8))


def demo():
    onnx_dir = "./onnx_export"
    if not os.path.exists(os.path.join(onnx_dir, "vision_encoder_quant.onnx")):
        print("Error: Quantized model not found. Run quantize_onnx.py first!")
        return

    encoder = QuantizedVisionEncoder(onnx_dir)
    
    # Try to find a test image
    test_img = "/kaggle/working/ECG_Atrial_Fibrillation.jpg"
    if not os.path.exists(test_img):
        # Create dummy
        Image.new('RGB', (500, 300), color='white').save(test_img)
        print("Created dummy test image.")
        
    print(f"\nRunning inference on {test_img}...")
    att, heatmap, latency = encoder.get_attention_heatmap(test_img)
    
    print(f"Inference time: {latency*1000:.1f} ms")
    print(f"Attention map: {att.shape}")
    
    heatmap.save("quantized_heatmap.png")
    print("Saved quantized_heatmap.png")

if __name__ == "__main__":
    demo()


Loading quantized model: vision_encoder_quant.onnx...
Model loaded in 0.96s

Running inference on /kaggle/working/ECG_Atrial_Fibrillation.jpg...


  cmap = cm.get_cmap(colormap)


Inference time: 38496.1 ms
Attention map: (16, 16)
Saved quantized_heatmap.png
