In [None]:
# @title 1. Imports

from datasets import load_dataset
from collections import defaultdict
import numpy as np
import json
from PIL import Image
import base64
import os
import google.generativeai as genai
from PIL import Image
import time
from tqdm import tqdm
import random
import datetime
import hashlib

In [None]:
# @title 2. Set up Gemini API
print("Setting up API keys...")
api_keys = [
    "xyz1",
    "xyz2",
    "xyz3",
    "xyz4",
    # Add your additional API keys here
]
current_key_index = 0
print(f"Loaded {len(api_keys)} API key(s)")
genai.configure(api_key=api_keys[current_key_index])
print("Initial API key configured successfully")

# Rate limiting variables (tracking per key)
API_CALLS = {key: [] for key in api_keys}  # Tracks timestamp of each API call per key
KEY_ERRORS = {key: 0 for key in api_keys}  # Tracks errors per key
RATE_LIMIT = 60  # Free tier limit: 60 requests per minute
MIN_DELAY = 1.2  # Minimum delay between requests

# Function to rotate to the next available API key
def rotate_api_key():
    """Switch to the next available API key"""
    global current_key_index

    # Move to next key
    previous_key = api_keys[current_key_index]
    current_key_index = (current_key_index + 1) % len(api_keys)
    new_key = api_keys[current_key_index]

    # Configure the API with the new key
    genai.configure(api_key=new_key)

    print(f"\nRotated from API key ending in '...{previous_key[-4:]}' to '...{new_key[-4:]}'")
    return new_key

# Function to wait if needed to respect rate limits
def respect_rate_limit():
    """Ensure we don't exceed the rate limit by tracking API calls and waiting if needed."""
    global API_CALLS, current_key_index

    current_key = api_keys[current_key_index]
    now = time.time()

    # Remove API calls older than 1 minute from our tracking list for current key
    API_CALLS[current_key] = [t for t in API_CALLS[current_key] if now - t < 60]

    # If we're at or near the limit for current key, try rotating keys first
    if len(API_CALLS[current_key]) >= RATE_LIMIT - 5:  # Leave a small buffer
        # Check if we have other keys with capacity
        for idx, key in enumerate(api_keys):
            if idx != current_key_index and len([t for t in API_CALLS[key] if now - t < 60]) < RATE_LIMIT - 10:
                # Found a key with available capacity
                current_key_index = idx
                genai.configure(api_key=key)
                print(f"\nSwitched to API key ending in '...{key[-4:]}' with available capacity")
                return

        # If we reach here, all keys are near capacity, so use the current key but wait
        wait_time = 60 - (now - API_CALLS[current_key][0]) + random.uniform(1, 3)

        # Log the cooldown period
        cooldown_end = datetime.datetime.now() + datetime.timedelta(seconds=wait_time)
        print(f"\nAll keys near rate limit. Cooling down for {wait_time:.1f} seconds until {cooldown_end.strftime('%H:%M:%S')}")

        time.sleep(wait_time)

        # After waiting, clear expired entries again
        now = time.time()
        API_CALLS[current_key] = [t for t in API_CALLS[current_key] if now - t < 60]

    # Add a minimum delay between requests regardless
    time.sleep(MIN_DELAY)

    # Track this API call with the current key
    API_CALLS[current_key].append(time.time())


In [None]:
# @title 3. Load data

split = "train"
json_file = f"md5_{split}/vqa_rad_gemini_batch1_{split}.json"
temp_json_file = f"md5_{split}/vqa_rad_gemini_batch1_{split}_temp.json"

# Load the dataset
ds = load_dataset("flaviagiammarino/vqa-rad")

def get_image_hash(image):
    """Generate MD5 hash of an image."""
    # Convert image to bytes and compute MD5 hash
    md5_hash = hashlib.md5(image.tobytes()).hexdigest()
    return md5_hash

# Create a structured dataset
structured_dataset = defaultdict(lambda: {'image': None, 'qa_pairs': []})

# Process data set
print(f"Processing {split} set...")
for idx in range(len(ds[split])):
    image = ds[split][idx]['image']
    img_hash = get_image_hash(image)

    if structured_dataset[img_hash]['image'] is None:
        structured_dataset[img_hash]['image'] = {
            'data': image,
            'size': image.size,
            'split': split
        }

    structured_dataset[img_hash]['qa_pairs'].append({
        'question': ds[split][idx]['question'],
        'answer': ds[split][idx]['answer']
    })

