In [None]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import h5py
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from transformers import ViTFeatureExtractor, ViTModel, ViTConfig
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
data_path = Path('../MIMIC/physionet.org/content/mimic-cxr-jpg/get-zip/2.1.0/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0')
icu_data_path = Path('../MIMIC/MIMICIV/physionet.org/files/mimiciv/3.1/icu')
icustay_path = icu_data_path / 'icustays.csv.gz'
hosp_data_path = Path('../MIMIC/MIMICIV/physionet.org/files/mimiciv/3.1/hosp')
note_path = Path('../MIMIC/MIMICIV/physionet.org/files/mimic-iv-note/2.2/note')
discharge_data_path = note_path / 'discharge.csv.gz'
radiology_data_path = note_path / 'radiology.csv.gz'

In [None]:
def list_files_in_directory(directory, level=0, max_depth=2):
    if level >= max_depth:
        return
    
    for item in sorted(directory.iterdir()):
        indent = '    ' * level
        if item.is_dir():
            if level == max_depth - 1:
                count = sum(1 for _ in item.iterdir())
                print(f"{indent}[DIR] {item.name}: {count} items")
            else:
                print(f"{indent}[DIR] {item.name}")
                list_files_in_directory(item, level + 1, max_depth)
        else:
            print(f"{indent}{item.name}")

if data_path.exists():
    list_files_in_directory(data_path, max_depth=2)
else:
    print(f"Directory '{data_path}' does not exist.")

In [None]:
# Get all CSV and CSV.GZ files in the directory
metadata_path = Path('../MIMIC/physionet.org/content/mimic-cxr-jpg/get-zip/2.1.0/mimic-cxr-jpg-chest-radiographs-with-structured-labels-2.1.0/mimic-cxr-2.0.0-metadata.csv.gz')

df = pd.read_csv(metadata_path, compression="gzip" if metadata_path.suffix == ".gz" else None)
    
# Print DataFrame basic information
print("DataFrame Info:")
print(df.info())

# Print the first 5 rows
print("\nFirst 5 Rows:")
print(df.head())

# Print column names
print("\nColumn Names:")
print(df.columns.tolist())

# Print data types and non-null values count
print("\nData Types and Non-Null Value Count:")
print(df.dtypes)

# Print descriptive statistics for numeric columns
print("\nNumeric Column Statistics:")
print(df.describe())

# Print unique value counts for each column
print("\nUnique Value Count per Column:")
print(df.nunique())


In [None]:
# Get all CSV and CSV.GZ files in the directory
df_icustay = pd.read_csv(icustay_path, compression="gzip" if metadata_path.suffix == ".gz" else None)
    
# Print DataFrame basic information
print("DataFrame Info:")
print(df.info())

# Print the first 5 rows
print("\nFirst 5 Rows:")
print(df.head())

# Print column names
print("\nColumn Names:")
print(df.columns.tolist())

# Print data types and non-null values count
print("\nData Types and Non-Null Value Count:")
print(df.dtypes)

# Print descriptive statistics for numeric columns
print("\nNumeric Column Statistics:")
print(df.describe())

# Print unique value counts for each column
print("\nUnique Value Count per Column:")
print(df.nunique())


In [None]:
df["StudyDateStr"] = df["StudyDate"].astype(str)
def format_time(t):
    try:
        t = float(t)
        int_part = int(t)
        decimal_part = f"{t:.6f}".split(".")[1]  
        int_str = str(int_part).zfill(6)       
        return f"{int_str}.{decimal_part}"
    except:
        return "000000.000000"
df["StudyDateStr"] = df["StudyDate"].astype(str).str.zfill(8)
df["StudyTimeStr"] = df["StudyTime"].apply(format_time)

df["StudyDateTime"] = pd.to_datetime(
    df["StudyDateStr"] + " " + df["StudyTimeStr"],
    format="%Y%m%d %H%M%S.%f",
    errors="coerce"
)
print(df[["StudyDate", "StudyTime", "StudyDateTime"]])

