# Script to generate all splits

In [13]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path



BASEDIR_MIMIC = "../data/mimic/"
# Set up plotting style
# Load preprocessed metadata
metadata_df = pd.read_csv(os.path.join(BASEDIR_MIMIC,'mimic_metadata_preprocessed.csv'))
# Get numerical columns (excluding subject_id, study_id etc)
numerical_cols = ['Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema', 
                 'No Finding', 'Pleural Effusion',
                 'Pneumonia', 'Pneumothorax']



#columns_of_interest = ["No Finding", "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion", "Pneumonia", "Pneumothorax"] #"No Finding" is first column

            


# Convert numerical values to boolean True/False
# True for positive cases (1.0), False otherwise
metadata_df[numerical_cols] = metadata_df[numerical_cols].apply(lambda x: x == 1.0)
label_cols = numerical_cols
# removed all entries with more than one label
metadata_df = metadata_df[metadata_df[label_cols].sum(axis=1) == 1]

#metadata_df = metadata_df[metadata_df["Pleural Other"] == False]
metadata_df[label_cols].sum()

Atelectasis         10029
Cardiomegaly        10614
Consolidation        1689
Edema                4966
No Finding          69677
Pleural Effusion     9301
Pneumonia            5385
Pneumothorax         3412
dtype: int64

In [14]:
# Create balanced test set with 283 images per class
test_size = 200 

# Initialize empty list to store test indices
test_indices = []

# For each label, randomly sample test_size images
for label in label_cols:
    # Get indices of positive cases for this label
    label_indices = metadata_df[metadata_df[label]].index.tolist()
    # Set random seed for reproducibility
    np.random.seed(1)
    # Randomly sample test_size indices
    sampled_indices = np.random.choice(label_indices, size=test_size, replace=False)
    
    test_indices.extend(sampled_indices)

# Convert to array and get unique indices (in case of any overlap)
test_indices = np.unique(test_indices)

# Create test dataframe
test_df = metadata_df.loc[test_indices]

# Create train dataframe by removing test subjects
train_subjects = set(metadata_df['subject_id']) - set(test_df['subject_id'])
train_df = metadata_df[metadata_df['subject_id'].isin(train_subjects)]

print(f"Train set size: {len(train_df)}")
print(f"Test set size: {len(test_df)}")
print("\nTest set label distribution:")
print(test_df[label_cols].sum())


Train set size: 105115
Test set size: 1600

Test set label distribution:
Atelectasis         200
Cardiomegaly        200
Consolidation       200
Edema               200
No Finding          200
Pleural Effusion    200
Pneumonia           200
Pneumothorax        200
dtype: int64


In [15]:
# Create balanced test set with 283 images per class
retrieve_size = 300 

# Initialize empty list to store test indices
retrieve_indices = []

# For each label, randomly sample retrieve_size images
for label in label_cols:
    # Get indices of positive cases for this label
    label_indices = train_df[train_df[label]].index.tolist()
    # Set random seed for reproducibility
    np.random.seed(1)
    # Randomly sample retrieve_size indices
    sampled_indices = np.random.choice(label_indices, size=retrieve_size, replace=False)
    
    retrieve_indices.extend(sampled_indices)

# Convert to array and get unique indices (in case of any overlap)
retrieve_indices = np.unique(retrieve_indices)

# Create retrieve dataframe
retrieve_df = train_df.loc[retrieve_indices]

# Create train dataframe by removing retrieve subjects
retrieve_subjects = set(train_df['subject_id']) - set(retrieve_df['subject_id']) - set(test_df["subject_id"])
train_df = train_df[train_df['subject_id'].isin(retrieve_subjects)]

print(f"Train set size: {len(train_df)}")
print(f"retrieve set size: {len(retrieve_df)}")
print("\nRetrieve set label distribution:")
print(retrieve_df[label_cols].sum())
train_df[label_cols].sum()

Train set size: 93063
retrieve set size: 2400