print(f"\nTotal unique images: {len(structured_dataset)}")
total_qa_pairs = sum(len(data['qa_pairs']) for data in structured_dataset.values())
print(f"Total QA pairs: {total_qa_pairs}")

# Create new md5_images directory if it doesn't exist
md5_images_dir = "md5_images"
os.makedirs(md5_images_dir, exist_ok=True)
print(f"Created {md5_images_dir} directory for storing MD5-hashed images")



In [None]:
# @title 4. Generate descriptive answer using Gemini with exponential backoff for retries

def generate_descriptive_answer(image_path, question, reference_answer, max_retries=5):
    """Generate descriptive answer with retry logic for rate limits and transient errors."""
    global current_key_index

    retry_count = 0
    base_wait = 2  # Base wait time in seconds for exponential backoff

    while retry_count <= max_retries:
        try:
            # Respect the rate limit before making a call
            respect_rate_limit()

            # Load the image
            image = Image.open(image_path)

            # Initialize Gemini 2.0 Flash model
            model = genai.GenerativeModel('gemini-2.0-flash')

            # Prepare the prompt
            prompt = f"""You are a medical imaging expert. Please analyze this medical image and answer the following question.

            Question: {question}
            Reference Answer: {reference_answer}

            Your response should have the following pattern: <Reference Answer>. Explanation: <a clear explanation supporting the reference answer.>
            """

            # Generate response with max token limit
            generation_config = genai.GenerationConfig(
                max_output_tokens=100,  # Limit to 100 tokens (approximately 75-100 words)
                temperature=0.2,  # Lower temperature for more focused responses
            )

            response = model.generate_content(
                [prompt, image],
                generation_config=generation_config
            )

            return response.text

        except Exception as e:
            retry_count += 1
            error_str = str(e)
            current_key = api_keys[current_key_index]

            # Track error for current key
            KEY_ERRORS[current_key] += 1

            # Check if it's a rate limit or quota error
            if "rate limit" in error_str.lower() or "quota" in error_str.lower() or "429" in error_str:
                print(f"\nRate limit or quota error with key ending in '...{current_key[-4:]}'")

                # If we have multiple keys, try rotating to a different key first
                if len(api_keys) > 1:
                    rotate_api_key()
                    # Reset retry count to give the new key a fresh chance
                    retry_count = 0
                    continue
                else:
                    # Use exponential backoff with jitter for single key
                    wait_time = (base_wait ** retry_count) + random.uniform(1, 5)
                    cooldown_end = datetime.datetime.now() + datetime.timedelta(seconds=wait_time)
                    print(f"\nNo additional keys available. Retry {retry_count}/{max_retries} after {wait_time:.1f}s ({cooldown_end.strftime('%H:%M:%S')})")
                    time.sleep(wait_time)
            else:
                # For non-rate-limit errors, still use backoff but with less wait time
                wait_time = retry_count * 2 + random.uniform(0, 1)
                print(f"\nError: {error_str}. Retrying in {wait_time:.1f}s ({retry_count}/{max_retries})")
                time.sleep(wait_time)

                # If this key has had multiple errors, try rotating
                if KEY_ERRORS[current_key] >= 3 and len(api_keys) > 1:
                    rotate_api_key()
                    # Reset error count for the new key
                    retry_count = 0
                    continue

            # If we've exhausted retries
            if retry_count > max_retries:
                return f"Error after {max_retries} retries: {error_str}"

# Initialize empty enhanced dataset
enhanced_dataset = {}

# Get all image hashes
img_hash_list = list(structured_dataset.keys())

# Process first 150 images
BATCH_SIZE = len(img_hash_list)
print(f"\nProcessing first {BATCH_SIZE} images...")


# Process images
processed_count = 0
start_time = time.time()

