# ODIR-5K RET-CLIP Unified Pipeline

This notebook implements a complete end-to-end pipeline for training **RET-CLIP** on the **ODIR-5K dataset** with English clinical text.

---

## Research Contribution

**"English BERT Embeddings for Binocular Retinal Image-Text Alignment"**

- ‚úÖ First validation of RET-CLIP architecture on English clinical text
- ‚úÖ Cross-lingual transfer validation (original: Chinese ‚Üí our work: English)
- ‚úÖ Real binocular fundus images from ODIR-5K (not duplicated monocular)
- ‚úÖ Comparison of medical domain-specific vs general BERT models

---

## Dataset: ODIR-5K (Ocular Disease Intelligent Recognition)

- **5,000 patients** with genuine paired left/right eye fundus images
- **10,000 images** total (2 per patient)
- **Metadata**: Patient Age, Sex, Diagnostic Keywords (English)
- **8 Disease Categories**: Normal, Diabetes, Glaucoma, Cataract, AMD, Hypertension, Myopia, Other

---

## RET-CLIP Architecture

**Binocular Vision-Language Foundation Model**

```
Left Eye Image  ‚îÄ‚îÄ‚Üí  Vision Encoder  ‚îÄ‚îÄ‚Üí  Left Projection  ‚îÄ‚îê
                                                             ‚îú‚îÄ‚Üí Tripartite Loss
Right Eye Image ‚îÄ‚îÄ‚Üí  Vision Encoder  ‚îÄ‚îÄ‚Üí  Right Projection ‚îÄ‚î§
                                                             ‚îÇ
Clinical Text   ‚îÄ‚îÄ‚Üí  Text Encoder    ‚îÄ‚îÄ‚Üí  Text Embedding   ‚îÄ‚îò
```

**Three-Level Contrastive Learning**:
1. Left eye ‚Üî Left-specific clinical description
2. Right eye ‚Üî Right-specific clinical description
3. Patient-level ‚Üî Holistic diagnostic impression (both eyes)

---

## Pipeline Overview

```
1. Setup & Configuration        ‚Üí Install packages, authenticate APIs
2. Load ODIR-5K Dataset         ‚Üí CSV metadata + paired fundus images
3. Generate Clinical Prompts    ‚Üí DSPy + OpenRouter (3 prompts/patient)
4. Preprocess for RET-CLIP      ‚Üí TSV + JSONL with eye_side annotations
5. Build LMDB Database          ‚Üí Efficient PyTorch DataLoader format
6. Train RET-CLIP               ‚Üí 10 epochs contrastive learning
7. Zero-Shot Evaluation         ‚Üí Vision-language alignment test
8. Linear Probing Evaluation    ‚Üí Feature quality assessment
9. Final Report                 ‚Üí Metrics, comparison, artifacts
```

---

## ‚è±Ô∏è Estimated Runtime

| Mode | Patients | Prompts Time | Training Time | Total |
|------|----------|--------------|---------------|-------|
| TEST | 100 | ~30 min | ~30 min (2 epochs) | **~2-3 hours** |
| FULL | 5,000 | ~4-5 hours | ~12-15 hours (10 epochs) | **~18-24 hours** |

---

## Prerequisites

1. **Google Colab** with A100 GPU (or T4 for testing)
2. **API Keys**:
   - HuggingFace Token: https://huggingface.co/settings/tokens
   - OpenRouter API Key: https://openrouter.ai/keys
3. **ODIR-5K Dataset**: Will be downloaded automatically

---

**Let's begin!**

# SECTION 1: Setup & Configuration

## Cell 1.1: Check GPU

In [None]:
# Check GPU availability
!nvidia-smi

## Cell 1.2: Mount Google Drive

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted successfully")

## Cell 1.3: Install Dependencies

In [None]:
# Install all required packages
print("üì¶ Installing dependencies...\n")

# DSPy packages (for prompt generation)
!pip install -q dspy-ai datasets huggingface-hub pandas pillow tqdm ipywidgets matplotlib

# RET-CLIP packages (for training)
!pip install -q ftfy regex
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q transformers
!pip install -q lmdb
!pip install -q scikit-learn seaborn

# OpenCV for image processing
!pip install -q opencv-python-headless

# Kaggle dataset download (automatic!)
!pip install -q kagglehub openpyxl

print("\n‚úÖ All dependencies installed")

## Cell 1.4: Clone RET-CLIP Repository (with English BERT fixes)

In [None]:
import os
import shutil
import sys

# Clone repository with fixed RET-CLIP
REPO_URL = "https://github.com/FahadAlothman-fsd/retclip-english.git"

if not os.path.exists('/content/retclip_repo'):
    print(f"Cloning repository from {REPO_URL}...")
    !git clone {REPO_URL} /content/retclip_repo
    print("‚úÖ Repository cloned")
else:
    print("‚úÖ Repository already exists")

# Copy retclip to /content/retclip
if not os.path.exists('/content/retclip'):
    print("\nCopying fixed RET-CLIP...")
    shutil.copytree('/content/retclip_repo/retclip', '/content/retclip')
    print("‚úÖ Copied to /content/retclip")
else:
    print("‚úÖ /content/retclip already exists")

# Add to Python path
sys.path.insert(0, '/content/retclip')
os.environ['PYTHONPATH'] = '/content/retclip'

print("\n‚úÖ RET-CLIP repository ready")
print("\n" + "="*80)
print("VERIFY: Using fixed RET-CLIP with:")
print("="*80)
print("  1. ‚úì English BERT configs (PubMedBERT, BERT-base, BioBERT)")
print("  2. ‚úì URL-safe base64 encoding/decoding")
print("  3. ‚úì 3-column TSV format support (patient_id, left_img, right_img)")
print("  4. ‚úì DDP checkpoint loading with 'module.' prefix stripping")
print(f"\nLocation: /content/retclip")
print(f"Source: {REPO_URL}")

## Cell 1.5: Configuration Parameters

In [None]:
# Configuration Parameters
import os

# TEST MODE: Set to True for quick testing with subset of data
TEST_MODE = True  # Set to False for full dataset training
NUM_TEST_PATIENTS = 100 if TEST_MODE else None

# TRAINING HYPERPARAMETERS
VISION_MODEL = "ViT-B-16"
IMAGE_SIZE = 224
BATCH_SIZE = 32 if TEST_MODE else 128
NUM_EPOCHS = 2 if TEST_MODE else 10
LEARNING_RATE = 5e-5
WARMUP_STEPS = 500

# TEXT ENCODER CONFIGURATION
# Default text encoder (used when comparison is disabled)
TEXT_MODEL = "microsoft-BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# TEXT ENCODER COMPARISON
RUN_TEXT_ENCODER_COMPARISON = True  # Set to False to train only PubMedBERT (faster)

# PROMPT GENERATION MODEL
PRIMARY_MODEL = "openrouter/google/gemini-2.5-flash-lite"  # Fast, high-quality model for prompt generation

# PROMPT GENERATION SETTINGS
CHECKPOINT_INTERVAL = 10  # Save checkpoint every N patients
DELAY_BETWEEN_CALLS = 0.5  # Delay between API calls (seconds) to avoid rate limiting

print(f"Configuration:")
print(f"  Mode: {'TEST (100 patients, 2 epochs)' if TEST_MODE else 'FULL (5000 patients, 10 epochs)'}")
print(f"  Vision Model: {VISION_MODEL}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  Prompt Model: {PRIMARY_MODEL}")
print(f"  Checkpoint Interval: {CHECKPOINT_INTERVAL} patients")
print(f"  API Delay: {DELAY_BETWEEN_CALLS}s")
print(f"\nText Encoder Comparison: {'‚úÖ ENABLED' if RUN_TEXT_ENCODER_COMPARISON else '‚ùå DISABLED (PubMedBERT only)'}")

if RUN_TEXT_ENCODER_COMPARISON:
    print("   Will train and compare 3 text encoders:")
    print("   - PubMedBERT (medical domain)")
    print("   - BERT-base (general English)")
    print("   - BioBERT (biomedical domain)")

# Define text encoders for comparison
# Note: model_id is for RET-CLIP config files (uses dashes), hf_model_id is for HuggingFace tokenizer (uses slashes)
TEXT_ENCODERS = [
    {
        "name": "PubMedBERT",
        "model_id": "microsoft-BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        "hf_model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        "description": "Medical domain-specific BERT trained on PubMed abstracts"
    },
    {
        "name": "BERT-base",
        "model_id": "bert-base-uncased",
        "hf_model_id": "bert-base-uncased",
        "description": "General English BERT (baseline)"
    },
    {
        "name": "BioBERT",
        "model_id": "dmis-lab-biobert-base-cased-v1.1",
        "hf_model_id": "dmis-lab/biobert-base-cased-v1.1",
        "description": "Biomedical domain BERT trained on PubMed + PMC"
    }
]

# GOOGLE DRIVE PATHS (will be set after mounting)
DRIVE_BASE = "/content/drive/MyDrive/RET-CLIP-ODIR"
DRIVE_DATA = f"{DRIVE_BASE}/data"
DRIVE_PROMPTS = f"{DRIVE_BASE}/prompts"
DRIVE_LMDB = f"{DRIVE_BASE}/lmdb"
DRIVE_CHECKPOINTS = f"{DRIVE_BASE}/checkpoints"
DRIVE_RESULTS = f"{DRIVE_BASE}/results"

print(f"\nGoogle Drive paths configured:")
print(f"  Base: {DRIVE_BASE}")
print(f"  Data: {DRIVE_DATA}")
print(f"  Prompts: {DRIVE_PROMPTS}")
print(f"  LMDB: {DRIVE_LMDB}")
print(f"  Checkpoints: {DRIVE_CHECKPOINTS}")
print(f"  Results: {DRIVE_RESULTS}")

## Cell 1.6: API Authentication

In [None]:
import os

# Try to load from Colab secrets first
try:
    from google.colab import userdata
    HF_TOKEN = userdata.get('HF_TOKEN')
    OPENROUTER_API_KEY = userdata.get('OPENROUTER_API_KEY')
    KAGGLE_USERNAME = userdata.get('KAGGLE_USERNAME')
    KAGGLE_API_TOKEN = userdata.get('KAGGLE_API_TOKEN')
    print("‚úÖ API keys loaded from Colab secrets")
except:
    # Manual entry if secrets not available
    HF_TOKEN = ""
    OPENROUTER_API_KEY = ""
    KAGGLE_USERNAME = ""
    KAGGLE_API_TOKEN = ""
    print("‚ö†Ô∏è Colab secrets not available - please set tokens manually")

# Authenticate with HuggingFace
if HF_TOKEN:
    from huggingface_hub import login
    login(token=HF_TOKEN)
    print("‚úÖ Authenticated with HuggingFace")
else:
    print("‚ùå HF_TOKEN not set")

# Configure Kaggle credentials
if KAGGLE_API_TOKEN and KAGGLE_USERNAME:
    os.environ['KAGGLE_USERNAME'] = KAGGLE_USERNAME
    os.environ['KAGGLE_KEY'] = KAGGLE_API_TOKEN
    print(f"‚úÖ Kaggle configured (username: {KAGGLE_USERNAME})")
elif KAGGLE_API_TOKEN:
    os.environ['KAGGLE_KEY'] = KAGGLE_API_TOKEN
    print("‚ö†Ô∏è Kaggle key set, but no username - may fail")
else:
    print("‚ùå KAGGLE_API_TOKEN not set")

