<a href="https://colab.research.google.com/github/adihatake/Improving-OCT-Interpretation-through-Retrieval-Guided-Diagnosis-and-LLM-Based-Reporting/blob/main/Indicium_Download_and_Split_Datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Load dependencies

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
! pip install -U datasets

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd

import matplotlib.pyplot as plt
import datasets
import numpy as np


from PIL import Image
from datasets import load_dataset, load_from_disk

from torch.utils.data import DataLoader
from torchvision import transforms

# Keep constant for reproducibility
SEED = 42

# Download biomarker benchmarking dataset

In [None]:
olives_biomarker_benchmark = load_dataset('gOLIVES/OLIVES_Dataset', 'biomarker_detection', split = 'test')
olives_biomarker_benchmark = olives_biomarker_benchmark.remove_columns(['Eye_ID', 'BCVA', "CST", "Patient_ID"])


In [None]:
olives_biomarker_benchmark.to_parquet("/content/drive/MyDrive/Indicium/OLIVES_biomarker_benchmark_dataset.parquet")

Creating parquet from Arrow format:   0%|          | 0/39 [00:00<?, ?ba/s]

968454669

# Download OLIVES train split from Huggingface

In [None]:
# Download data from huggingface
olives = load_dataset('gOLIVES/OLIVES_Dataset', 'disease_classification', split = 'train')
olives = olives.remove_columns(['Scan (n/49)', 'Eye_ID', 'BCVA', "CST"])


In [None]:
# Convert to a pandas dataset
olives_df = olives.to_pandas()
olives_df.head()

In [None]:
# Filter out rows that have 'NaN'
filtered_df = olives_df.dropna()

In [None]:
# Convert the pandas dataset back to a huggingface dataset
filtered_dataset = datasets.Dataset.from_pandas(filtered_df)

In [None]:
# Count instances per Patient_ID
from collections import Counter

patient_counts = Counter(filtered_df['Patient_ID'])
print(patient_counts)
unique_patients = sorted(list(patient_counts.keys()))
print(f"Total unique patients: {len(unique_patients)}")



Counter({201: 392, 213: 392, 226: 392, 222: 384, 204: 343, 234: 294, 243: 294, 249: 294, 232: 290, 58: 196, 59: 196, 60: 196, 61: 196, 62: 196, 63: 196, 64: 196, 65: 196, 66: 196, 67: 196, 68: 196, 69: 196, 70: 196, 71: 196, 72: 196, 73: 196, 74: 196, 75: 196, 76: 196, 77: 196, 57: 196, 78: 196, 80: 196, 81: 196, 82: 196, 83: 196, 84: 196, 85: 196, 86: 196, 87: 196, 88: 196, 89: 196, 90: 196, 91: 196, 92: 196, 93: 196, 94: 196, 95: 196, 96: 196, 203: 196, 215: 196, 217: 196, 221: 196, 225: 196, 229: 196, 210: 196, 212: 196, 216: 196, 218: 196, 219: 196, 245: 196, 248: 196, 253: 196, 207: 196, 208: 196, 209: 196, 211: 196, 220: 196, 224: 196, 228: 196, 230: 196, 231: 196, 235: 196, 237: 196, 240: 196, 242: 196, 254: 196, 255: 196, 206: 184, 79: 98, 236: 98, 238: 98, 239: 98, 241: 98, 247: 98, 251: 98, 252: 98, 256: 98})
Total unique patients: 87


# Split into train

In [None]:
from sklearn.model_selection import train_test_split

patient_labels = filtered_df.groupby('Patient_ID')['Disease Label'].first().reset_index()

# Split into train+val (80%) and test (20%)
train_patients, test_patients = train_test_split(
    patient_labels['Patient_ID'],
    test_size=0.2,
    stratify=patient_labels['Disease Label'],
    random_state=SEED
)

# Select rows for each split based on patient IDs
train_df = filtered_df[filtered_df['Patient_ID'].isin(train_patients)]
test_val_df = filtered_df[filtered_df['Patient_ID'].isin(test_patients)]


# Split into validate and test

In [None]:
# Get unique patients with their corresponding labels (e.g., majority label or any label for stratification)
test_val_patient_labels = test_val_df.groupby('Patient_ID')['Disease Label'].first().reset_index()


val_patients, test_patients = train_test_split(
    test_val_patient_labels['Patient_ID'],
    test_size=0.5,
    stratify=test_val_patient_labels['Disease Label'],
    random_state=SEED
)


