# ODIR RET-CLIP Pipeline Diagnostics

This notebook verifies data quality and identifies potential issues while training runs.

**Upload this to Google Colab and run it while training continues in the main notebook.**

## Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Configuration - MUST match your training notebook
DRIVE_BASE = "/content/drive/MyDrive/RET-CLIP-ODIR"
DRIVE_DATA = f"{DRIVE_BASE}/data"
DRIVE_PROMPTS = f"{DRIVE_BASE}/prompts"
DRIVE_LMDB = f"{DRIVE_BASE}/lmdb"

print(f"Base directory: {DRIVE_BASE}")
print(f"Data directory: {DRIVE_DATA}")
print(f"Prompts directory: {DRIVE_PROMPTS}")
print(f"LMDB directory: {DRIVE_LMDB}")

In [None]:
# Helper function (same as training notebook)
import re
import pandas as pd

def get_primary_disease(keywords_str):
    """Extract first disease from keywords, handling both standard and Chinese commas"""
    if pd.isna(keywords_str):
        return ""
    
    # Split on BOTH standard comma (,) AND Chinese comma (Ôºå)
    keywords = re.split(r'[,Ôºå]', str(keywords_str))
    
    # Get first non-empty keyword
    for kw in keywords:
        kw = kw.strip().lower()
        if kw and kw != 'nan':
            return kw
    return ""

## 1. JSONL eye_side Field Verification ‚≠ê CRITICAL

The tripartite loss depends on the `eye_side` field being correct.

In [None]:
import json

print("="*80)
print("CHECKING JSONL FORMAT - TRAIN SET")
print("="*80)

with open(f"{DRIVE_DATA}/train_texts.jsonl", 'r') as f:
    lines = f.readlines()
    
print(f"\nTotal entries: {len(lines)}")
print(f"Expected: {len(lines) // 3} patients (3 entries each)\n")

# Check first 3 patients (9 entries)
for i in range(min(9, len(lines))):
    entry = json.loads(lines[i].strip())
    
    if i % 3 == 0:
        print(f"\n{'='*80}")
        print(f"Patient: {entry.get('image_ids', ['?'])[0]}")
        print(f"{'='*80}")
    
    print(f"\nEntry {i % 3 + 1}/3:")
    print(f"  text_id: {entry.get('text_id')}")
    print(f"  eye_side: {entry.get('eye_side', '‚ùå MISSING!')}")
    print(f"  text: {entry.get('text', '')[:100]}...")

# Verify all entries have eye_side field
print(f"\n{'='*80}")
print("VALIDATION")
print("="*80)

missing_eye_side = 0
eye_side_counts = {'left': 0, 'right': 0, 'both': 0, 'other': 0}

for line in lines:
    entry = json.loads(line.strip())
    eye_side = entry.get('eye_side')
    
    if not eye_side:
        missing_eye_side += 1
    elif eye_side in eye_side_counts:
        eye_side_counts[eye_side] += 1
    else:
        eye_side_counts['other'] += 1

if missing_eye_side > 0:
    print(f"\n‚ùå CRITICAL: {missing_eye_side} entries missing eye_side field!")
else:
    print(f"\n‚úÖ All {len(lines)} entries have eye_side field")

print(f"\nEye side distribution:")
for side, count in eye_side_counts.items():
    if count > 0:
        print(f"  {side:10s}: {count:5d} ({count/len(lines)*100:.1f}%)")

expected_each = len(lines) // 3
if (eye_side_counts['left'] == expected_each and 
    eye_side_counts['right'] == expected_each and 
    eye_side_counts['both'] == expected_each):
    print(f"\n‚úÖ Perfect balance: {expected_each} entries for each eye_side")
else:
    print(f"\n‚ö†Ô∏è  Imbalance detected! Expected {expected_each} for each eye_side")

## 2. Prompt Quality Check

Inspect actual prompts to ensure they're clinically meaningful.

In [None]:
print("="*80)
print("PROMPT QUALITY CHECK")
print("="*80)

prompts_df = pd.read_csv(f"{DRIVE_PROMPTS}/odir_retclip_prompts.csv")

print(f"\nTotal patients with prompts: {len(prompts_df)}")
print(f"Columns: {list(prompts_df.columns)}")