for img_idx, img_hash in enumerate(tqdm(img_hash_list[:BATCH_SIZE], desc="Processing Images")):
    img_data = structured_dataset[img_hash]

    # Save the image to disk (use MD5 hash in filename)
    image_filename = f"{md5_images_dir}/image_md5_{img_hash}.png"
    img_data['image']['data'].save(image_filename)

    # Process this image's Q&A pairs
    enhanced_image_data = {
        'image_info': {
            'size': img_data['image']['size'],
            'split': img_data['image']['split']
        },
        'image_path': image_filename,  # Store the path with MD5 hash
        'qa_pairs': []
    }

    qa_count = len(img_data['qa_pairs'])
    print(f"\nProcessing image {img_idx + 1}/{BATCH_SIZE} (Total: {img_idx + 1}/{len(structured_dataset)})")
    print(f"QA pairs in this image: {qa_count}")

    # Calculate and show ETA for this image
    eta_seconds = qa_count * (MIN_DELAY + 0.5)  # Rough estimate
    eta_time = datetime.datetime.now() + datetime.timedelta(seconds=eta_seconds)
    print(f"Estimated completion time for this image: {eta_time.strftime('%H:%M:%S')}")

    for i, qa in enumerate(tqdm(img_data['qa_pairs'], desc=f"Processing QA pairs")):
        try:
            # Display current progress with timestamps
            now = datetime.datetime.now().strftime("%H:%M:%S")
            print(f"[{now}] Processing QA pair {i+1}/{qa_count}: {qa['question'][:50]}...")

            # Generate descriptive answer with retry logic
            descriptive_answer = generate_descriptive_answer(image_filename, qa['question'], qa['answer'])

            # Store both original and generated answers
            enhanced_qa = {
                'question': qa['question'],
                'original_answer': qa['answer'],
                'descriptive_answer': descriptive_answer,
                'split': img_data['image']['split']  # Include train/test split label
            }

            enhanced_image_data['qa_pairs'].append(enhanced_qa)

        except Exception as e:
            print(f"\nUnhandled error processing QA pair {i+1} for image {img_idx + 1}: {str(e)}")
            # Add partial data
            enhanced_qa = {
                'question': qa['question'],
                'original_answer': qa['answer'],
                'descriptive_answer': f"Error generating answer: {str(e)}",
                'split': img_data['image']['split']  # Include train/test split label
            }
            enhanced_image_data['qa_pairs'].append(enhanced_qa)

        # Save after every 5 QA pairs to avoid data loss
        if (i + 1) % 5 == 0:
            # Make a temporary copy of the dataset with the current progress
            temp_dataset = enhanced_dataset.copy()
            temp_dataset[img_hash] = enhanced_image_data
            with open(temp_json_file, 'w') as f:
                json.dump(temp_dataset, f)

    # Insert a cooldown period after each image to be extra safe with rate limits
    cooldown = random.uniform(5, 10)
    print(f"\nCooldown period of {cooldown:.1f}s after completing image {img_idx + 1}")
    time.sleep(cooldown)

    # Save this image's data to the enhanced dataset
    enhanced_dataset[img_hash] = enhanced_image_data
    processed_count += 1

    # Save intermediate results every 5 images or at the end
    if processed_count % 5 == 0 or processed_count == BATCH_SIZE:
        with open(json_file, 'w') as f:
            json.dump(enhanced_dataset, f, indent=2)

        # Calculate and display statistics about progress
        elapsed = time.time() - start_time
        images_left = BATCH_SIZE - processed_count
        avg_time_per_image = elapsed / processed_count if processed_count > 0 else 0
        eta_seconds = images_left * avg_time_per_image
        eta_time = datetime.datetime.now() + datetime.timedelta(seconds=eta_seconds)

        print(f"\nProgress: {processed_count}/{BATCH_SIZE} images processed ({processed_count/BATCH_SIZE*100:.1f}%)")
        print(f"Time elapsed: {elapsed/60:.1f} minutes")
        print(f"Estimated time remaining: {eta_seconds/60:.1f} minutes (completion around {eta_time.strftime('%H:%M:%S')})")
        print(f"Results saved after processing {processed_count}/{BATCH_SIZE} images")


In [None]:
# @title 5. Save the final enhanced dataset

with open(json_file, 'w') as f:
    json.dump(enhanced_dataset, f, indent=2)

# Calculate and show final statistics
total_qa_processed = sum(len(data['qa_pairs']) for data in enhanced_dataset.values())
total_time = time.time() - start_time

print(f"\nComplete enhanced dataset has been saved to '{json_file}'")
print(f"Processed {processed_count} images with {total_qa_processed} QA pairs")
print(f"Total processing time: {total_time/60:.1f} minutes ({total_time/3600:.2f} hours)")
print(f"Average time per QA pair: {total_time/total_qa_processed:.2f} seconds")

# Print sample of the enhanced dataset (just the first image)
print("\nSample of Enhanced Dataset (first image):")
first_key = next(iter(enhanced_dataset))
print(json.dumps({first_key: enhanced_dataset[first_key]}, indent=2))