# Configure DSPy LLM with OpenRouter
if OPENROUTER_API_KEY:
    import dspy
    
    primary_lm = dspy.LM(
        model=PRIMARY_MODEL,
        api_key=OPENROUTER_API_KEY,
        api_base="https://openrouter.ai/api/v1",
        extra_headers={"HTTP-Referer": "https://chiron.app", "X-Title": "Chiron"},
        num_retries=3,
    )
    
    dspy.configure(lm=primary_lm)
    print(f"‚úÖ LLM configured: {PRIMARY_MODEL.split('/')[-1]}")
else:
    print("‚ùå OPENROUTER_API_KEY not set")

print("\n" + "="*80)
print("API AUTHENTICATION STATUS")
print("="*80)
print(f"  HuggingFace: {'‚úÖ Ready' if HF_TOKEN else '‚ùå Not configured'}")
print(f"  Kaggle: {'‚úÖ Ready' if (KAGGLE_API_TOKEN and KAGGLE_USERNAME) else '‚ùå Not configured'}")
print(f"  OpenRouter: {'‚úÖ Ready' if OPENROUTER_API_KEY else '‚ùå Not configured'}")

# SECTION 2: Load ODIR-5K Dataset

## Cell 2.1: Download ODIR-5K Dataset from Kaggle

**Using kagglehub for automatic download!**

### What This Does:
- Downloads ODIR-5K dataset (~8 GB) including:
  - 10,000 fundus images (paired left/right)
  - Excel metadata file with diagnostic keywords
- Copies to Google Drive for persistence  
- **First run**: ~5-10 min download + ~5 min copy to Drive
- **Subsequent runs**: Instant (uses Drive cache)

In [None]:
import kagglehub
import shutil
from pathlib import Path

# Define paths
ODIR_DRIVE_DIR = f"{DRIVE_DATA}/ODIR-5K"

# Check if already downloaded to Drive
if os.path.exists(ODIR_DRIVE_DIR):
    # Find the Training Images directory (handle different nesting levels)
    possible_paths = [
        f"{ODIR_DRIVE_DIR}/ODIR-5K/ODIR-5K/Training Images",  # Extra nested
        f"{ODIR_DRIVE_DIR}/ODIR-5K/Training Images",          # Standard
        f"{ODIR_DRIVE_DIR}/Training Images",                   # Flat
    ]
    
    ODIR_IMAGES_DIR = None
    for path in possible_paths:
        if os.path.exists(path):
            ODIR_IMAGES_DIR = path
            break
    
    if ODIR_IMAGES_DIR:
        image_files = list(Path(ODIR_IMAGES_DIR).glob("*.jpg"))
        left_images = [f for f in image_files if '_left' in f.name]
        right_images = [f for f in image_files if '_right' in f.name]
        
        print("‚úÖ ODIR-5K found in Google Drive!")
        print(f"   Images directory: {ODIR_IMAGES_DIR}")
        print(f"   Total training images: {len(image_files)}")
        print(f"   Left eye: {len(left_images)}")
        print(f"   Right eye: {len(right_images)}")
        
        if len(left_images) == len(right_images):
            print(f"\n‚úÖ Paired images validated: {len(left_images)} patients")
            print("   Skipping download (using cached data from Drive)")
        else:
            print("\n‚ö†Ô∏è Warning: Unequal number of left and right images - will re-download")
            shutil.rmtree(ODIR_DRIVE_DIR, ignore_errors=True)
            ODIR_IMAGES_DIR = None
    else:
        print(f"‚ö†Ô∏è ODIR-5K directory exists but Training Images not found in expected locations")
        print(f"   Will re-download...")
        shutil.rmtree(ODIR_DRIVE_DIR, ignore_errors=True)

if not os.path.exists(ODIR_DRIVE_DIR) or not ODIR_IMAGES_DIR:
    print("üì• Downloading ODIR-5K from Kaggle...")
    print("   This will take ~5-10 minutes (~8 GB dataset)\n")
    
    # Download using kagglehub
    dataset_path = kagglehub.dataset_download("andrewmvd/ocular-disease-recognition-odir5k")
    
    print(f"‚úÖ Downloaded to: {dataset_path}")
    
    # Explore the FULL downloaded structure
    print("\n" + "="*80)
    print("COMPLETE DOWNLOADED STRUCTURE:")
    print("="*80)
    for root, dirs, files in os.walk(dataset_path):
        level = root.replace(dataset_path, '').count(os.sep)
        indent = ' ' * 2 * level
        rel_path = os.path.relpath(root, dataset_path)
        print(f"{indent}{rel_path}/")
        sub_indent = ' ' * 2 * (level + 1)
        
        # Show all files if less than 10, otherwise show summary
        if len(files) <= 10:
            for file in files:
                print(f"{sub_indent}{file}")
        else:
            for file in files[:3]:
                print(f"{sub_indent}{file}")
            print(f"{sub_indent}... and {len(files) - 3} more files")
    
    # Copy EVERYTHING to Google Drive to avoid missing anything
    print("\n" + "="*80)
    print("üìÇ Copying ENTIRE dataset to Google Drive...")
    print("="*80)
    print(f"   Destination: {ODIR_DRIVE_DIR}")
    
    # Remove existing if present
    if os.path.exists(ODIR_DRIVE_DIR):
        shutil.rmtree(ODIR_DRIVE_DIR)
    
    # Copy everything
    shutil.copytree(dataset_path, ODIR_DRIVE_DIR)
    
    print("‚úÖ Copy complete!")
    
    # Show what we got
    print("\n" + "="*80)
    print("COPIED TO DRIVE:")
    print("="*80)
    for root, dirs, files in os.walk(ODIR_DRIVE_DIR):
        level = root.replace(ODIR_DRIVE_DIR, '').count(os.sep)
        if level < 3:  # Only show top 3 levels
            indent = ' ' * 2 * level
            rel_path = os.path.relpath(root, ODIR_DRIVE_DIR)
            print(f"{indent}{rel_path}/ ({len(files)} files, {len(dirs)} dirs)")
    
    # Detect Training Images directory
    possible_paths = [
        f"{ODIR_DRIVE_DIR}/ODIR-5K/ODIR-5K/Training Images",
        f"{ODIR_DRIVE_DIR}/ODIR-5K/Training Images",
        f"{ODIR_DRIVE_DIR}/Training Images",
    ]
    
    ODIR_IMAGES_DIR = None
    for path in possible_paths:
        if os.path.exists(path):
            ODIR_IMAGES_DIR = path
            print(f"\n‚úÖ Found Training Images at: {ODIR_IMAGES_DIR}")
            break
    
    if not ODIR_IMAGES_DIR:
        print(f"\n‚ùå Could not find Training Images in any expected location")
        raise FileNotFoundError("Training Images directory not found after download")

# Final validation - load a sample image
if ODIR_IMAGES_DIR and os.path.exists(ODIR_IMAGES_DIR):
    image_files = list(Path(ODIR_IMAGES_DIR).glob("*.jpg"))
    left_images = [f for f in image_files if '_left' in f.name]
    right_images = [f for f in image_files if '_right' in f.name]
    
    print(f"\n‚úÖ Training images available:")
    print(f"   Total: {len(image_files)}")
    print(f"   Left eye: {len(left_images)}")
    print(f"   Right eye: {len(right_images)}")
    print(f"   Paired patients: {len(left_images)}")
    
    # Load sample image
    sample_images = list(Path(ODIR_IMAGES_DIR).glob("*_left.jpg"))
    if sample_images:
        from PIL import Image
        sample_path = sample_images[0]
        sample_img = Image.open(sample_path)
        print(f"\n‚úÖ Sample image loaded: {sample_img.size}")
        print(f"   Format: {sample_img.format}, Mode: {sample_img.mode}")
        print(f"\nüéâ Dataset ready!")
    else:
        print(f"\n‚ùå No left images found in: {ODIR_IMAGES_DIR}")
else:
    print(f"\n‚ùå Training Images directory not found: {ODIR_IMAGES_DIR}")

## Cell 2.2: Load ODIR-5K Metadata

Load the Excel metadata file (downloaded in Cell 2.1 along with images).

In [None]:
import pandas as pd

# Metadata Excel file path - check both possible locations
metadata_paths = [
    f"{DRIVE_DATA}/ODIR-5K/ODIR-5K/data.xlsx",  # Nested structure
    f"{DRIVE_DATA}/ODIR-5K/data.xlsx",  # Flat structure
    f"{DRIVE_DATA}/ODIR-5K/full_df.csv",  # Alternative CSV at root
]

metadata_path = None
for path in metadata_paths:
    if os.path.exists(path):
        metadata_path = path
        break

print(f"Loading ODIR-5K metadata...")

if not metadata_path:
    print(f"‚ùå Metadata not found in any expected location:")
    for path in metadata_paths:
        print(f"   {path}")
    print(f"\n   Run Cell 2.1 first to download the dataset from Kaggle!")
    raise FileNotFoundError(f"Metadata file not found")

print(f"‚úÖ Found metadata at: {metadata_path}")

# Load metadata
if metadata_path.endswith('.csv'):
    odir_df = pd.read_csv(metadata_path)
else:
    odir_df = pd.read_excel(metadata_path)

print(f"‚úÖ Loaded ODIR-5K metadata")
print(f"   Total patients: {len(odir_df)}")
print(f"   Columns: {list(odir_df.columns)}")

# Apply TEST_MODE sampling if enabled
if TEST_MODE and NUM_TEST_PATIENTS:
    odir_df = odir_df.head(NUM_TEST_PATIENTS)
    print(f"\n‚ö†Ô∏è TEST MODE: Using {NUM_TEST_PATIENTS} patients")

print(f"\nFinal dataset size: {len(odir_df)} patients")
print(f"\nSample row:")
display(odir_df.head(1))

## Cell 2.3: Validate Dataset Structure

In [None]:
# Validate dataset structure and find image directory
print("Validating ODIR-5K dataset structure...\n")

# Check multiple possible image directory locations
image_dir_candidates = [
    f"{DRIVE_DATA}/ODIR-5K/ODIR-5K/Training Images",  # Nested structure
    f"{DRIVE_DATA}/ODIR-5K/Training Images",  # Flat structure
    f"{DRIVE_DATA}/ODIR-5K/preprocessed_images",  # Preprocessed folder
]

ODIR_IMAGES_DIR = None
for dir_path in image_dir_candidates:
    if os.path.exists(dir_path):
        test_images = list(Path(dir_path).glob("*.jpg"))
        if test_images:
            ODIR_IMAGES_DIR = dir_path
            print(f"‚úÖ Found images at: {ODIR_IMAGES_DIR}")
            break

if not ODIR_IMAGES_DIR:
    print("‚ùå Training images not found in any expected location:")
    for dir_path in image_dir_candidates:
        print(f"   {dir_path}")
    raise FileNotFoundError("Training images directory not found")

# Check images exist
image_files = list(Path(ODIR_IMAGES_DIR).glob("*.jpg"))
left_images = [f for f in image_files if '_left' in f.name]
right_images = [f for f in image_files if '_right' in f.name]

print(f"\nüìä Images statistics:")
print(f"   Total images: {len(image_files)}")
print(f"   Left eye: {len(left_images)}")
print(f"   Right eye: {len(right_images)}")

# Validate required columns in metadata
required_columns = ['ID', 'Patient Age', 'Patient Sex', 'Left-Diagnostic Keywords', 'Right-Diagnostic Keywords']
missing_columns = [col for col in required_columns if col not in odir_df.columns]