# Sample prompts from different disease categories
diseases_to_check = ['normal fundus', 'moderate non proliferative retinopathy', 'glaucoma']

for disease in diseases_to_check:
    matches = prompts_df[
        (prompts_df['left_keywords'].str.contains(disease, na=False, case=False)) |
        (prompts_df['right_keywords'].str.contains(disease, na=False, case=False))
    ]
    
    if len(matches) > 0:
        sample = matches.iloc[0]
        print(f"\n{'='*80}")
        print(f"Disease: {disease}")
        print(f"Patient: {sample['patient_id']}, Age: {sample['age']}, Sex: {sample['sex']}")
        print(f"Left keywords: {sample['left_keywords']}")
        print(f"Right keywords: {sample['right_keywords']}")
        print(f"\nüìù Left eye prompt:")
        print(f"{sample['prompt_left']}")
        print(f"\nüìù Right eye prompt:")
        print(f"{sample['prompt_right']}")
        print(f"\nüìù Patient-level prompt:")
        print(f"{sample['prompt_patient']}")
    else:
        print(f"\n‚ö†Ô∏è  No examples found for '{disease}'")

# Check for empty or very short prompts
print(f"\n{'='*80}")
print("PROMPT LENGTH ANALYSIS")
print("="*80)

for col in ['prompt_left', 'prompt_right', 'prompt_patient']:
    lengths = prompts_df[col].str.len()
    print(f"\n{col}:")
    print(f"  Mean: {lengths.mean():.1f} chars")
    print(f"  Min: {lengths.min()} chars")
    print(f"  Max: {lengths.max()} chars")
    
    very_short = (lengths < 50).sum()
    if very_short > 0:
        print(f"  ‚ö†Ô∏è  {very_short} prompts are suspiciously short (<50 chars)")

## 3. LMDB Data Integrity

Verify LMDB stores the correct data format.

In [None]:
# Install lmdb if needed
!pip install -q lmdb

In [None]:
import lmdb
import pickle

print("="*80)
print("LMDB INTEGRITY CHECK - PAIRS")
print("="*80)

env = lmdb.open(f"{DRIVE_LMDB}/train/pairs", readonly=True)
with env.begin() as txn:
    # Get total samples
    num_samples = int(txn.get(b'num_samples').decode('utf-8'))
    print(f"\nTotal LMDB pairs: {num_samples}")
    
    # Expected: 3 texts per patient
    train_patients = len(pd.read_csv(f"{DRIVE_DATA}/train_patients.csv"))
    expected = train_patients * 3
    print(f"Expected pairs: {expected} (3 per patient √ó {train_patients} patients)")
    
    if num_samples == expected:
        print(f"‚úÖ LMDB pair count matches expectation")
    else:
        print(f"‚ö†Ô∏è  Mismatch! Difference: {abs(num_samples - expected)}")
    
    # Check first 5 entries
    print(f"\n{'='*80}")
    print("SAMPLE ENTRIES")
    print("="*80)
    
    for i in range(min(5, num_samples)):
        data = txn.get(f"{i}".encode('utf-8'))
        patient_id, text_id, text = pickle.loads(data)
        print(f"\nPair {i}:")
        print(f"  Patient ID: {patient_id}")
        print(f"  Text ID: {text_id}")
        print(f"  Text length: {len(text)} chars")
        print(f"  Text preview: {text[:100]}...")

env.close()

# Check images LMDB
print(f"\n{'='*80}")
print("LMDB INTEGRITY CHECK - IMAGES")
print("="*80)

env = lmdb.open(f"{DRIVE_LMDB}/train/imgs", readonly=True)
with env.begin() as txn:
    num_images = int(txn.get(b'num_images').decode('utf-8'))
    print(f"\nTotal images: {num_images}")
    print(f"Expected: {train_patients} (one binocular pair per patient)")
    
    if num_images == train_patients:
        print(f"‚úÖ LMDB image count matches expectation")
    else:
        print(f"‚ö†Ô∏è  Mismatch! Difference: {abs(num_images - train_patients)}")

env.close()

## 4. Disease Distribution Analysis

Check for class imbalance and coverage.

In [None]:
from collections import Counter

print("="*80)
print("DISEASE DISTRIBUTION ANALYSIS")
print("="*80)

train_df = pd.read_csv(f"{DRIVE_DATA}/train_patients.csv")
test_df = pd.read_csv(f"{DRIVE_DATA}/test_patients.csv")