In [None]:
def report_nat_rows(df, datetime_col="StudyDateTime", preview=5):
    if datetime_col not in df.columns:
        raise ValueError(f"❌ '{datetime_col}' is not in DataFrame.")

    nat_rows = df[df[datetime_col].isna()]
    count = len(nat_rows)
    print(f"⚠️ '{datetime_col}' Nat counts: {count}")

    if count > 0:
        print(f"\n📋 First {preview} as：")
        display(nat_rows.head(preview))
    
    return nat_rows
bad_rows = report_nat_rows(df, datetime_col="StudyDateTime")

In [None]:
# df["StudyDate"] = pd.to_datetime(df["StudyDate"].astype(str), format="%Y%m%d", errors="coerce")
df_sorted = df.sort_values(by=["subject_id", "StudyDateTime","ViewPosition"])
df_sorted["order"] = df_sorted.groupby("subject_id").cumcount() + 1
# df_sorted = df_sorted[df_sorted["ViewPosition"] == "AP"]
df_final = df_sorted[["subject_id", "study_id", "StudyDateTime" ,"dicom_id", "ViewPosition" ,"order"]]

In [None]:
df_final.head(10)

In [None]:
def report_time_spans(df_sorted, threshold_hours=24):
    df_sorted = df_sorted.copy()

    if "StudyDateTime" not in df_sorted.columns:
        raise ValueError("❗️Missing 'StudyDateTime' column. Please create it before running this function.")

    df_sorted = df_sorted[df_sorted["StudyDateTime"].notna()] 
    span_info = []

    for subject_id, group in df_sorted.groupby("subject_id"):
        group = group.sort_values("StudyDateTime").reset_index(drop=True)
        t0 = group["StudyDateTime"].iloc[0]
        t1 = group["StudyDateTime"].iloc[-1]
        span_hours = (t1 - t0).total_seconds() / 3600 if len(group) > 1 else 0.0

        span_info.append({
            "subject_id": subject_id,
            "start_time": t0,
            "end_time": t1,
            "span_hours": span_hours
        })

    span_df = pd.DataFrame(span_info)
    total = len(span_df)
    below_thresh = (span_df["span_hours"] < threshold_hours).sum()
    print(f"✅ {total} subject_ids in total; among them, {below_thresh} have a time span shorter than {threshold_hours} hours ({below_thresh/total:.1%}).")

    return span_df

def filter_by_time_window(df, first_24=False, last_24=False, last_48=False, return_df=False):
    df = df.copy()
    
    if "StudyDateTime" not in df.columns:
        raise ValueError("❌ No 'StudyDateTime'")

    df = df[df["StudyDateTime"].notna()] 
    result_frames = []

    for subject_id, group in df.groupby("subject_id"):
        group = group.sort_values("StudyDateTime").reset_index(drop=True)

        if len(group) == 1:
            result_frames.append(group)
            continue

        t0 = group["StudyDateTime"].iloc[0]
        t_end = group["StudyDateTime"].iloc[-1]

        group["relative_hours_from_start"] = (group["StudyDateTime"] - t0).dt.total_seconds() / 3600
        group["relative_hours_to_end"] = (t_end - group["StudyDateTime"]).dt.total_seconds() / 3600

        selected_rows = pd.DataFrame()

        if first_24:
            selected_rows = pd.concat([selected_rows, group[group["relative_hours_from_start"] <= 24]])

        if last_24:
            selected_rows = pd.concat([selected_rows, group[group["relative_hours_to_end"] <= 24]])

        if last_48:
            selected_rows = pd.concat([selected_rows, group[group["relative_hours_to_end"] <= 48]])

        if not selected_rows.empty:
            result_frames.append(selected_rows)

    if result_frames:
        combined = pd.concat(result_frames).drop_duplicates()
        total = combined["dicom_id"].nunique()
        print(f"✅ Total selected dicom_id: {total}")
        if return_df:
            return combined
    else:
        print("⚠️ No time window selected or no StudyDateTime available.")
        if return_df:
            return pd.DataFrame()