if missing_columns:
    print(f"\n‚ùå Missing required columns: {missing_columns}")
    print(f"   Available columns: {list(odir_df.columns)}")
    raise ValueError(f"Missing required columns: {missing_columns}")
else:
    print(f"\n‚úÖ All required metadata columns present")

# Check for missing values
print(f"\nMissing values in metadata:")
for col in required_columns:
    missing = odir_df[col].isna().sum()
    print(f"  {col}: {missing} ({missing/len(odir_df)*100:.1f}%)")

print(f"\n‚úÖ Dataset validation complete!")
print(f"   Using images from: {ODIR_IMAGES_DIR}")

## Cell 2.4: Dataset Statistics

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

print("="*80)
print("ODIR-5K DATASET STATISTICS")
print("="*80)

# Age distribution
print(f"\nAge Statistics:")
print(f"  Mean: {odir_df['Patient Age'].mean():.1f} years")
print(f"  Median: {odir_df['Patient Age'].median():.1f} years")
print(f"  Range: {odir_df['Patient Age'].min():.0f} - {odir_df['Patient Age'].max():.0f} years")

# Sex distribution
print(f"\nSex Distribution:")
sex_counts = odir_df['Patient Sex'].value_counts()
for sex, count in sex_counts.items():
    print(f"  {sex}: {count} ({count/len(odir_df)*100:.1f}%)")

# Keywords distribution (top 10 most common)
from collections import Counter

all_keywords = []
for keywords_str in pd.concat([odir_df['Left-Diagnostic Keywords'], odir_df['Right-Diagnostic Keywords']]).dropna():
    all_keywords.extend([k.strip() for k in str(keywords_str).split(',') if k.strip()])

keyword_counts = Counter(all_keywords)
top_keywords = keyword_counts.most_common(10)

print(f"\nTop 10 Disease Keywords:")
for keyword, count in top_keywords:
    print(f"  {keyword}: {count} ({count/(len(odir_df)*2)*100:.1f}%)")

# Visualizations
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Age distribution
axes[0].hist(odir_df['Patient Age'].dropna(), bins=20, color='skyblue', edgecolor='black')
axes[0].set_title('Age Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Age (years)')
axes[0].set_ylabel('Count')

# Sex distribution
sex_counts.plot(kind='bar', ax=axes[1], color=['lightcoral', 'lightblue'])
axes[1].set_title('Sex Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Sex')
axes[1].set_ylabel('Count')
axes[1].tick_params(axis='x', rotation=0)

