In [None]:
!pip install psycopg2 datasets
!pip install sentencepiece

In [None]:
import os
import base64
import io
import numpy as np
import psycopg2
from psycopg2.extras import execute_batch
from datasets import load_dataset
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
from tqdm.auto import tqdm
import tensorflow as tf
from typing import Dict, List, Tuple
import time

In [None]:
DATASET_NAME = "itsanmolgupta/mimic-cxr-dataset"
MODEL_NAME = "google/medsiglip-448"
BATCH_SIZE = 16  # Adjust based on your GPU memory
IMAGE_SIZE = 448  # MedSigLIP expects 448x448 images

# Device configuration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

In [None]:
from huggingface_hub import login
login()

In [None]:
def resize_image_tf_style(image: Image.Image, size: int = 448) -> Image.Image:
    """
    Resize image using TensorFlow's bilinear method to match MedSigLIP's
    preprocessing pipeline from Big Vision library.
    
    Args:
        image: PIL Image
        size: Target size (default 448)
    
    Returns:
        Resized PIL Image
    """
    # Convert PIL to numpy array
    img_array = np.array(image)
    
    # Use TensorFlow's resize (bilinear, no antialiasing)
    resized = tf.image.resize(
        images=img_array,
        size=[size, size],
        method='bilinear',
        antialias=False
    )
    
    # Convert back to PIL Image
    return Image.fromarray(resized.numpy().astype(np.uint8))


def image_to_base64(image: Image.Image) -> str:
    """
    Convert PIL Image to base64 string.
    
    Args:
        image: PIL Image
    
    Returns:
        Base64 encoded string
    """
    buffered = io.BytesIO()
    # Save as PNG to preserve quality
    image.save(buffered, format="PNG")
    img_bytes = buffered.getvalue()
    return base64.b64encode(img_bytes).decode('utf-8')


def process_batch(
    batch: Dict,
    processor,
    model,
    device: str
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Process a batch of data to generate embeddings and base64 images.
    
    Args:
        batch: Dictionary containing images, findings, and impressions
        processor: MedSigLIP processor
        model: MedSigLIP model
        device: Computing device (cuda/cpu)
    
    Returns:
        Tuple of (image_embeddings, text_embeddings, base64_images)
    """
    # Resize images using TensorFlow-style resizing
    resized_images = [resize_image_tf_style(img.convert("RGB")) for img in batch['image']]
    
    # Convert images to base64 (original images, not resized)
    base64_images = [image_to_base64(img) for img in batch['image']]
    
    # Combine findings and impression for text embeddings
    # Handle None values in impression
    texts = [
        f"Findings: {finding}. Impression: {impression if impression else 'None'}"
        for finding, impression in zip(batch['findings'], batch['impression'])
    ]
    
    # Process images ONLY (no text)
    image_inputs = processor(
        images=resized_images,
        return_tensors="pt"
    ).to(device)
    
    # Generate IMAGE embeddings only
    with torch.no_grad():
        # Get image embeddings
        image_embeds = model.get_image_features(**image_inputs)
    
    # Process texts ONLY (no images) using the tokenizer
    text_inputs = processor.tokenizer(
        texts,
        padding="max_length",
        truncation=True,
        max_length=64,
        return_tensors="pt"
    ).to(device)
    
    # Generate TEXT embeddings only
    with torch.no_grad():
        # Get text embeddings
        text_embeds = model.get_text_features(**text_inputs)
    
    # Move embeddings to CPU
    image_embeds = image_embeds.cpu().numpy()
    text_embeds = text_embeds.cpu().numpy()
    
    return image_embeds, text_embeds, base64_images

In [None]:
def load_data_and_model():
    """
    Load the MIMIC-CXR dataset and MedSigLIP model.
    
    Returns:
        Tuple of (dataset, processor, model)
    """
    print("Loading dataset...")
    dataset = load_dataset(DATASET_NAME, split="train")
    print(f"Dataset loaded: {len(dataset)} samples")
    
    print(f"\nLoading MedSigLIP model from {MODEL_NAME}...")
    print("Note: You need to accept the model's terms of use on Hugging Face")
    
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
    model.eval()  # Set to evaluation mode
    
    print("Model loaded successfully!")
    
    return dataset, processor, model

# Load dataset and model
# Uncomment the line below when ready to load
dataset, processor, model = load_data_and_model()

In [None]:
def process_dataset(dataset, processor, model, batch_size: int = BATCH_SIZE):
    """
    Process entire dataset and generate embeddings.
    
    Args:
        dataset: HuggingFace dataset
        processor: MedSigLIP processor
        model: MedSigLIP model
        batch_size: Batch size for processing
    
    Returns:
        List of processed records
    """
    all_records = []
    num_batches = (len(dataset) + batch_size - 1) // batch_size
    
    print(f"\nProcessing {len(dataset)} samples in {num_batches} batches...")
    
    for i in tqdm(range(0, len(dataset), batch_size), desc="Processing batches"):
        batch_end = min(i + batch_size, len(dataset))
        
        try:
            # Extract batch manually to ensure proper structure
            batch = {
                'image': [dataset[j]['image'] for j in range(i, batch_end)],
                'findings': [dataset[j]['findings'] for j in range(i, batch_end)],
                'impression': [dataset[j]['impression'] for j in range(i, batch_end)]
            }
            
            # Process batch
            image_embeds, text_embeds, base64_images = process_batch(
                batch, processor, model, DEVICE
            )
            
            # Prepare records for database insertion
            for j in range(len(batch['findings'])):
                record = {
                    'findings': batch['findings'][j],
                    'impression': batch['impression'][j],
                    'image_base64': base64_images[j],
                    'image_embedding': image_embeds[j].tolist(),
                    'text_embedding': text_embeds[j].tolist()
                }
                all_records.append(record)
            
            # Clear cache periodically
            if (i // batch_size) % 10 == 0:
                torch.cuda.empty_cache() if torch.cuda.is_available() else None
                
        except Exception as e:
            print(f"\nError processing batch starting at index {i}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print(f"\nProcessing complete! Total records: {len(all_records)}")
    return all_records

# Process the dataset
# Uncomment when ready to process
all_records = process_dataset(dataset, processor, model)