In [None]:
span_df = report_time_spans(df_final, threshold_hours=24)

In [None]:
df_merged = df_final.merge(df_icustay[["subject_id", "stay_id", "intime" ,"outtime","los"]], on='subject_id', how='left')

In [None]:
# df_merged[df_merged['subject_id'].isin([10000032, 10001217])]

In [None]:
# df_merged[df_merged['stay_id'].notna()]

In [None]:
df_merged['StudyDateTime'] = pd.to_datetime(df_merged['StudyDateTime'])
df_merged['intime'] = pd.to_datetime(df_merged['intime'])
df_merged['outtime'] = pd.to_datetime(df_merged['outtime'])

In [None]:
no_icu_count = df_merged['stay_id'].isna().sum()
icu_counts = df_merged.dropna(subset=['stay_id']).groupby('subject_id')['stay_id'].nunique()
more_than_once_icu = (icu_counts > 1).sum()
df_has_icu = df_merged[df_merged['stay_id'].notna()]
before_icu = df_has_icu[df_has_icu['StudyDateTime'] < df_has_icu['intime']].shape[0]
within_icu = df_has_icu[(df_has_icu['StudyDateTime'] >= df_has_icu['intime']) & 
                        (df_has_icu['StudyDateTime'] <= df_has_icu['outtime'])].shape[0]
after_icu = df_has_icu[df_has_icu['StudyDateTime'] > df_has_icu['outtime']].shape[0]

In [None]:
print("Summary:")
print(f"1. Number of images without an ICU stay: {no_icu_count}")
print(f"2. Number of subject_ids with more than one ICU stay: {more_than_once_icu}")
print(f"3. Distribution of StudyDateTime relative to ICU stay:")
print(f"   - Before ICU stay: {before_icu}")
print(f"   - During ICU stay: {within_icu}")
print(f"   - After ICU stay: {after_icu}")


In [None]:
first_icu = (
    df_merged[df_merged['stay_id'].notna()]
    .sort_values(['subject_id', 'intime']) 
    .drop_duplicates(subset='subject_id', keep='first') 
)
df_first_icu = df_merged[['subject_id', 'StudyDateTime', 'dicom_id', 'ViewPosition', 'order']].drop_duplicates()
df_first_icu = df_first_icu.merge(
    first_icu[['subject_id', 'stay_id', 'intime', 'outtime', 'los']], 
    on='subject_id', how='left'
)
df_first_icu['StudyDateTime'] = pd.to_datetime(df_first_icu['StudyDateTime'])
df_first_icu['intime'] = pd.to_datetime(df_first_icu['intime'])
df_first_icu['outtime'] = pd.to_datetime(df_first_icu['outtime'])

no_icu_count = df_first_icu['stay_id'].isna().sum()

before_icu = df_first_icu[df_first_icu['StudyDateTime'] < df_first_icu['intime']].shape[0]
within_icu = df_first_icu[(df_first_icu['StudyDateTime'] >= df_first_icu['intime']) &
                          (df_first_icu['StudyDateTime'] <= df_first_icu['outtime'])].shape[0]
after_icu = df_first_icu[df_first_icu['StudyDateTime'] > df_first_icu['outtime']].shape[0]

print("Summary:")
print(f"1. Number of images without an ICU stay: {no_icu_count}")
print(f"2. Distribution of StudyDateTime relative to the first ICU stay:")
print(f"   - Before first ICU stay: {before_icu}")
print(f"   - During first ICU stay: {within_icu}")
print(f"   - After first ICU stay: {after_icu}")