In [None]:
val_labels = test_val_df[test_val_df['Patient_ID'].isin(val_patients)]
test_labels = test_val_df[test_val_df['Patient_ID'].isin(test_patients)]

val_label_counts = Counter(val_labels['Disease Label'])
test_label_counts = Counter(test_labels['Disease Label'])

print("Val label distribution:", val_label_counts)
print("Test label distribution:", test_label_counts)


val_df = test_val_df[test_val_df['Patient_ID'].isin(val_patients)]
test_df = test_val_df[test_val_df['Patient_ID'].isin(test_patients)]


Val label distribution: Counter({1.0: 968, 0.0: 686})
Test label distribution: Counter({0.0: 784, 1.0: 784})


# Verify and Save

In [None]:
# Convert to HuggingFace dataset:

train_dataset = datasets.Dataset.from_pandas(train_df)
val_dataset = datasets.Dataset.from_pandas(val_df.reset_index(drop=True))
test_dataset = datasets.Dataset.from_pandas(test_df.reset_index(drop=True))



In [None]:
from collections import Counter

# Extract disease labels from train_dataset
train_labels = train_dataset['Disease Label']
train_label_counts = Counter(train_labels)

# Extract disease labels from test_dataset
test_labels = test_dataset['Disease Label']
test_label_counts = Counter(test_labels)

# Extract disease labels from val_dataset
val_labels = val_dataset['Disease Label']
val_label_counts = Counter(val_labels)


print("Train label distribution:", train_label_counts)
print("Test label distribution:", test_label_counts)
print("Val label distribution:", val_label_counts)

Train label distribution: Counter({1.0: 7975, 0.0: 6272})
Test label distribution: Counter({0.0: 784, 1.0: 784})
Val label distribution: Counter({1.0: 968, 0.0: 686})


In [None]:
# Extract patient IDs from train and test
train_patients = set(train_dataset['Patient_ID'])
val_patients = set(val_dataset['Patient_ID'])
test_patients = set(test_dataset['Patient_ID'])

print(train_patients)
print(val_patients)
print(test_patients)


# Find any overlapping patients
overlapping_patients_train_val = train_patients.intersection(val_patients)
overlapping_patients_train_test = train_patients.intersection(test_patients)

overlapping_patients_val_test = val_patients.intersection(test_patients)


if overlapping_patients_train_val:
    print(f"{len(overlapping_patients_train_val)} patient(s) appear in both train and validation datasets.")
    print("Example overlapping patient IDs:", list(overlapping_patients_train_val)[:5])

if overlapping_patients_train_test:
    print(f"{len(overlapping_patients_train_test)} patient(s) appear in both train and test datasets.")
    print("Example overlapping patient IDs:", list(overlapping_patients_train_test)[:5])

if overlapping_patients_val_test:
    print(f"{len(overlapping_patients_val_test)} patient(s) appear in both validation and test datasets.")
    print("Example overlapping patient IDs:", list(overlapping_patients_val_test)[:5])

else:
    print("No patient ID leakage between train and test datasets.")


{208, 209, 220, 210, 212, 57, 58, 59, 60, 61, 62, 63, 216, 65, 66, 68, 69, 70, 72, 73, 74, 201, 76, 204, 78, 207, 80, 81, 82, 83, 84, 213, 86, 87, 88, 89, 90, 91, 92, 93, 94, 215, 96, 225, 226, 218, 219, 229, 222, 228, 232, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 245, 248, 249, 251, 252, 253, 254, 255}
{64, 224, 67, 203, 206, 79, 217, 221, 95}
{256, 230, 71, 231, 75, 77, 211, 85, 247}
No patient ID leakage between train and test datasets.


In [None]:
val_listing = list(set(val_dataset['Patient_ID']))
test_listing = list(set(test_dataset['Patient_ID']))
train_listing = list(set(train_dataset['Patient_ID']))

print(len(val_listing))
print(len(test_listing))
print(len(train_listing))

9
9
69


In [None]:
# Save as a parquet file
#filtered_dataset.to_parquet("/content/drive/MyDrive/Indicium/OLIVES_dataset.parquet")
train_dataset.to_parquet("/content/drive/MyDrive/Indicium/OLIVES_train_42.parquet")
test_dataset.to_parquet("/content/drive/MyDrive/Indicium/OLIVES_test_.parquet")
val_dataset.to_parquet("/content/drive/MyDrive/Indicium/OLIVES_val.parquet")

Creating parquet from Arrow format:   0%|          | 0/143 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/16 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

413946748