Retrieve set label distribution:
Atelectasis         300
Cardiomegaly        300
Consolidation       300
Edema               300
No Finding          300
Pleural Effusion    300
Pneumonia           300
Pneumothorax        300
dtype: int64


Atelectasis          7700
Cardiomegaly         7880
Consolidation         706
Edema                3227
No Finding          61820
Pleural Effusion     6489
Pneumonia            3846
Pneumothorax         1395
dtype: int64

In [16]:
def assert_empty_intersect(set_a, set_b): 
    assert len(set_a.intersection(set_b) ) == 0

assert_empty_intersect(set(test_df["subject_id"]),  set(train_df["subject_id"]))
assert_empty_intersect(set(test_df["subject_id"]),  set(retrieve_df["subject_id"]))
assert_empty_intersect(set(train_df["subject_id"]), set(retrieve_df["subject_id"]))

In [17]:
# Create balanced test set with 283 images per class
balanced_train_size = 706 

# Initialize empty list to store test indices
retrieve_indices = []

# For each label, randomly sample balanced_train_size images
for label in label_cols:
    # Get indices of positive cases for this label
    label_indices = train_df[train_df[label]].index.tolist()
    # Set random seed for reproducibility
    np.random.seed(1)
    # Randomly sample balanced_train_size indices
    sampled_indices = np.random.choice(label_indices, size=balanced_train_size, replace=False)
    
    retrieve_indices.extend(sampled_indices)

# Convert to array and get unique indices (in case of any overlap)
retrieve_indices = np.unique(retrieve_indices)

# Create retrieve dataframe
train_balanced_df = train_df.loc[retrieve_indices]

train_balanced_df[label_cols].sum()

Atelectasis         706
Cardiomegaly        706
Consolidation       706
Edema               706
No Finding          706
Pleural Effusion    706
Pneumonia           706
Pneumothorax        706
dtype: int64

In [18]:
train_df[label_cols].sum()

Atelectasis          7700
Cardiomegaly         7880
Consolidation         706
Edema                3227
No Finding          61820
Pleural Effusion     6489
Pneumonia            3846
Pneumothorax         1395
dtype: int64

In [19]:
# Export dataframes to CSV files
train_balanced_df.to_csv(os.path.join(BASEDIR_MIMIC, 'longtail_8_balanced_train.csv'), index=False)
train_df.to_csv(os.path.join(BASEDIR_MIMIC, 'longtail_8_train.csv'), index=False)
retrieve_df.to_csv(os.path.join(BASEDIR_MIMIC, 'longtail_8_balanced_retrieve.csv'), index=False)
test_df.to_csv(os.path.join(BASEDIR_MIMIC,'longtail_8_balanced_test.csv'), index=False)

In [31]:
# Create unbalanced train set with only 10 consolidation samples
disease_to_undersample = "Consolidation"
consolidation_indices = train_balanced_df[train_balanced_df[disease_to_undersample] == True].index.tolist()
np.random.seed(1)
sampled_consolidation = np.random.choice(consolidation_indices, size=10, replace=False)

# Get all non-consolidation samples
non_consolidation_indices = train_balanced_df[~train_balanced_df.index.isin(consolidation_indices)].index.tolist()

# Combine indices and create new dataframe
train_df_unbalanced = train_balanced_df.loc[non_consolidation_indices + list(sampled_consolidation)]

# Verify counts
print("Disease counts in unbalanced training set:")
print(train_df_unbalanced[label_cols].sum())

# Export to CSV
train_df_unbalanced.to_csv(os.path.join(BASEDIR_MIMIC, f'longtail_8_train_unbalanced_{disease_to_undersample}.csv'), index=False)


Disease counts in unbalanced training set:
Atelectasis         706
Cardiomegaly        706
Consolidation        10
Edema               706
No Finding          706
Pleural Effusion    706
Pneumonia           706
Pneumothorax        706
dtype: int64