In [None]:
before_icu_ids = df_first_icu[df_first_icu['StudyDateTime'] < df_first_icu['intime']]['subject_id'].unique()
within_icu_ids = df_first_icu[
    (df_first_icu['StudyDateTime'] >= df_first_icu['intime']) &
    (df_first_icu['StudyDateTime'] <= df_first_icu['outtime'])
]['subject_id'].unique()
after_icu_ids = df_first_icu[df_first_icu['StudyDateTime'] > df_first_icu['outtime']]['subject_id'].unique()
total_unique_subjects = df_first_icu['subject_id'].nunique()
no_icu_ids = df_first_icu[df_first_icu['stay_id'].isna()]['subject_id'].unique()
num_no_icu = len(no_icu_ids)
print("Summary:")
print(f"1. Total unique subject_ids: {total_unique_subjects}")
print(f"2. Number of subject_ids with more than one ICU stay: {more_than_once_icu}")
print(f"3. Unique subject_id counts by StudyDateTime (relative to first ICU stay):")
print(f"   - Before first ICU stay: {len(before_icu_ids)}")
print(f"   - During first ICU stay: {len(within_icu_ids)}")
print(f"   - After first ICU stay: {len(after_icu_ids)}")
print(f"   - No ICU stay: {len(no_icu_ids)}")

In [None]:
has_icu_ids = df_first_icu[df_first_icu['stay_id'].notna()]['subject_id'].unique()

true_no_icu_ids = set(no_icu_ids) - set(has_icu_ids)
num_true_no_icu = len(true_no_icu_ids)

print("Summary (exclude duplicates):")
print(f"1. Total unique subject_ids: {total_unique_subjects}")
print(f"2. Number of subject_ids with more than one ICU stay: {more_than_once_icu}")
print(f"3. Unique subject_id counts by StudyDateTime (relative to first ICU stay):")
print(f"   - Before first ICU stay: {len(before_icu_ids)}")
print(f"   - During first ICU stay: {len(within_icu_ids)}")
print(f"   - After first ICU stay: {len(after_icu_ids)}")
print(f"   - No ICU stay (never appeared in ICU at all): {num_true_no_icu}")

In [None]:
df_discharge = pd.read_csv(discharge_data_path)
df_radiology = pd.read_csv(radiology_data_path)

df_discharge = df_discharge[df_discharge['text'].notna()]
df_radiology = df_radiology[df_radiology['text'].notna()]

discharge_ids = set(df_discharge['subject_id'].unique())
radiology_ids = set(df_radiology['subject_id'].unique())

all_note_subjects = discharge_ids.union(radiology_ids)

print(f"Discharge note subject count: {len(discharge_ids)}")
print(f"Radiology note subject count: {len(radiology_ids)}")
print(f"Total unique subjects with notes: {len(all_note_subjects)}")

In [None]:
# df_discharge

In [None]:
# df_radiology

In [None]:
df_discharge['note_type'] = 'discharge'
df_radiology['note_type'] = 'radiology'

df_note = pd.concat([df_discharge, df_radiology], ignore_index=True)

df_note['charttime'] = pd.to_datetime(df_note['charttime'])

df_note = df_note[df_note['text'].notna()]

print(df_note.head())

In [None]:
note_ids = set(df_note['subject_id'].unique())

before_with_notes = set(before_icu_ids) & note_ids
within_with_notes = set(within_icu_ids) & note_ids
after_with_notes = set(after_icu_ids) & note_ids
true_no_icu_with_notes = true_no_icu_ids & note_ids 

print("Summary:")
print(f"1. Total unique subject_ids: {total_unique_subjects}")
# print(f"2. Number of subject_ids with more than one ICU stay: {more_than_once_icu}")
print(f"2. Before first ICU stay with notes: {len(before_with_notes)}")
print(f"3. During first ICU stay with notes: {len(within_with_notes)}")
print(f"4. After first ICU stay with notes: {len(after_with_notes)}")
print(f"5. No ICU stay at all with notes: {len(true_no_icu_with_notes)}")