# Top keywords
keywords, counts = zip(*top_keywords)
axes[2].barh(range(len(keywords)), counts, color='lightgreen')
axes[2].set_yticks(range(len(keywords)))
axes[2].set_yticklabels(keywords)
axes[2].set_title('Top 10 Disease Keywords', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Frequency')
axes[2].invert_yaxis()

plt.tight_layout()
plt.savefig(f"{DRIVE_RESULTS}/odir_dataset_statistics.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Statistics plot saved to {DRIVE_RESULTS}/odir_dataset_statistics.png")

# SECTION 3: Generate Clinical Prompts

Use DSPy + OpenRouter to generate 3 clinical prompts per patient:
1. **Left eye prompt**: Specific to left eye pathology
2. **Right eye prompt**: Specific to right eye pathology
3. **Patient-level prompt**: Holistic diagnostic impression (both eyes)

## Cell 3.1: Define ODIR Prompt Signature

In [None]:
import dspy

class OdirPromptSignature(dspy.Signature):
    """Generate a clinical diagnostic impression for a retinal fundus image with patient metadata.
    
    Requirements:
    - Use varied medical terminology
    - Be concise (1-2 sentences)
    - Include relevant clinical features when visible
    - Incorporate patient demographics (age, sex) when clinically relevant
    - Vary phrasing to avoid repetition
    - Specify eye laterality (left/right) in the description
    """
    
    image = dspy.InputField(desc="Color fundus photograph")
    keywords = dspy.InputField(desc="Diagnostic keywords from clinical annotations")
    eye_side = dspy.InputField(desc="Eye laterality: 'left', 'right', or 'both'")
    age = dspy.InputField(desc="Patient age in years")
    sex = dspy.InputField(desc="Patient sex (M/F)")
    style_hint = dspy.InputField(desc="Writing style guidance for variation")
    
    impression = dspy.OutputField(desc="Clinical diagnostic impression")

print("‚úÖ OdirPromptSignature defined")

## Cell 3.2: Create ODIR Prompt Generator

In [None]:
import dspy
import random
from typing import Optional
import re

class OdirPromptGenerator(dspy.Module):
    """Generates varied clinical prompts for ODIR-5K fundus images with metadata."""
    
    def __init__(self, use_chain_of_thought: bool = False):
        super().__init__()
        
        # Choose between simple prediction or chain-of-thought
        if use_chain_of_thought:
            self.generate = dspy.ChainOfThought(OdirPromptSignature)
        else:
            self.generate = dspy.Predict(OdirPromptSignature)
        
        # Style variations for randomization
        self.writing_styles = [
            "formal and detailed",
            "concise and direct",
            "descriptive with key findings",
            "focused on diagnostic features",
            "educational clinical note style",
            "brief assessment format",
        ]
        
        self.perspectives = [
            "describe visible pathology",
            "summarize diagnostic impression",
            "note clinical significance",
            "describe characteristic findings",
            "identify key abnormalities",
        ]
        
        self.detail_levels = ["brief", "moderate", "detailed"]
    
    def _create_style_hint(self, rng: random.Random) -> str:
        """Create a randomized style hint."""
        style = rng.choice(self.writing_styles)
        perspective = rng.choice(self.perspectives)
        detail = rng.choice(self.detail_levels)
        
        return f"{detail} description, {style}, {perspective}"
    
    def forward(self, image, keywords: str, eye_side: str, age: int, sex: str, 
                rng: Optional[random.Random] = None):
        """Generate a clinical prompt for the given image and metadata."""
        rng = rng or random.Random()
        
        style_hint = self._create_style_hint(rng)
        
        result = self.generate(
            image=image,
            keywords=keywords,
            eye_side=eye_side,
            age=age,
            sex=sex,
            style_hint=style_hint
        )
        
        # Clean up the output
        impression = result.impression.strip()
        
        # Ensure it ends with a period
        if not impression.endswith(('.', '!', '?')):
            impression += '.'
        
        # Take only first sentence if multiple were generated
        if '\n' in impression:
            impression = impression.split('\n')[0].strip()
        
        # Remove model artifacts
        impression = re.sub(r'\s*\[\[\s*##\s*completed\s*##\s*\]\]\.?', '', impression)
        impression = impression.strip()
        
        # Ensure it still ends with a period after cleanup
        if impression and not impression.endswith(('.', '!', '?')):
            impression += '.'
        
        return impression

print("‚úÖ OdirPromptGenerator defined")

## Cell 3.3: Test Generator on Single Patient

In [None]:
from PIL import Image

# Initialize generator
generator = OdirPromptGenerator(use_chain_of_thought=False)

# Get first patient for testing
test_patient = odir_df.iloc[0]
patient_id = test_patient['ID']
age = int(test_patient['Patient Age'])
sex = test_patient['Patient Sex']
left_keywords = str(test_patient['Left-Diagnostic Keywords'])
right_keywords = str(test_patient['Right-Diagnostic Keywords'])

# Load images from file paths
left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"

if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
    print(f"‚ùå Images not found for patient {patient_id}")
    print(f"   Expected: {left_img_path}, {right_img_path}")
    print(f"\n‚ö†Ô∏è Make sure you've downloaded and extracted ODIR-5K images!")
else:
    left_img = Image.open(left_img_path).convert('RGB')
    right_img = Image.open(right_img_path).convert('RGB')
    
    print(f"Test Patient: {patient_id}")
    print(f"  Age: {age}, Sex: {sex}")
    print(f"  Left Keywords: {left_keywords}")
    print(f"  Right Keywords: {right_keywords}")
    print(f"  Left Image: {left_img.size}")
    print(f"  Right Image: {right_img.size}")
    
    # Generate 3 prompts
    rng = random.Random(42)
    
    print("\nGenerating prompts...\n")
    
    # 1. Left eye prompt
    left_prompt = generator(
        image=left_img,
        keywords=left_keywords,
        eye_side="left",
        age=age,
        sex=sex,
        rng=rng
    )
    print(f"‚úÖ Left Eye Prompt:\n   {left_prompt}\n")
    
    # 2. Right eye prompt
    right_prompt = generator(
        image=right_img,
        keywords=right_keywords,
        eye_side="right",
        age=age,
        sex=sex,
        rng=rng
    )
    print(f"‚úÖ Right Eye Prompt:\n   {right_prompt}\n")
    
    # 3. Patient-level prompt (both eyes)
    patient_prompt = generator(
        image=left_img,  # Use either image
        keywords=f"{left_keywords}; {right_keywords}",
        eye_side="both",
        age=age,
        sex=sex,
        rng=rng
    )
    print(f"‚úÖ Patient-Level Prompt:\n   {patient_prompt}\n")
    
    print("‚úÖ Test generation successful - 3 prompts created per patient")

## Cell 3.4: Retry Logic for Rate Limiting

In [None]:
import time

def retry_with_backoff(func, max_retries: int = 5, base_delay: float = 10.0):
    """Enhanced retry function with exponential backoff and better error handling.
    
    Handles:
    - Rate limiting (429, quota exceeded)
    - Bad requests (400, JSON schema errors)
    - Model unavailability (404, not found)
    """
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            error_str = str(e).lower()
            
            # Categorize errors
            is_rate_limit = any(kw in error_str for kw in ["rate limit", "429", "quota", "too many requests"])
            is_bad_request = any(kw in error_str for kw in ["400", "bad request", "json schema"])
            is_not_found = any(kw in error_str for kw in ["404", "not found", "no endpoints"])
            
            # Don't retry 404 errors - model doesn't exist
            if is_not_found:
                print(f"‚ùå Model unavailable (404): {error_str[:150]}")
                raise
            
            # Don't retry JSON schema errors - let DSPy handle fallback
            if is_bad_request and "json schema" in error_str:
                print(f"‚ùå JSON schema error (will try fallback model)")
                raise
            
            # Last attempt - give up
            if attempt == max_retries - 1:
                print(f"‚ùå Max retries ({max_retries}) exhausted")
                raise
            
            # Calculate backoff delay
            if is_rate_limit:
                delay = base_delay * (2 ** attempt) + random.uniform(0, 5)
                print(f"‚è≥ Rate limited. Waiting {delay:.1f}s (retry {attempt + 1}/{max_retries})")
            else:
                delay = base_delay + random.uniform(0, 3)
                print(f"‚ö†Ô∏è Error occurred. Waiting {delay:.1f}s (retry {attempt + 1}/{max_retries})")
                print(f"   Error: {error_str[:150]}")
            
            time.sleep(delay)

print("‚úÖ Retry logic defined")

## Cell 3.5: Main Prompt Generation Loop

**‚ö†Ô∏è This will take ~30 minutes for 100 patients (TEST_MODE) or ~4-5 hours for 5,000 patients (FULL_MODE)**

Generates 3 prompts per patient:
- Left eye-specific prompt
- Right eye-specific prompt  
- Patient-level holistic prompt

In [None]:
from tqdm.notebook import tqdm
import pandas as pd
import json
from PIL import Image

# Output paths
prompts_csv_path = f"{DRIVE_PROMPTS}/odir_retclip_prompts.csv"
checkpoint_path = f"{DRIVE_PROMPTS}/generation_checkpoint.json"

# Load checkpoint if exists
processed_patients = set()
prompts_rows = []

if os.path.exists(checkpoint_path):
    with open(checkpoint_path, 'r') as f:
        checkpoint_data = json.load(f)
        processed_patients = set(checkpoint_data.get('processed_patients', []))
        print(f"‚úÖ Resuming from checkpoint: {len(processed_patients)} patients already processed")
    
    # Load existing prompts
    if os.path.exists(prompts_csv_path):
        existing_df = pd.read_csv(prompts_csv_path)
        prompts_rows = existing_df.to_dict('records')

print(f"\nGenerating prompts for {len(odir_df)} patients...")
print(f"Checkpoint interval: {CHECKPOINT_INTERVAL}")
print(f"Delay between calls: {DELAY_BETWEEN_CALLS}s\n")

# Initialize generator
generator = OdirPromptGenerator(use_chain_of_thought=False)

# Process each patient
for idx, row in tqdm(odir_df.iterrows(), total=len(odir_df), desc="Generating prompts"):
    patient_id = row['ID']
    
    # Skip if already processed
    if patient_id in processed_patients:
        continue
    
    try:
        # Load patient metadata
        age = int(row['Patient Age'])
        sex = row['Patient Sex']
        left_keywords = str(row['Left-Diagnostic Keywords'])
        right_keywords = str(row['Right-Diagnostic Keywords'])
        
        # Load images from file paths
        left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
        right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
        
        if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
            print(f"\n‚ö†Ô∏è Images not found for patient {patient_id}, skipping")
            continue
        
        left_img = Image.open(left_img_path).convert('RGB')
        right_img = Image.open(right_img_path).convert('RGB')
        
        # Create deterministic RNG based on patient ID
        rng = random.Random(hash(str(patient_id)) ^ 42)
        
        # Generate 3 prompts with retry logic
        def make_left_prompt():
            return generator(
                image=left_img,
                keywords=left_keywords,
                eye_side="left",
                age=age,
                sex=sex,
                rng=rng
            )
        
        def make_right_prompt():
            return generator(
                image=right_img,
                keywords=right_keywords,
                eye_side="right",
                age=age,
                sex=sex,
                rng=rng
            )
        
        def make_patient_prompt():
            return generator(
                image=left_img,
                keywords=f"{left_keywords}; {right_keywords}",
                eye_side="both",
                age=age,
                sex=sex,
                rng=rng
            )
        
        left_prompt = retry_with_backoff(make_left_prompt)
        time.sleep(DELAY_BETWEEN_CALLS)
        
        right_prompt = retry_with_backoff(make_right_prompt)
        time.sleep(DELAY_BETWEEN_CALLS)
        
        patient_prompt = retry_with_backoff(make_patient_prompt)
        
        # Store results
        prompts_rows.append({
            'patient_id': patient_id,
            'age': age,
            'sex': sex,
            'left_keywords': left_keywords,
            'right_keywords': right_keywords,
            'prompt_left': left_prompt,
            'prompt_right': right_prompt,
            'prompt_patient': patient_prompt
        })
        
        processed_patients.add(patient_id)
        
        # Save checkpoint periodically
        if len(processed_patients) % CHECKPOINT_INTERVAL == 0:
            pd.DataFrame(prompts_rows).to_csv(prompts_csv_path, index=False)
            with open(checkpoint_path, 'w') as f:
                json.dump({'processed_patients': list(processed_patients)}, f)
        
        # Rate limiting
        if idx < len(odir_df) - 1:
            time.sleep(DELAY_BETWEEN_CALLS)
    
    except Exception as e:
        print(f"\n‚ùå Error processing patient {patient_id}: {e}")
        # Save checkpoint on error
        pd.DataFrame(prompts_rows).to_csv(prompts_csv_path, index=False)
        with open(checkpoint_path, 'w') as f:
            json.dump({'processed_patients': list(processed_patients)}, f)
        continue

# Save final results
prompts_df = pd.DataFrame(prompts_rows)
prompts_df.to_csv(prompts_csv_path, index=False)

# Clean up checkpoint
if os.path.exists(checkpoint_path):
    os.remove(checkpoint_path)

print(f"\n‚úÖ Prompt generation complete!")
print(f"   Total patients: {len(prompts_df)}")
print(f"   Total prompts: {len(prompts_df) * 3} (3 per patient)")
print(f"   Saved to: {prompts_csv_path}")

# SECTION 4: Preprocess for RET-CLIP

Convert generated prompts and images to RET-CLIP format:
- **TSV**: Real paired left/right images (patient_id, left_img_base64, right_img_base64)
- **JSONL**: Text annotations with eye_side field for tripartite loss

## Cell 4.1: Helper Functions for Image Encoding

In [None]:
import base64
from io import BytesIO

def image_to_base64_urlsafe(pil_image, size=224):
    """Convert PIL Image to URL-safe base64 string with resizing"""
    if not isinstance(pil_image, Image.Image):
        raise ValueError(f"Expected PIL Image, got {type(pil_image)}")
    
    # Convert to RGB and resize
    img = pil_image.convert('RGB')
    img = img.resize((size, size), Image.BICUBIC)
    
    # Encode to base64
    buffered = BytesIO()
    img.save(buffered, format="JPEG", quality=95)
    img_str = base64.urlsafe_b64encode(buffered.getvalue()).decode()
    return img_str

print("‚úÖ Image encoding helper defined")

## Cell 4.2: Create TSV File (Paired Images)

**Format**: `patient_id\tleft_img_base64\tright_img_base64`

Real binocular pairs - NOT duplicated monocular images!

In [None]:
from tqdm.notebook import tqdm

# Create TSV file with real paired images
tsv_path = f"{DRIVE_DATA}/odir_train_imgs.tsv"

print(f"Creating TSV file with paired left/right images...")
print(f"Output: {tsv_path}\n")

with open(tsv_path, 'w', encoding='utf-8') as tsv_file:
    for idx, row in tqdm(prompts_df.iterrows(), total=len(prompts_df), desc="Encoding images"):
        patient_id = row['patient_id']
        
        # Get image paths
        left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
        right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
        
        if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
            print(f"\n‚ö†Ô∏è Images not found for patient {patient_id}")
            continue
        
        # Load and encode both images separately (real binocular pair!)
        left_img = Image.open(left_img_path)
        right_img = Image.open(right_img_path)
        
        left_b64 = image_to_base64_urlsafe(left_img, IMAGE_SIZE)
        right_b64 = image_to_base64_urlsafe(right_img, IMAGE_SIZE)
        
        # Write TSV line: patient_id, left_img, right_img
        tsv_file.write(f"{patient_id}\t{left_b64}\t{right_b64}\n")

print(f"\n‚úÖ TSV file created: {tsv_path}")

# Validate format
with open(tsv_path, 'r', encoding='utf-8') as f:
    first_line = f.readline().strip()
    parts = first_line.split('\t')
    print(f"\nValidation:")
    print(f"  Columns: {len(parts)} (expected: 3)")
    print(f"  Patient ID: {parts[0]}")
    print(f"  Left image length: {len(parts[1])} chars")
    print(f"  Right image length: {len(parts[2])} chars")
    print(f"  ‚úÖ Format correct - Real binocular pairs!")

## Cell 4.3: Create JSONL File (Text Annotations with Eye Side)

**Format**: Each patient generates 3 JSONL entries:
```json
{"text_id": 0, "text": "left prompt", "image_ids": ["patient_id"], "eye_side": "left"}
{"text_id": 1, "text": "right prompt", "image_ids": ["patient_id"], "eye_side": "right"}
{"text_id": 2, "text": "patient prompt", "image_ids": ["patient_id"], "eye_side": "both"}
```

The `eye_side` field is used by RET-CLIP's tripartite loss function.

In [None]:
import json

# Create JSONL file with text annotations and eye_side field
jsonl_path = f"{DRIVE_DATA}/odir_train_texts.jsonl"

print(f"Creating JSONL file with eye_side annotations...")
print(f"Output: {jsonl_path}\n")

text_id = 0
with open(jsonl_path, 'w', encoding='utf-8') as jsonl_file:
    for idx, row in tqdm(prompts_df.iterrows(), total=len(prompts_df), desc="Writing JSONL"):
        patient_id = row['patient_id']
        
        # Entry 1: Left eye prompt
        left_entry = {
            "text_id": text_id,
            "text": row['prompt_left'],
            "image_ids": [patient_id],
            "eye_side": "left"
        }
        jsonl_file.write(json.dumps(left_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 2: Right eye prompt
        right_entry = {
            "text_id": text_id,
            "text": row['prompt_right'],
            "image_ids": [patient_id],
            "eye_side": "right"
        }
        jsonl_file.write(json.dumps(right_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 3: Patient-level prompt (both eyes)
        patient_entry = {
            "text_id": text_id,
            "text": row['prompt_patient'],
            "image_ids": [patient_id],
            "eye_side": "both"
        }
        jsonl_file.write(json.dumps(patient_entry, ensure_ascii=False) + '\n')
        text_id += 1

print(f"\n‚úÖ JSONL file created: {jsonl_path}")
print(f"   Total patients: {len(prompts_df)}")
print(f"   Total text entries: {text_id} (3 per patient)")

# Validate format
with open(jsonl_path, 'r', encoding='utf-8') as f:
    sample_lines = [json.loads(f.readline()) for _ in range(3)]
    
print(f"\nValidation - First patient's 3 prompts:")
for entry in sample_lines:
    print(f"  [{entry['eye_side']}] {entry['text'][:80]}...")

print(f"\n‚úÖ Format correct - 3 prompts per patient with eye_side annotations!")

## Cell 4.4: Train/Test Split

Split the prompts into training and testing sets (80/20 split).

In [None]:
from sklearn.model_selection import train_test_split

# Split prompts DataFrame (80/20 split)
train_df, test_df = train_test_split(
    prompts_df,
    test_size=0.2,
    random_state=42,
    stratify=None  # Can't stratify if some classes have only 1 sample
)

print(f"Train patients: {len(train_df)}")
print(f"Test patients: {len(test_df)}")

# Save splits
train_df.to_csv(f"{DRIVE_DATA}/train_patients.csv", index=False)
test_df.to_csv(f"{DRIVE_DATA}/test_patients.csv", index=False)

print(f"\n‚úÖ Split saved:")
print(f"   Train: {DRIVE_DATA}/train_patients.csv")
print(f"   Test: {DRIVE_DATA}/test_patients.csv")

## Cell 4.5: Create Train TSV and JSONL Files

In [None]:
# Create Train TSV file
tsv_path = f"{DRIVE_DATA}/train_imgs.tsv"  # FIXED: removed "odir_" prefix

print(f"Creating TRAIN TSV file...")
print(f"Output: {tsv_path}\n")

with open(tsv_path, 'w', encoding='utf-8') as tsv_file:
    for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Encoding train images"):
        patient_id = row['patient_id']
        
        # Get image paths
        left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
        right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
        
        if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
            continue
        
        # Load and encode both images
        left_img = Image.open(left_img_path)
        right_img = Image.open(right_img_path)
        
        left_b64 = image_to_base64_urlsafe(left_img, IMAGE_SIZE)
        right_b64 = image_to_base64_urlsafe(right_img, IMAGE_SIZE)
        
        tsv_file.write(f"{patient_id}\t{left_b64}\t{right_b64}\n")

print(f"\n‚úÖ Train TSV created: {tsv_path}")

# Create Train JSONL file
jsonl_path = f"{DRIVE_DATA}/train_texts.jsonl"  # FIXED: removed "odir_" prefix

print(f"\nCreating TRAIN JSONL file...")
print(f"Output: {jsonl_path}\n")

text_id = 0
with open(jsonl_path, 'w', encoding='utf-8') as jsonl_file:
    for idx, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Writing train JSONL"):
        patient_id = row['patient_id']
        
        # Entry 1: Left eye prompt
        left_entry = {
            "text_id": text_id,
            "text": row['prompt_left'],
            "image_ids": [patient_id],
            "eye_side": "left"
        }
        jsonl_file.write(json.dumps(left_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 2: Right eye prompt
        right_entry = {
            "text_id": text_id,
            "text": row['prompt_right'],
            "image_ids": [patient_id],
            "eye_side": "right"
        }
        jsonl_file.write(json.dumps(right_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 3: Patient-level prompt
        patient_entry = {
            "text_id": text_id,
            "text": row['prompt_patient'],
            "image_ids": [patient_id],
            "eye_side": "both"
        }
        jsonl_file.write(json.dumps(patient_entry, ensure_ascii=False) + '\n')
        text_id += 1

print(f"\n‚úÖ Train JSONL created: {jsonl_path}")
print(f"   Total text entries: {text_id} (3 per patient)")

## Cell 4.6: Create Test TSV and JSONL Files

In [None]:
# Create Test TSV file
tsv_path = f"{DRIVE_DATA}/test_imgs.tsv"  # FIXED: removed "odir_" prefix

print(f"Creating TEST TSV file...")
print(f"Output: {tsv_path}\n")

with open(tsv_path, 'w', encoding='utf-8') as tsv_file:
    for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Encoding test images"):
        patient_id = row['patient_id']
        
        # Get image paths
        left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
        right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
        
        if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
            continue
        
        # Load and encode both images
        left_img = Image.open(left_img_path)
        right_img = Image.open(right_img_path)
        
        left_b64 = image_to_base64_urlsafe(left_img, IMAGE_SIZE)
        right_b64 = image_to_base64_urlsafe(right_img, IMAGE_SIZE)
        
        tsv_file.write(f"{patient_id}\t{left_b64}\t{right_b64}\n")

print(f"\n‚úÖ Test TSV created: {tsv_path}")

# Create Test JSONL file
jsonl_path = f"{DRIVE_DATA}/test_texts.jsonl"  # FIXED: removed "odir_" prefix

print(f"\nCreating TEST JSONL file...")
print(f"Output: {jsonl_path}\n")

text_id = 0
with open(jsonl_path, 'w', encoding='utf-8') as jsonl_file:
    for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Writing test JSONL"):
        patient_id = row['patient_id']
        
        # Entry 1: Left eye prompt
        left_entry = {
            "text_id": text_id,
            "text": row['prompt_left'],
            "image_ids": [patient_id],
            "eye_side": "left"
        }
        jsonl_file.write(json.dumps(left_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 2: Right eye prompt
        right_entry = {
            "text_id": text_id,
            "text": row['prompt_right'],
            "image_ids": [patient_id],
            "eye_side": "right"
        }
        jsonl_file.write(json.dumps(right_entry, ensure_ascii=False) + '\n')
        text_id += 1
        
        # Entry 3: Patient-level prompt
        patient_entry = {
            "text_id": text_id,
            "text": row['prompt_patient'],
            "image_ids": [patient_id],
            "eye_side": "both"
        }
        jsonl_file.write(json.dumps(patient_entry, ensure_ascii=False) + '\n')
        text_id += 1

print(f"\n‚úÖ Test JSONL created: {jsonl_path}")
print(f"   Total text entries: {text_id} (3 per patient)")

# SECTION 5: Build LMDB Database

Create LMDB databases for efficient data loading during training.

LMDB (Lightning Memory-Mapped Database) is a high-performance embedded database that allows fast random access to image data during training.

## Cell 5.1: Build Train LMDB

In [None]:
# Build LMDB for train set
print("="*80)
print("Building LMDB for TRAIN set")
print("="*80)

!python /content/retclip/RET_CLIP/preprocess/build_lmdb_dataset_for_RET-CLIP.py \
    --data_dir {DRIVE_DATA} \
    --splits train \
    --lmdb_dir {DRIVE_LMDB}

## Cell 5.2: Build Test LMDB

In [None]:
# Build LMDB for test set
print("\n" + "="*80)
print("Building LMDB for TEST set")
print("="*80)

!python /content/retclip/RET_CLIP/preprocess/build_lmdb_dataset_for_RET-CLIP.py \
    --data_dir {DRIVE_DATA} \
    --splits test \
    --lmdb_dir {DRIVE_LMDB}

## Cell 5.3: Validate LMDB Databases

In [None]:
# Validate LMDB by reading a few samples
import lmdb
import pickle

print("\n" + "="*80)
print("Validating LMDB databases")
print("="*80)

for split_name in ['train', 'test']:
    lmdb_path = f"{DRIVE_LMDB}/{split_name}/imgs"
    
    if not os.path.exists(lmdb_path):
        print(f"‚ùå LMDB not found: {lmdb_path}")
        continue
    
    env = lmdb.open(lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
    
    with env.begin() as txn:
        # Try to read first 3 samples
        print(f"\n{split_name.upper()} LMDB:")
        
        # Get list of patient IDs from the corresponding dataframe
        if split_name == 'train':
            sample_ids = train_df['patient_id'].head(3).tolist()
        else:
            sample_ids = test_df['patient_id'].head(3).tolist()
        
        for patient_id in sample_ids:
            # FIXED: Convert patient_id to string before encoding
            value = txn.get(str(patient_id).encode('utf-8'))
            
            if value is None:
                print(f"  ‚ö†Ô∏è No data for {patient_id}")
                continue
            
            try:
                img_left_b64, img_right_b64 = pickle.loads(value)
                print(f"  ‚úÖ {patient_id}: left={len(img_left_b64)} chars, right={len(img_right_b64)} chars")
            except Exception as e:
                print(f"  ‚ùå Error unpacking {patient_id}: {e}")
    
    env.close()

print("\n‚úÖ LMDB validation complete")

# SECTION 6: Train RET-CLIP

Train RET-CLIP using contrastive learning on the ODIR-5K dataset with English PubMedBERT text encoder.

**Expected Time**:
- TEST MODE (100 patients, 2 epochs): ~30 minutes on T4
- FULL MODE (5,000 patients, 10 epochs): ~12-15 hours on A100

## Cell 6.1: Training Configuration

In [None]:
# Display training configuration
print("="*80)
print("TRAINING CONFIGURATION")
print("="*80)
print(f"  Mode: {'TEST' if TEST_MODE else 'FULL'}")
print(f"  Vision Model: {VISION_MODEL}")
print(f"  Text Model: {TEXT_MODEL}")
print(f"  Batch Size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Warmup Steps: {WARMUP_STEPS}")
print(f"  Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  Train samples: {len(train_df)}")
print(f"  Test samples: {len(test_df)}")
print(f"  LMDB dir: {DRIVE_LMDB}/train")
print(f"  Checkpoint dir: {DRIVE_CHECKPOINTS}")
print("\nüí° Note: Using PubMedBERT (medical domain-specific text encoder)")
print("‚úÖ Configuration ready for training")

## Cell 6.2: Run Training with torchrun

**‚ö†Ô∏è This will take ~30 minutes for TEST_MODE or ~12-15 hours for FULL_MODE**

Trains RET-CLIP with:
- Distributed data parallel (DDP) even on single GPU for efficiency
- Checkpoints saved every epoch to Google Drive
- Uses tripartite contrastive loss for binocular architecture

In [None]:
# Set PYTHONPATH for subprocess
import os
os.environ['PYTHONPATH'] = '/content/retclip'
print(f"‚úÖ Set PYTHONPATH to: {os.environ['PYTHONPATH']}")

# Determine which text encoders to train
if RUN_TEXT_ENCODER_COMPARISON:
    encoders_to_train = TEXT_ENCODERS
    print(f"\n‚úÖ TEXT ENCODER COMPARISON MODE: Will train {len(encoders_to_train)} models")
else:
    # Train only the default TEXT_MODEL
    encoders_to_train = [{
        "name": "PubMedBERT",
        "model_id": TEXT_MODEL,
        "description": "Default text encoder"
    }]
    print(f"\n‚úÖ SINGLE ENCODER MODE: Training only {TEXT_MODEL}")

# Train each text encoder
for encoder_idx, encoder_config in enumerate(encoders_to_train):
    encoder_name = encoder_config["name"]
    encoder_model_id = encoder_config["model_id"]
    encoder_desc = encoder_config["description"]
    
    print("\n" + "="*80)
    print(f"Training Model {encoder_idx + 1}/{len(encoders_to_train)}: {encoder_name}")
    print("="*80)
    print(f"  Description: {encoder_desc}")
    print(f"  Model ID: {encoder_model_id}")
    print(f"  Vision Model: {VISION_MODEL}")
    print(f"  Epochs: {NUM_EPOCHS}")
    print(f"  Batch Size: {BATCH_SIZE}")
    
    # Create encoder-specific name for checkpoints
    encoder_short_name = encoder_name.lower().replace('-', '').replace(' ', '')
    model_name = f"retclip_odir_{encoder_short_name}"
    
    print(f"  Checkpoint Name: {model_name}")
    print("="*80)
    
    # Run training with distributed launcher
    !torchrun --nproc_per_node=1 --master_port=29500 \
        /content/retclip/RET_CLIP/training/main.py \
        --train-data {DRIVE_LMDB}/train \
        --batch-size {BATCH_SIZE} \
        --max-epochs {NUM_EPOCHS} \
        --lr {LEARNING_RATE} \
        --warmup {WARMUP_STEPS} \
        --vision-model {VISION_MODEL} \
        --text-model {encoder_model_id} \
        --logs {DRIVE_CHECKPOINTS} \
        --name {model_name} \
        --save-epoch-frequency 1 \
        --skip-aggregate
    
    print(f"\n‚úÖ Completed training for {encoder_name}")
    print(f"   Checkpoints saved to: {DRIVE_CHECKPOINTS}/{model_name}/checkpoints/")
    
    if encoder_idx < len(encoders_to_train) - 1:
        print(f"\n‚è≥ Moving to next encoder...")

print("\n" + "="*80)
print("üéâ ALL TRAINING COMPLETE!")
print("="*80)
if RUN_TEXT_ENCODER_COMPARISON:
    print(f"Trained {len(encoders_to_train)} models:")
    for encoder_config in encoders_to_train:
        encoder_short_name = encoder_config["name"].lower().replace('-', '').replace(' ', '')
        print(f"  ‚úÖ retclip_odir_{encoder_short_name}")
else:
    print(f"Trained 1 model: retclip_odir_pubmedbert")

## Cell 6.3: Verify Saved Checkpoints

In [None]:
# List saved checkpoints for all trained models
print("\n" + "="*80)
print("Saved Checkpoints")
print("="*80)

# Determine which models were trained
if RUN_TEXT_ENCODER_COMPARISON:
    model_names = [f"retclip_odir_{enc['name'].lower().replace('-', '').replace(' ', '')}" for enc in TEXT_ENCODERS]
else:
    model_names = ["retclip_odir_pubmedbert"]

for model_name in model_names:
    checkpoint_dir = f"{DRIVE_CHECKPOINTS}/{model_name}/checkpoints"
    
    print(f"\nüìÅ {model_name}:")
    
    if os.path.exists(checkpoint_dir):
        checkpoint_files = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith('.pt')])
        
        if checkpoint_files:
            print(f"   ‚úÖ Found {len(checkpoint_files)} checkpoint(s):")
            for ckpt in checkpoint_files:
                ckpt_path = f"{checkpoint_dir}/{ckpt}"
                size_mb = os.path.getsize(ckpt_path) / (1024 * 1024)
                print(f"      {ckpt} ({size_mb:.1f} MB)")
            
            # Verify final checkpoint exists
            final_checkpoint = f"{checkpoint_dir}/epoch_latest.pt"
            if os.path.exists(final_checkpoint):
                print(f"   ‚úÖ Final checkpoint ready: epoch_latest.pt")
            else:
                print(f"   ‚ö†Ô∏è Final checkpoint not found: epoch_latest.pt")
        else:
            print(f"   ‚ùå No .pt files found")
    else:
        print(f"   ‚ùå Directory not found: {checkpoint_dir}")

print("\n" + "="*80)
print("Checkpoint Verification Complete")
print("="*80)

# SECTION 7: Zero-Shot Evaluation

Test RET-CLIP's vision-language alignment by computing similarity between image embeddings and text embeddings for all disease classes.

Zero-shot evaluation doesn't require training a classifier - it directly measures how well the model aligns images with clinical text descriptions.

## Cell 7.1: Prepare Zero-Shot Prompts

Get unique disease keywords from test set and create representative prompts for each disease class.

In [None]:
# Prepare zero-shot classification prompts from test set disease keywords
print("Preparing zero-shot disease class prompts...\n")

# Get unique disease keywords from test set
disease_classes = []
for keywords_str in pd.concat([test_df['left_keywords'], test_df['right_keywords']]).dropna():
    for keyword in str(keywords_str).split(','):
        disease = keyword.strip()
        if disease and disease not in disease_classes and disease != 'nan':
            disease_classes.append(disease)

print(f"Found {len(disease_classes)} unique disease keywords in test set")
print(f"Disease classes: {disease_classes[:10]}..." if len(disease_classes) > 10 else f"Disease classes: {disease_classes}")

# Create zero-shot prompts using patient-level prompts from test set
# For each disease, find a representative prompt
zero_shot_prompts = {}

for disease in disease_classes:
    # Find test patients with this disease
    matching = test_df[
        (test_df['left_keywords'].str.contains(disease, na=False, case=False)) |
        (test_df['right_keywords'].str.contains(disease, na=False, case=False))
    ]
    
    if len(matching) > 0:
        # Use the patient-level prompt from the first matching patient
        zero_shot_prompts[disease] = matching.iloc[0]['prompt_patient']
    else:
        # Fallback: create a simple clinical description
        zero_shot_prompts[disease] = f"Fundus photograph showing {disease}."

print(f"\n‚úÖ Created {len(zero_shot_prompts)} zero-shot prompts")
print(f"\nSample prompts:")
for disease, prompt in list(zero_shot_prompts.items())[:3]:
    print(f"  [{disease}]: {prompt[:80]}...")

## Cell 7.2: Zero-Shot Evaluation Loop

Evaluate all trained models on the test set:
- Load each model checkpoint
- Encode text embeddings for disease classes
- Extract image features from test set
- Compute zero-shot predictions via cosine similarity
- Calculate accuracy and F1 scores

In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm.notebook import tqdm
import json
from transformers import AutoTokenizer
import numpy as np

# Determine which models to evaluate
if RUN_TEXT_ENCODER_COMPARISON:
    models_to_evaluate = TEXT_ENCODERS
else:
    models_to_evaluate = [{
        "name": "PubMedBERT",
        "model_id": TEXT_MODEL,
        "hf_model_id": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
        "description": "Default text encoder"
    }]

# Store results for all models
all_results = {}

for encoder_config in models_to_evaluate:
    encoder_name = encoder_config["name"]
    encoder_model_id = encoder_config["model_id"]  # For RET-CLIP config files (dashes)
    encoder_hf_id = encoder_config["hf_model_id"]  # For HuggingFace tokenizer (slashes)
    encoder_short_name = encoder_name.lower().replace('-', '').replace(' ', '')
    model_name = f"retclip_odir_{encoder_short_name}"
    
    print("\n" + "="*80)
    print(f"Evaluating: {encoder_name}")
    print("="*80)
    print(f"  Config Model ID: {encoder_model_id}")
    print(f"  HuggingFace ID: {encoder_hf_id}")
    print(f"  Checkpoint: {model_name}")
    
    # Load model
    from RET_CLIP.clip.model import CLIP
    
    vision_config_path = f"/content/retclip/RET_CLIP/clip/model_configs/{VISION_MODEL}.json"
    text_config_path = f"/content/retclip/RET_CLIP/clip/model_configs/{encoder_model_id}.json"
    
    with open(vision_config_path, 'r') as fv, open(text_config_path, 'r') as ft:
        model_cfg = json.load(fv)
        for k, v in json.load(ft).items():
            model_cfg[k] = v
    
    model = CLIP(**model_cfg)
    
    # Load checkpoint
    checkpoint_path = f"{DRIVE_CHECKPOINTS}/{model_name}/checkpoints/epoch_latest.pt"
    print(f"  Loading: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle DDP state dict
    state_dict = checkpoint['state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
    
    model.load_state_dict(new_state_dict)
    model = model.cuda()
    model.eval()
    
    print("  ‚úÖ Model loaded")
    
    # Load tokenizer for this text encoder using HuggingFace ID
    print(f"  Loading tokenizer from HuggingFace: {encoder_hf_id}...")
    tokenizer = AutoTokenizer.from_pretrained(encoder_hf_id)
    
    # Encode text embeddings for disease classes IN THE SAME ORDER as disease_classes
    print(f"\n  Encoding {len(disease_classes)} disease class prompts...")
    print(f"  Disease classes order: {disease_classes}")
    text_features = []
    
    with torch.no_grad():
        # IMPORTANT: Iterate in disease_classes order to match labels!
        for disease in disease_classes:
            prompt = zero_shot_prompts[disease]
            
            # Tokenize using the loaded tokenizer
            text_input = tokenizer(
                [prompt],
                max_length=77,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            
            # Move to GPU
            input_ids = text_input['input_ids'].cuda()
            
            # Encode text - handle both tuple and tensor returns
            text_output = model.encode_text(input_ids)
            if isinstance(text_output, tuple):
                text_feat = text_output[0]  # Take first element if tuple
            else:
                text_feat = text_output
            
            text_feat = F.normalize(text_feat, dim=-1)
            text_features.append(text_feat)
    
    text_features = torch.cat(text_features, dim=0)
    print(f"  ‚úÖ Text features shape: {text_features.shape}")
    
    # Extract image features from test set
    print(f"\n  Extracting image features from {len(test_df)} test patients...")
    image_features = []
    true_labels = []
    
    from PIL import Image
    import torchvision.transforms as transforms
    
    # Image preprocessing
    preprocess = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    with torch.no_grad():
        for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc=f"  Processing"):
            patient_id = row['patient_id']
            
            # Load images
            left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
            right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
            
            if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
                continue
            
            left_img = Image.open(left_img_path).convert('RGB')
            right_img = Image.open(right_img_path).convert('RGB')
            
            left_tensor = preprocess(left_img).unsqueeze(0).cuda()
            right_tensor = preprocess(right_img).unsqueeze(0).cuda()
            
            # RET-CLIP encode_image takes BOTH images as arguments (binocular architecture)
            img_output = model.encode_image(left_tensor, right_tensor)
            
            # Handle both tuple and tensor returns
            if isinstance(img_output, tuple):
                img_feat = img_output[0]
            else:
                img_feat = img_output
            
            img_feat = F.normalize(img_feat, dim=-1)
            image_features.append(img_feat)
            
            # Get true label (use primary keyword from left or right)
            left_kw = str(row['left_keywords']).split(',')[0].strip() if pd.notna(row['left_keywords']) else ""
            right_kw = str(row['right_keywords']).split(',')[0].strip() if pd.notna(row['right_keywords']) else ""
            primary_keyword = left_kw if left_kw and left_kw in disease_classes else right_kw
            
            if primary_keyword in disease_classes:
                true_labels.append(disease_classes.index(primary_keyword))
            else:
                # Use first disease class as fallback
                true_labels.append(0)
    
    image_features = torch.cat(image_features, dim=0)
    print(f"  ‚úÖ Image features shape: {image_features.shape}")
    
    # Compute zero-shot predictions
    print(f"\n  Computing zero-shot predictions...")
    with torch.no_grad():
        # Cosine similarity between images and texts
        similarity = (image_features @ text_features.T)  # [N_test, N_classes]
        predictions = similarity.argmax(dim=-1).cpu().numpy()
    
    true_labels = np.array(true_labels)
    
    # Compute metrics
    accuracy = accuracy_score(true_labels, predictions)
    f1_macro = f1_score(true_labels, predictions, average='macro', zero_division=0)
    f1_weighted = f1_score(true_labels, predictions, average='weighted', zero_division=0)
    
    print(f"\n  üìä Results for {encoder_name}:")
    print(f"     Accuracy: {accuracy * 100:.2f}%")
    print(f"     F1 (Macro): {f1_macro * 100:.2f}%")
    print(f"     F1 (Weighted): {f1_weighted * 100:.2f}%")
    
    # Store results
    all_results[encoder_name] = {
        "accuracy": accuracy,
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        "predictions": predictions,
        "true_labels": true_labels,
        "model_name": model_name
    }
    
    # Save metrics to file
    metrics_path = f"{DRIVE_RESULTS}/zeroshot_metrics_{encoder_short_name}.json"
    with open(metrics_path, 'w') as f:
        json.dump({
            "encoder_name": encoder_name,
            "accuracy": float(accuracy),
            "f1_macro": float(f1_macro),
            "f1_weighted": float(f1_weighted),
            "num_test_samples": len(true_labels),
            "num_classes": len(disease_classes)
        }, f, indent=2)
    
    print(f"  ‚úÖ Metrics saved to: {metrics_path}")
    
    # Clean up GPU memory
    del model
    del tokenizer
    torch.cuda.empty_cache()

print("\n" + "="*80)
print("üéâ ZERO-SHOT EVALUATION COMPLETE!")
print("="*80)
if RUN_TEXT_ENCODER_COMPARISON:
    print(f"\nResults Summary:")
    for encoder_name, results in all_results.items():
        print(f"\n  {encoder_name}:")
        print(f"    Accuracy: {results['accuracy'] * 100:.2f}%")
        print(f"    F1 (Macro): {results['f1_macro'] * 100:.2f}%")
        print(f"    F1 (Weighted): {results['f1_weighted'] * 100:.2f}%")

## Cell 7.3: Visualize Zero-Shot Confusion Matrices

Create confusion matrices for all evaluated models to visualize zero-shot classification performance.

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Check if we have results to visualize
if not all_results:
    print("‚ö†Ô∏è  No results to visualize!")
else:
    # Visualize confusion matrices for all models
    num_models = len(all_results)
    
    if num_models == 1:
        # Single model - larger single plot
        fig, ax = plt.subplots(1, 1, figsize=(12, 10))
        axes = [ax]
    else:
        # Multiple models - horizontal layout
        fig, axes = plt.subplots(1, num_models, figsize=(10 * num_models, 8))
    
    for idx, (encoder_name, results) in enumerate(all_results.items()):
        cm = confusion_matrix(results['true_labels'], results['predictions'])
        
        # Normalize confusion matrix
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Plot with smaller annotations and adjusted styling
        sns.heatmap(
            cm_norm,
            annot=True,
            fmt='.2f',
            cmap='Blues',
            xticklabels=disease_classes,
            yticklabels=disease_classes,
            ax=axes[idx],
            cbar_kws={'label': 'Normalized Count'},
            annot_kws={'fontsize': 7},  # Smaller annotation font
            vmin=0,
            vmax=1
        )
        
        # Title
        axes[idx].set_title(
            f'{encoder_name}\nZero-Shot Accuracy: {results["accuracy"]*100:.2f}%',
            fontsize=14,
            fontweight='bold',
            pad=15
        )
        
        # Axis labels
        axes[idx].set_xlabel('Predicted Disease', fontsize=11, fontweight='bold')
        axes[idx].set_ylabel('True Disease', fontsize=11, fontweight='bold')
        
        # Tick labels - rotate and adjust size
        axes[idx].set_xticklabels(
            axes[idx].get_xticklabels(),
            rotation=45,
            ha='right',
            fontsize=8
        )
        axes[idx].set_yticklabels(
            axes[idx].get_yticklabels(),
            rotation=0,
            fontsize=8
        )
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save with high DPI for better quality
    plt.savefig(
        f'{DRIVE_RESULTS}/zeroshot_confusion_matrices.png',
        dpi=200,
        bbox_inches='tight',
        facecolor='white'
    )
    plt.show()
    
    print(f"\n‚úÖ Confusion matrices saved to {DRIVE_RESULTS}/zeroshot_confusion_matrices.png")

# SECTION 8: Linear Probing Evaluation

Test the quality of learned visual representations by training a simple linear classifier (logistic regression) on frozen features.

Linear probing freezes the model weights and only trains a classifier on top, measuring how linearly separable the learned features are.

## Cell 8.1: Extract Features from Train and Test Sets

Extract frozen image features from all trained models to use for linear classifier training.

In [None]:
import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm
from PIL import Image
import torchvision.transforms as transforms
import numpy as np

# Image preprocessing
preprocess = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Helper function to extract features from a dataset
def extract_features(model, dataframe, desc="Extracting features"):
    """Extract image features from a dataframe of patients."""
    features = []
    labels = []
    
    with torch.no_grad():
        for idx, row in tqdm(dataframe.iterrows(), total=len(dataframe), desc=desc):
            patient_id = row['patient_id']
            
            # Load images
            left_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_left.jpg"
            right_img_path = f"{ODIR_IMAGES_DIR}/{patient_id}_right.jpg"
            
            if not os.path.exists(left_img_path) or not os.path.exists(right_img_path):
                continue
            
            left_img = Image.open(left_img_path).convert('RGB')
            right_img = Image.open(right_img_path).convert('RGB')
            
            left_tensor = preprocess(left_img).unsqueeze(0).cuda()
            right_tensor = preprocess(right_img).unsqueeze(0).cuda()
            
            # Encode binocular pair (RET-CLIP binocular architecture expects both images)
            img_output = model.encode_image(left_tensor, right_tensor)
            
            # Handle tuple return (some models return tuple of features)
            if isinstance(img_output, tuple):
                img_feat = img_output[0]
            else:
                img_feat = img_output
            
            # Normalize features
            img_feat = F.normalize(img_feat, dim=-1)
            
            features.append(img_feat.cpu())
            
            # Get true label (use primary keyword)
            left_kw = str(row['left_keywords']).split(',')[0].strip() if pd.notna(row['left_keywords']) else ""
            right_kw = str(row['right_keywords']).split(',')[0].strip() if pd.notna(row['right_keywords']) else ""
            primary_keyword = left_kw if left_kw and left_kw in disease_classes else right_kw
            
            if primary_keyword in disease_classes:
                labels.append(disease_classes.index(primary_keyword))
            else:
                labels.append(0)
    
    features = torch.cat(features, dim=0).numpy()
    labels = np.array(labels)
    
    return features, labels

# Store features for all models
all_linear_probe_data = {}

# Determine which models to evaluate
if RUN_TEXT_ENCODER_COMPARISON:
    models_to_evaluate = TEXT_ENCODERS
else:
    models_to_evaluate = [{
        "name": "PubMedBERT",
        "model_id": TEXT_MODEL,
        "description": "Default text encoder"
    }]

for encoder_config in models_to_evaluate:
    encoder_name = encoder_config["name"]
    encoder_model_id = encoder_config["model_id"]
    encoder_short_name = encoder_name.lower().replace('-', '').replace(' ', '')
    model_name = f"retclip_odir_{encoder_short_name}"
    
    print("\n" + "="*80)
    print(f"Extracting features for: {encoder_name}")
    print("="*80)
    
    # Load model
    from RET_CLIP.clip.model import CLIP
    
    vision_config_path = f"/content/retclip/RET_CLIP/clip/model_configs/{VISION_MODEL}.json"
    text_config_path = f"/content/retclip/RET_CLIP/clip/model_configs/{encoder_model_id}.json"
    
    with open(vision_config_path, 'r') as fv, open(text_config_path, 'r') as ft:
        model_cfg = json.load(fv)
        for k, v in json.load(ft).items():
            model_cfg[k] = v
    
    model = CLIP(**model_cfg)
    
    checkpoint_path = f"{DRIVE_CHECKPOINTS}/{model_name}/checkpoints/epoch_latest.pt"
    print(f"  Loading: {checkpoint_path}")
    
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    
    # Handle DDP state dict
    state_dict = checkpoint['state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        name = k.replace('module.', '')
        new_state_dict[name] = v
    
    model.load_state_dict(new_state_dict)
    model = model.cuda()
    model.eval()
    
    print("  ‚úÖ Model loaded")
    
    # Extract train features
    print(f"\n  Extracting train features from {len(train_df)} patients...")
    train_features, train_labels = extract_features(model, train_df, desc=f"  Train")
    print(f"  ‚úÖ Train features: {train_features.shape}")
    
    # Extract test features
    print(f"\n  Extracting test features from {len(test_df)} patients...")
    test_features, test_labels = extract_features(model, test_df, desc=f"  Test")
    print(f"  ‚úÖ Test features: {test_features.shape}")
    
    # Store data
    all_linear_probe_data[encoder_name] = {
        "train_features": train_features,
        "train_labels": train_labels,
        "test_features": test_features,
        "test_labels": test_labels,
        "model_name": model_name
    }
    
    # Clean up GPU memory
    del model
    torch.cuda.empty_cache()

print("\n" + "="*80)
print("‚úÖ Feature extraction complete for all models")
print("="*80)

## Cell 8.2: Train Logistic Regression Classifiers

Train a linear classifier on frozen features for each model and evaluate performance.

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, classification_report
import json

# Store linear probe results
all_linear_probe_results = {}

for encoder_name, data in all_linear_probe_data.items():
    encoder_short_name = encoder_name.lower().replace('-', '').replace(' ', '')
    
    print("\n" + "="*80)
    print(f"Training linear classifier for: {encoder_name}")
    print("="*80)
    
    # Train logistic regression
    print("  Training logistic regression...")
    clf = LogisticRegression(
        max_iter=1000,
        random_state=42,
        multi_class='multinomial',
        solver='lbfgs',
        n_jobs=-1
    )
    
    clf.fit(data['train_features'], data['train_labels'])
    print("  ‚úÖ Training complete")
    
    # Predict on test set
    print("  Evaluating on test set...")
    test_predictions = clf.predict(data['test_features'])
    
    # Compute metrics
    accuracy = accuracy_score(data['test_labels'], test_predictions)
    f1_macro = f1_score(data['test_labels'], test_predictions, average='macro', zero_division=0)
    f1_weighted = f1_score(data['test_labels'], test_predictions, average='weighted', zero_division=0)
    
    print(f"\n  üìä Linear Probe Results for {encoder_name}:")
    print(f"     Accuracy: {accuracy * 100:.2f}%")
    print(f"     F1 (Macro): {f1_macro * 100:.2f}%")
    print(f"     F1 (Weighted): {f1_weighted * 100:.2f}%")
    
    # Store results
    all_linear_probe_results[encoder_name] = {
        "accuracy": accuracy,
        "f1_macro": f1_macro,
        "f1_weighted": f1_weighted,
        "predictions": test_predictions,
        "true_labels": data['test_labels'],
        "classifier": clf
    }
    
    # Save metrics
    metrics_path = f"{DRIVE_RESULTS}/linear_probe_metrics_{encoder_short_name}.json"
    with open(metrics_path, 'w') as f:
        json.dump({
            "encoder_name": encoder_name,
            "accuracy": float(accuracy),
            "f1_macro": float(f1_macro),
            "f1_weighted": float(f1_weighted),
            "num_train_samples": len(data['train_labels']),
            "num_test_samples": len(data['test_labels']),
            "num_classes": len(disease_classes)
        }, f, indent=2)
    
    print(f"  ‚úÖ Metrics saved to: {metrics_path}")

print("\n" + "="*80)
print("üéâ LINEAR PROBING COMPLETE!")
print("="*80)
if RUN_TEXT_ENCODER_COMPARISON:
    print(f"\nResults Summary:")
    for encoder_name, results in all_linear_probe_results.items():
        print(f"\n  {encoder_name}:")
        print(f"    Accuracy: {results['accuracy'] * 100:.2f}%")
        print(f"    F1 (Macro): {results['f1_macro'] * 100:.2f}%")
        print(f"    F1 (Weighted): {results['f1_weighted'] * 100:.2f}%")

## Cell 8.3: Visualize Linear Probing Confusion Matrices

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Check if we have results to visualize
if not all_linear_probe_results:
    print("‚ö†Ô∏è  No results to visualize!")
else:
    # Create confusion matrices for all models
    num_models = len(all_linear_probe_results)
    
    if num_models == 1:
        # Single model - larger single plot
        fig, ax = plt.subplots(1, 1, figsize=(12, 10))
        axes = [ax]
    else:
        # Multiple models - horizontal layout
        fig, axes = plt.subplots(1, num_models, figsize=(10 * num_models, 8))
    
    for idx, (encoder_name, results) in enumerate(all_linear_probe_results.items()):
        # Compute confusion matrix
        cm = confusion_matrix(results['true_labels'], results['predictions'])
        
        # Normalize by row (true labels)
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Plot with smaller annotations and adjusted styling
        sns.heatmap(
            cm_norm,
            annot=True,
            fmt='.2f',
            cmap='Greens',
            xticklabels=disease_classes,
            yticklabels=disease_classes,
            ax=axes[idx],
            cbar_kws={'label': 'Normalized Count'},
            annot_kws={'fontsize': 7},  # Smaller annotation font
            vmin=0,
            vmax=1
        )
        
        # Title
        axes[idx].set_title(
            f'{encoder_name}\nLinear Probe Accuracy: {results["accuracy"]*100:.2f}%',
            fontsize=14,
            fontweight='bold',
            pad=15
        )
        
        # Axis labels
        axes[idx].set_xlabel('Predicted Disease', fontsize=11, fontweight='bold')
        axes[idx].set_ylabel('True Disease', fontsize=11, fontweight='bold')
        
        # Tick labels - rotate and adjust size
        axes[idx].set_xticklabels(
            axes[idx].get_xticklabels(),
            rotation=45,
            ha='right',
            fontsize=8
        )
        axes[idx].set_yticklabels(
            axes[idx].get_yticklabels(),
            rotation=0,
            fontsize=8
        )
    
    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    # Save with high DPI for better quality
    plt.savefig(
        f'{DRIVE_RESULTS}/linear_probe_confusion_matrices.png',
        dpi=200,
        bbox_inches='tight',
        facecolor='white'
    )
    plt.show()
    
    print(f"\n‚úÖ Confusion matrices saved to {DRIVE_RESULTS}/linear_probe_confusion_matrices.png")
    
    # Also save individual confusion matrices for each model
    for encoder_name, results in all_linear_probe_results.items():
        encoder_short_name = encoder_name.lower().replace('-', '').replace(' ', '')
        
        # Compute confusion matrix
        cm = confusion_matrix(results['true_labels'], results['predictions'])
        cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        
        # Create single plot
        fig_single, ax_single = plt.subplots(figsize=(12, 10))
        sns.heatmap(
            cm_norm,
            annot=True,
            fmt='.2f',
            cmap='Greens',
            xticklabels=disease_classes,
            yticklabels=disease_classes,
            ax=ax_single,
            cbar_kws={'label': 'Normalized Count'},
            annot_kws={'fontsize': 7},
            vmin=0,
            vmax=1
        )
        
        ax_single.set_title(
            f'Linear Probe: {encoder_name}\nAccuracy: {results["accuracy"]*100:.2f}%',
            fontsize=14,
            fontweight='bold',
            pad=15
        )
        ax_single.set_xlabel('Predicted Disease', fontsize=11, fontweight='bold')
        ax_single.set_ylabel('True Disease', fontsize=11, fontweight='bold')
        
        ax_single.set_xticklabels(
            ax_single.get_xticklabels(),
            rotation=45,
            ha='right',
            fontsize=8
        )
        ax_single.set_yticklabels(
            ax_single.get_yticklabels(),
            rotation=0,
            fontsize=8
        )
        
        plt.tight_layout()
        
        cm_path = f"{DRIVE_RESULTS}/linear_probe_cm_{encoder_short_name}.png"
        fig_single.savefig(cm_path, dpi=200, bbox_inches='tight', facecolor='white')
        plt.close(fig_single)
        
        print(f"  ‚úÖ Saved: {cm_path}")

# SECTION 9: Final Report

Generate a comprehensive report comparing all trained models and summarizing the entire pipeline.

## Cell 9.1: Text Encoder Comparison Table

Compare zero-shot and linear probing performance across all text encoders.

In [None]:
import pandas as pd

# Create comparison table
comparison_data = []

for encoder_name in all_results.keys():
    zs_results = all_results[encoder_name]
    lp_results = all_linear_probe_results[encoder_name]
    
    comparison_data.append({
        "Text Encoder": encoder_name,
        "Zero-Shot Accuracy": f"{zs_results['accuracy'] * 100:.2f}%",
        "Zero-Shot F1 (Macro)": f"{zs_results['f1_macro'] * 100:.2f}%",
        "Zero-Shot F1 (Weighted)": f"{zs_results['f1_weighted'] * 100:.2f}%",
        "Linear Probe Accuracy": f"{lp_results['accuracy'] * 100:.2f}%",
        "Linear Probe F1 (Macro)": f"{lp_results['f1_macro'] * 100:.2f}%",
        "Linear Probe F1 (Weighted)": f"{lp_results['f1_weighted'] * 100:.2f}%"
    })

comparison_df = pd.DataFrame(comparison_data)

print("="*120)
print("TEXT ENCODER COMPARISON - ZERO-SHOT VS LINEAR PROBING")
print("="*120)
print()
display(comparison_df)
print()

# Find best performers
zs_best = max(all_results.items(), key=lambda x: x[1]['accuracy'])
lp_best = max(all_linear_probe_results.items(), key=lambda x: x[1]['accuracy'])

print(f"üèÜ Best Zero-Shot Performance: {zs_best[0]} ({zs_best[1]['accuracy'] * 100:.2f}%)")
print(f"üèÜ Best Linear Probe Performance: {lp_best[0]} ({lp_best[1]['accuracy'] * 100:.2f}%)")

# Save comparison table
comparison_path = f"{DRIVE_RESULTS}/text_encoder_comparison.csv"
comparison_df.to_csv(comparison_path, index=False)
print(f"\n‚úÖ Comparison table saved to: {comparison_path}")

## Cell 9.2: Generate Comprehensive Final Report

In [None]:
import os
from pathlib import Path

print("="*120)
print("VERIFYING ALL ARTIFACTS")
print("="*120)

# Expected artifacts
artifacts = {
    "Prompts": [
        f"{DRIVE_PROMPTS}/odir_retclip_prompts.csv"
    ],
    "Data Splits": [
        f"{DRIVE_DATA}/train_patients.csv",
        f"{DRIVE_DATA}/test_patients.csv",
        f"{DRIVE_DATA}/train_imgs.tsv",
        f"{DRIVE_DATA}/test_imgs.tsv",
        f"{DRIVE_DATA}/train_texts.jsonl",
        f"{DRIVE_DATA}/test_texts.jsonl"
    ],
    "LMDB Databases": [
        f"{DRIVE_LMDB}/train/imgs",
        f"{DRIVE_LMDB}/test/imgs"
    ],
    "Model Checkpoints": [],
    "Results & Metrics": [
        f"{DRIVE_RESULTS}/odir_dataset_statistics.png",
        f"{DRIVE_RESULTS}/text_encoder_comparison.csv",
        f"{DRIVE_RESULTS}/final_report.txt"
    ]
}

# Add checkpoints for all trained models
for encoder_config in (TEXT_ENCODERS if RUN_TEXT_ENCODER_COMPARISON else [{"name": "PubMedBERT"}]):
    encoder_short_name = encoder_config['name'].lower().replace('-', '').replace(' ', '')
    artifacts["Model Checkpoints"].append(
        f"{DRIVE_CHECKPOINTS}/retclip_odir_{encoder_short_name}/checkpoints/epoch_latest.pt"
    )
    artifacts["Results & Metrics"].extend([
        f"{DRIVE_RESULTS}/zeroshot_metrics_{encoder_short_name}.json",
        f"{DRIVE_RESULTS}/zeroshot_confusion_matrix_{encoder_short_name}.png",
        f"{DRIVE_RESULTS}/linear_probe_metrics_{encoder_short_name}.json",
        f"{DRIVE_RESULTS}/linear_probe_confusion_matrix_{encoder_short_name}.png"
    ])

# Add comparison plots if multiple models
if RUN_TEXT_ENCODER_COMPARISON:
    artifacts["Results & Metrics"].extend([
        f"{DRIVE_RESULTS}/zeroshot_comparison_all_models.png",
        f"{DRIVE_RESULTS}/linear_probe_comparison_all_models.png"
    ])

# Check each artifact
total_artifacts = 0
found_artifacts = 0
missing_artifacts = []

for category, paths in artifacts.items():
    print(f"\n{category}:")
    for path in paths:
        total_artifacts += 1
        if os.path.exists(path):
            # Get file size
            if os.path.isfile(path):
                size_mb = os.path.getsize(path) / (1024 * 1024)
                print(f"  ‚úÖ {os.path.basename(path)} ({size_mb:.2f} MB)")
                found_artifacts += 1
            else:
                # Directory (LMDB)
                print(f"  ‚úÖ {os.path.basename(path)} (directory)")
                found_artifacts += 1
        else:
            print(f"  ‚ùå {os.path.basename(path)} - NOT FOUND")
            missing_artifacts.append(path)

print("\n" + "="*120)
print(f"ARTIFACT SUMMARY: {found_artifacts}/{total_artifacts} found")
print("="*120)

if missing_artifacts:
    print(f"\n‚ö†Ô∏è Missing {len(missing_artifacts)} artifact(s):")
    for path in missing_artifacts:
        print(f"  - {path}")
else:
    print("\n‚úÖ All artifacts successfully created and saved to Google Drive!")
    print(f"\nüìÅ Base directory: {DRIVE_BASE}")
    print(f"\nüéâ Pipeline complete! Ready for analysis and publication.")

print("\n" + "="*120)
print("NEXT STEPS")
print("="*120)
if TEST_MODE:
    print("""
1. ‚úÖ Review test results to ensure pipeline works correctly
2. ‚ö†Ô∏è Set TEST_MODE = False in Cell 1.5 for full training
3. üöÄ Run full pipeline on all 5,000 patients (~18-24 hours)
4. üìä Analyze final results and write research paper
5. üìÑ Target venues: MICCAI, IEEE TMI, or similar
""")
else:
    print("""
1. ‚úÖ Full pipeline complete on all patients
2. üìä Analyze results and create visualizations
3. üìù Write research paper draft
4. üî¨ Consider additional experiments:
   - Fine-tuning on downstream tasks
   - Ablation studies
   - Cross-dataset validation
5. üìÑ Submit to target venue (MICCAI, IEEE TMI, etc.)
""")

## Cell 9.3: List All Artifacts

Verify all output files were successfully created and saved to Google Drive.

In [None]:
# List all expected artifacts and verify they exist
import os

print("Verifying Output Artifacts:")
print("=" * 120)

# Define expected artifacts by category
artifacts = {
    "Prompts": [
        f"{DRIVE_PROMPTS}/odir_retclip_prompts.csv"
    ],
    "Data Splits": [
        f"{DRIVE_DATA}/train_patients.csv",
        f"{DRIVE_DATA}/test_patients.csv",
        f"{DRIVE_DATA}/odir_train_imgs.tsv",
        f"{DRIVE_DATA}/odir_test_imgs.tsv",
        f"{DRIVE_DATA}/odir_train_texts.jsonl",
        f"{DRIVE_DATA}/odir_test_texts.jsonl"
    ],
    "LMDB Databases": [
        f"{DRIVE_LMDB}/train",
        f"{DRIVE_LMDB}/test"
    ],
    "Model Checkpoints": [],
    "Results & Metrics": [
        f"{DRIVE_RESULTS}/odir_dataset_statistics.png",
        f"{DRIVE_RESULTS}/zeroshot_confusion_matrices.png",
        f"{DRIVE_RESULTS}/linear_probe_confusion_matrices.png",
        f"{DRIVE_RESULTS}/final_report.txt"
    ]
}

# Add checkpoints for all trained models
if RUN_TEXT_ENCODER_COMPARISON:
    encoders_list = TEXT_ENCODERS
else:
    encoders_list = [{"name": "PubMedBERT", "model_id": TEXT_MODEL}]

for encoder_config in encoders_list:
    encoder_short_name = encoder_config['name'].lower().replace('-', '').replace(' ', '')
    model_name = f"retclip_odir_{encoder_short_name}"
    artifacts["Model Checkpoints"].append(
        f"{DRIVE_CHECKPOINTS}/{model_name}/checkpoints/epoch_latest.pt"
    )

# Verify each artifact
total_artifacts = 0
found_artifacts = 0

for category, paths in artifacts.items():
    print(f"\n{category}:")
    print("-" * 120)
    
    for path in paths:
        total_artifacts += 1
        exists = os.path.exists(path)
        found_artifacts += exists
        
        status = "‚úÖ" if exists else "‚ùå"
        print(f"  {status} {path}")

print("\n" + "=" * 120)
print(f"Artifacts Summary: {found_artifacts}/{total_artifacts} found")

if found_artifacts == total_artifacts:
    print("‚úÖ All artifacts successfully created!")
else:
    print(f"‚ö†Ô∏è  {total_artifacts - found_artifacts} artifacts missing - check for errors above")