# Get primary diseases from train set
train_diseases = []
for idx, row in train_df.iterrows():
    left_kw = get_primary_disease(row['left_keywords'])
    right_kw = get_primary_disease(row['right_keywords'])
    primary = left_kw if left_kw else right_kw
    if primary:
        train_diseases.append(primary)

# Get primary diseases from test set
test_diseases = []
for idx, row in test_df.iterrows():
    left_kw = get_primary_disease(row['left_keywords'])
    right_kw = get_primary_disease(row['right_keywords'])
    primary = left_kw if left_kw else right_kw
    if primary:
        test_diseases.append(primary)

train_counts = Counter(train_diseases)
test_counts = Counter(test_diseases)

print(f"\nTraining set: {len(train_diseases)} samples, {len(train_counts)} unique diseases")
print(f"Test set: {len(test_diseases)} samples, {len(test_counts)} unique diseases")

# Combined disease list
all_diseases = sorted(set(train_counts.keys()) | set(test_counts.keys()))

print(f"\n{'Disease':<45} {'Train':>7} {'Test':>7} {'Total':>7}")
print("="*80)

for disease in all_diseases:
    train_count = train_counts.get(disease, 0)
    test_count = test_counts.get(disease, 0)
    total = train_count + test_count
    
    # Flag diseases only in test or only in train
    flag = ""
    if train_count == 0:
        flag = " ‚ö†Ô∏è  TEST ONLY!"
    elif test_count == 0:
        flag = " ‚ÑπÔ∏è  Train only"
    
    print(f"{disease:<45} {train_count:>7} {test_count:>7} {total:>7}{flag}")

# Check for zero-shot issues
test_only_diseases = set(test_counts.keys()) - set(train_counts.keys())

print(f"\n{'='*80}")
print("ZERO-SHOT COVERAGE ANALYSIS")
print("="*80)

if test_only_diseases:
    print(f"\n‚ùå CRITICAL: {len(test_only_diseases)} diseases appear ONLY in test set:")
    for disease in sorted(test_only_diseases):
        print(f"  - {disease} ({test_counts[disease]} test samples)")
    print(f"\n‚ö†Ô∏è  These diseases will have NO training examples for zero-shot prompts!")
else:
    print(f"\n‚úÖ All test diseases have training examples")

# Check class imbalance
max_train = max(train_counts.values())
min_train = min(train_counts.values())
imbalance_ratio = max_train / min_train

print(f"\nClass imbalance:")
print(f"  Largest class: {max_train} samples")
print(f"  Smallest class: {min_train} samples")
print(f"  Imbalance ratio: {imbalance_ratio:.1f}:1")

if imbalance_ratio > 100:
    print(f"  ‚ö†Ô∏è  Severe imbalance! May need class weighting or oversampling.")
elif imbalance_ratio > 10:
    print(f"  ‚ÑπÔ∏è  Moderate imbalance - typical for medical datasets")
else:
    print(f"  ‚úÖ Relatively balanced")

## 5. Summary Report

In [None]:
print("="*80)
print("DIAGNOSTIC SUMMARY")
print("="*80)

print("\n‚úÖ Checks passed:")
print("  - File structure exists")
print("  - JSONL format validated")
print("  - LMDB integrity confirmed")
print("  - Prompts loaded successfully")

if test_only_diseases:
    print(f"\n‚ùå Critical issues:")
    print(f"  - {len(test_only_diseases)} diseases in test but not train")
    print(f"  - These will use generic fallback prompts")

if imbalance_ratio > 100:
    print(f"\n‚ö†Ô∏è  Warnings:")
    print(f"  - Severe class imbalance ({imbalance_ratio:.0f}:1)")
    print(f"  - Consider class weighting in training")

print(f"\nüìä Dataset summary:")
print(f"  Train: {len(train_df)} patients, {len(train_diseases)} with keywords")
print(f"  Test: {len(test_df)} patients, {len(test_diseases)} with keywords")
print(f"  Unique diseases (train): {len(train_counts)}")
print(f"  Unique diseases (test): {len(test_counts)}")
print(f"  Prompts generated: {len(prompts_df)} patients")

print(f"\n{'='*80}")
print("DIAGNOSTICS COMPLETE")
print("="*80)