In [None]:
subjects_with_notes = set(df_note['subject_id'].unique())

df_first_icu['StudyDateTime'] = pd.to_datetime(df_first_icu['StudyDateTime'])
df_first_icu['intime'] = pd.to_datetime(df_first_icu['intime'])

images_within_24h = df_first_icu[
    df_first_icu['StudyDateTime'] <= df_first_icu['intime'] + pd.Timedelta(hours=24)
]

images_within_48h = df_first_icu[
    df_first_icu['StudyDateTime'] <= df_first_icu['intime'] + pd.Timedelta(hours=48)
]

subjects_with_images_24h = set(images_within_24h['subject_id'].unique())
subjects_with_images_48h = set(images_within_48h['subject_id'].unique())

cohort_subjects_24h = subjects_with_images_24h & subjects_with_notes
cohort_subjects_48h = subjects_with_images_48h & subjects_with_notes

print("In-ICU mortality cohort (based on image before/within first ICU stay + clinical notes):")
print(f"1. Subjects with CXR before or within 24h and clinical note: {len(cohort_subjects_24h)}")
print(f"2. Subjects with CXR before or within 48h and clinical note: {len(cohort_subjects_48h)}")

In [None]:
images_within_visit = df_first_icu[
    df_first_icu['StudyDateTime'] <= df_first_icu['outtime']
]
subjects_with_images = set(images_within_visit['subject_id'].unique())
cohort_subjects_all = subjects_with_images & subjects_with_notes
print(f"3. Subjects with CXR before or during and clinical note: {len(cohort_subjects)}")

In [None]:
df_first_icu['StudyDateTime'] = pd.to_datetime(df_first_icu['StudyDateTime'])
df_first_icu['outtime'] = pd.to_datetime(df_first_icu['outtime'])

images_within_last_48h = df_first_icu[
    (df_first_icu['StudyDateTime'] <= df_first_icu['outtime']) &
    (df_first_icu['StudyDateTime'] >= df_first_icu['outtime'] - pd.Timedelta(hours=48))
]

subjects_with_images_last_48h = set(images_within_last_48h['subject_id'].unique())

cohort_subjects_last_48h = subjects_with_images_last_48h & subjects_with_notes

print("Readmission cohort (based on image within last 48h of first ICU stay + clinical notes):")
print(f"Subjects with CXR in last 48h before ICU discharge and clinical note: {len(cohort_subjects_last_48h)}")

In [None]:
def extract_cxr_metadata_by_subjects(
    metadata_path: str,
    cohort_24h: set,
    cohort_48h: set,
    cohort_l48h: set,
    cohort_all: set,
    output_24h_path: str = "cxr_f24h_metadata.csv.gz",
    output_48h_path: str = "cxr_f48h_metadata.csv.gz",
    output_l48h_path: str = "cxr_l48h_metadata.csv.gz",
    output_all_path: str = "cxr_all_metadata.csv.gz"
):

    df_metadata = pd.read_csv(metadata_path)

    df_24h = df_metadata[df_metadata['subject_id'].isin(cohort_24h)]
    df_48h = df_metadata[df_metadata['subject_id'].isin(cohort_48h)]
    df_l48h = df_metadata[df_metadata['subject_id'].isin(cohort_l48h)]
    df_all = df_metadata[df_metadata['subject_id'].isin(cohort_all)]

    df_24h.to_csv(output_24h_path, index=False, compression='gzip')
    df_48h.to_csv(output_48h_path, index=False, compression='gzip')
    df_l48h.to_csv(output_l48h_path, index=False, compression='gzip')
    df_all.to_csv(output_all_path, index=False, compression='gzip')

    print(f"✅ Saved {len(df_24h)} rows to {output_24h_path}")
    print(f"✅ Saved {len(df_48h)} rows to {output_48h_path}")
    print(f"✅ Saved {len(df_l48h)} rows to {output_l48h_path}")
    print(f"✅ Saved {len(df_all)} rows to {output_all_path}")


In [None]:
# extract_cxr_metadata_by_subjects(
#     metadata_path=metadata_path,
#     cohort_24h=cohort_subjects_24h,
#     cohort_48h=cohort_subjects_48h,
#     cohort_l48h=cohort_subjects_last_48h,
#     cohort_all=cohort_subjects_all,
# )

In [None]:
def extract_cxr_metadata_by_time(
    metadata_path: str,
    df_first_icu: pd.DataFrame,
    output_24h_path: str = "cxr_f24h_metadata.csv.gz",
    output_48h_path: str = "cxr_f48h_metadata.csv.gz",
    output_l48h_path: str = "cxr_l48h_metadata.csv.gz",
    output_all_path: str = "cxr_all_metadata.csv.gz"
):
    df_metadata = pd.read_csv(metadata_path)
    original_columns = df_metadata.columns.tolist() 

    if 'StudyDateTime' not in df_metadata.columns:
        if 'StudyDate' in df_metadata.columns and 'StudyTime' in df_metadata.columns:
            df_metadata['StudyDateStr'] = df_metadata['StudyDate'].astype(str).str.zfill(8)

            def format_time(t):
                try:
                    t = float(t)
                    int_part = int(t)
                    decimal_part = f"{t:.6f}".split(".")[1]
                    int_str = str(int_part).zfill(6)
                    return f"{int_str}.{decimal_part}"
                except:
                    return "000000.000000"

            df_metadata['StudyTimeStr'] = df_metadata['StudyTime'].apply(format_time)
            df_metadata['StudyDateTime'] = pd.to_datetime(
                df_metadata['StudyDateStr'] + " " + df_metadata['StudyTimeStr'],
                format="%Y%m%d %H%M%S.%f",
                errors='coerce'
            )
        else:
            raise ValueError("Metadata missing both 'StudyDateTime' and ('StudyDate', 'StudyTime') to generate StudyDateTime.")

    df_first_icu['intime'] = pd.to_datetime(df_first_icu['intime'])
    df_first_icu['outtime'] = pd.to_datetime(df_first_icu['outtime'])

    df = df_metadata.merge(df_first_icu[['subject_id', 'intime', 'outtime']], on='subject_id', how='left')

    df_24h = df[
        df['StudyDateTime'] <= df['intime'] + pd.Timedelta(hours=24)
    ]
    df_48h = df[
        df['StudyDateTime'] <= df['intime'] + pd.Timedelta(hours=48)
    ]
    df_l48h = df[
        (df['StudyDateTime'] <= df['outtime']) &
        (df['StudyDateTime'] >= df['outtime'] - pd.Timedelta(hours=48))
    ]
    df_all = df[
        (df['StudyDateTime'] <= df['outtime']) 
    ]

    df_24h[original_columns].to_csv(output_24h_path, index=False, compression='gzip')
    df_48h[original_columns].to_csv(output_48h_path, index=False, compression='gzip')
    df_l48h[original_columns].to_csv(output_l48h_path, index=False, compression='gzip')
    df_all[original_columns].to_csv(output_all_path, index=False, compression='gzip')

    print(f"✅ Saved {len(df_24h)} rows to {output_24h_path}")
    print(f"✅ Saved {len(df_48h)} rows to {output_48h_path}")
    print(f"✅ Saved {len(df_l48h)} rows to {output_l48h_path}")
    print(f"✅ Saved {len(df_all)} rows to {output_all_path}")

In [None]:
extract_cxr_metadata_by_time(
    metadata_path=metadata_path,
    df_first_icu=df_first_icu,
    output_24h_path="cxr_filtered_f24h.csv.gz",
    output_48h_path="cxr_filtered_f48h.csv.gz",
    output_l48h_path="cxr_filtered_l48h.csv.gz",
    output_all_path="cxr_filtered_all.csv.gz"
)

In [None]:
# drop duplicate