## MIMIC-IV Data Preprocessing

In [31]:
import sys
print(sys.executable)

import os
import pickle as pickle
import numpy as np
from datetime import datetime
import pandas as pd
import scipy.sparse as sps
#import torch
from copy import deepcopy
#import torch.nn as nn
#import torch.nn.init as init
#from torch.nn import functional as F
from collections import OrderedDict
#import torch.utils.data as data
#from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
import random
import warnings
warnings.filterwarnings("ignore")

c:\Users\Joshua\Documents\GitHub\SHy\venv\Scripts\python.exe


In [32]:
def parse_admission(path) -> dict:
    print('parsing ADMISSIONS.csv ...')
    admission_path = os.path.join(path, 'admissions.csv')
    admissions = pd.read_csv(
        admission_path,
        usecols=['subject_id', 'hadm_id', 'admittime'],
        converters={ 'subject_id': int, 'hadm_id': int, 'admittime': str }
    )
    all_patients = dict()
    for i, row in admissions.iterrows():
        pid = row['subject_id']
        admission_id = row['hadm_id']
        admission_time = datetime.strptime(row['admittime'], '%Y-%m-%d %H:%M:%S')
        if pid not in all_patients:
            all_patients[pid] = []
        admission = all_patients[pid]
        admission.append({
            'admission_id': admission_id,
            'admission_time': admission_time
        })

    patient_admission = dict()
    for pid, admissions in all_patients.items():
        if len(admissions) > 1:
            patient_admission[pid] = sorted(admissions, key=lambda admission: admission['admission_time'])

    return patient_admission

In [33]:
def parse_diagnoses(path, dump_list, patient_admission: dict) -> dict:
    print('parsing DIAGNOSES_ICD.csv ...')
    diagnoses_path = os.path.join(path, 'diagnoses_icd.csv')
    diagnoses = pd.read_csv(
        diagnoses_path,
        usecols=['subject_id', 'hadm_id', 'icd_code'],
        converters={ 'subject_id': int, 'hadm_id': int, 'icd_code': str }
    )

    def to_standard_icd9(code: str):
        split_pos = 4 if code.startswith('E') else 3
        icd9_code = code[:split_pos] + '.' + code[split_pos:] if len(code) > split_pos else code
        return icd9_code

    admission_codes = dict()
    for i, row in diagnoses.iterrows():
        pid = row['subject_id']
        if pid in patient_admission:
            admission_id = row['hadm_id']
            code = row['icd_code']
            if code == '':
                continue
            if code.startswith('E') or code.startswith('V') or code.startswith('0') or code.startswith('1') or code.startswith('2') or code.startswith('3') or code.startswith('4') or code.startswith('5') or code.startswith('6') or code.startswith('7') or code.startswith('8') or code.startswith('9'):
                code = to_standard_icd9(code)
                if code in dump_list:
                    continue
            else:
                continue
            if admission_id not in admission_codes:
                codes = []
                admission_codes[admission_id] = codes
            else:
                codes = admission_codes[admission_id]
            codes.append(code)

    return admission_codes

In [34]:
def calibrate_patient_by_admission(patient_admission: dict, admission_codes: dict):
    print('calibrating patients by admission ...')
    del_pids = []
    for pid, admissions in patient_admission.items():
        for admission in admissions:
            if admission['admission_id'] not in admission_codes:
                break
        else:
            continue
        del_pids.append(pid)
    for pid in del_pids:
        admissions = patient_admission[pid]
        for admission in admissions:
            if admission['admission_id'] in admission_codes:
                del admission_codes[admission['admission_id']]
        del patient_admission[pid]

In [35]:
raw_path = '../data/RAW/MIMIC_IV/'
with open(f'../data/RAW/MIMIC_IV/dump_list_icd9.pkl', 'rb') as f0:
    dump_list_icd9 = pickle.load(f0)
patient_admission = parse_admission(raw_path)
admission_codes = parse_diagnoses(raw_path, dump_list_icd9, patient_admission)
calibrate_patient_by_admission(patient_admission, admission_codes)
print('There are %d valid patients' % len(patient_admission))

parsing ADMISSIONS.csv ...
parsing DIAGNOSES_ICD.csv ...
calibrating patients by admission ...
There are 41919 valid patients


In [36]:
with open('../data/MIMIC_IV/patient_admission.pkl', 'wb') as f155:
    pickle.dump(patient_admission, f155)

with open('../data/MIMIC_IV/admission_codes.pkl', 'wb') as f156:
    pickle.dump(admission_codes, f156)

In [37]:
max_admission_num = 0
for pid, admissions in patient_admission.items():
    if len(admissions) > max_admission_num:
        max_admission_num = len(admissions)
max_code_num_in_a_visit = 0
for admission_id, codes in admission_codes.items():
    if len(codes) > max_code_num_in_a_visit:
        max_code_num_in_a_visit = len(codes)

In [38]:
def encode_code(admission_codes: dict) -> (dict, dict):
    print('encoding code ...')
    code_map = dict()
    for i, (admission_id, codes) in enumerate(admission_codes.items()):
        for code in codes:
            if code not in code_map:
                code_map[code] = len(code_map) + 1

    admission_codes_encoded = {
        admission_id: [code_map[code] for code in codes]
        for admission_id, codes in admission_codes.items()
    }
    return admission_codes_encoded, code_map

In [39]:
def encode_time_duration(patient_admission: dict) -> dict:
    print('encoding time duration ...')
    patient_time_duration_encoded = dict()
    for pid, admissions in patient_admission.items():
        duration = [0]
        for i in range(1, len(admissions)):
            days = (admissions[i]['admission_time'] - admissions[i - 1]['admission_time']).days
            duration.append(days)
        patient_time_duration_encoded[pid] = duration
    return patient_time_duration_encoded

In [40]:
def split_patients(patient_admission: dict, admission_codes: dict, code_map: dict, seed=6669) -> (np.ndarray, np.ndarray):
    print('splitting train, valid, and test pids')
    np.random.seed(seed)
    common_pids = set()
    for i, code in enumerate(code_map):
        print('\r\t%.2f%%' % ((i + 1) * 100 / len(code_map)), end='')
        for pid, admissions in patient_admission.items():
            for admission in admissions:
                codes = admission_codes[admission['admission_id']]
                if code in codes:
                    common_pids.add(pid)
                    break
            else:
                continue
            break
    print('\r\t100%')
    max_admission_num = 0
    pid_max_admission_num = 0
    for pid, admissions in patient_admission.items():
        if len(admissions) > max_admission_num:
            max_admission_num = len(admissions)
            pid_max_admission_num = pid
    common_pids.add(pid_max_admission_num)
    remaining_pids = np.array(list(set(patient_admission.keys()).difference(common_pids)))
    np.random.shuffle(remaining_pids)

    train_num = 40725
    train_pids = np.array(list(common_pids.union(set(remaining_pids[:(train_num - len(common_pids))].tolist()))))
    test_pids = remaining_pids[(train_num - len(common_pids)):]
    return train_pids, test_pids

In [41]:
admission_codes_encoded, code_map = encode_code(admission_codes)
patient_time_duration_encoded = encode_time_duration(patient_admission)

code_num = len(code_map)

train_pids, test_pids = split_patients(
    patient_admission=patient_admission,
    admission_codes=admission_codes,
    code_map=code_map
)

encoding code ...
encoding time duration ...
splitting train, valid, and test pids
	100%00%


In [42]:
with open('../data/MIMIC_IV/code_map.pkl', 'wb') as f13:
    pickle.dump(code_map, f13)

with open('../data/MIMIC_IV/admission_codes_encoded.pkl', 'wb') as f157:
    pickle.dump(admission_codes_encoded, f157)

with open('../data/MIMIC_IV/patient_time_duration_encoded.pkl', 'wb') as f158:
    pickle.dump(patient_time_duration_encoded, f158)

with open('../data/MIMIC_IV/train_pids.npy', 'wb') as f258:
    np.save(f258, train_pids)

with open('../data/MIMIC_IV/test_pids.npy', 'wb') as f259:
    np.save(f259, test_pids)

In [43]:
def build_code_xy(pids: np.ndarray,
                  patient_admission: dict,
                  admission_codes_encoded: dict,
                  max_admission_num: int,
                  code_num: int,
                  max_code_num_in_a_visit: int) -> (np.ndarray, np.ndarray, np.ndarray):
    print('building train/test codes features and labels ...')
    n = len(pids)
    x = np.zeros((n, max_admission_num, max_code_num_in_a_visit), dtype=int)
    y = np.zeros((n, code_num), dtype=int)
    lens = np.zeros((n, ), dtype=int)
    for i, pid in enumerate(pids):
        print('\r\t%d / %d' % (i + 1, len(pids)), end='')
        admissions = patient_admission[pid]
        for k, admission in enumerate(admissions[:-1]):
            codes = admission_codes_encoded[admission['admission_id']]
            x[i][k][:len(codes)] = codes
        codes = np.array(admission_codes_encoded[admissions[-1]['admission_id']]) - 1
        y[i][codes] = 1
        lens[i] = len(admissions) - 1
    print('\r\t%d / %d' % (len(pids), len(pids)))
    return x, y, lens

In [44]:
def build_time_duration_xy(pids: np.ndarray,
                           patient_time_duration_encoded: dict,
                           max_admission_num: int) -> (np.ndarray, np.ndarray):
    print('building train/valid/test time duration features and labels ...')
    n = len(pids)
    x = np.zeros((n, max_admission_num))
    y = np.zeros((n, ))
    for i, pid in enumerate(pids):
        print('\r\t%d / %d' % (i + 1, len(pids)), end='')
        duration = patient_time_duration_encoded[pid]
        x[i][:len(duration) - 1] = duration[:-1]
        y[i] = duration[-1]
    print('\r\t%d / %d' % (len(pids), len(pids)))
    return x, y

In [45]:
def build_time_gaps(pids: np.ndarray,
                    patient_time_duration_encoded: dict,
                    max_admission_num: int) -> np.ndarray:
    """
    Build time gaps from the LAST visit for each patient.
    This gives us "days ago" for each visit, which we use for temporal weighting.
    More recent visits (smaller values) should have higher weights.
    """
    print('building time gaps for time-aware extension ...')
    n = len(pids)
    time_gaps = []
    
    for i, pid in enumerate(pids):
        print('\r\t%d / %d' % (i + 1, len(pids)), end='')
        duration = patient_time_duration_encoded[pid]
        # duration[j] = days between visit j-1 and visit j
        # We want: days from visit j to the LAST visit (before prediction)
        
        # Cumulative sum gives days from first visit
        cumsum = np.cumsum(duration[:-1])  # Exclude the last (prediction target)
        
        if len(cumsum) > 0:
            # Convert to "days ago" from the last input visit
            days_ago = cumsum[-1] - cumsum  # Last visit = 0, earlier visits = larger
        else:
            days_ago = np.array([0])
        
        time_gaps.append(days_ago)
    
    print('\r\t%d / %d' % (len(pids), len(pids)))
    return time_gaps

# Build time gaps
train_time_gaps = build_time_gaps(train_pids, patient_time_duration_encoded, max_admission_num)
test_time_gaps = build_time_gaps(test_pids, patient_time_duration_encoded, max_admission_num)

# Save time gaps
with open('../data/MIMIC_IV/train_time_gaps.pkl', 'wb') as f:
    pickle.dump(train_time_gaps, f)

with open('../data/MIMIC_IV/test_time_gaps.pkl', 'wb') as f:
    pickle.dump(test_time_gaps, f)

print(f"Example time gaps for first patient: {train_time_gaps[0]}")
print(f"Number of visits: {len(train_time_gaps[0])}")

building time gaps for time-aware extension ...
	40725 / 40725
building time gaps for time-aware extension ...
	1194 / 1194
Example time gaps for first patient: [39  0]
Number of visits: 2


In [46]:
train_codes_x, train_codes_y, train_visit_lens = build_code_xy(train_pids, patient_admission, admission_codes_encoded, max_admission_num, code_num, max_code_num_in_a_visit)
test_codes_x, test_codes_y, test_visit_lens = build_code_xy(test_pids, patient_admission, admission_codes_encoded, max_admission_num, code_num, max_code_num_in_a_visit)

building train/test codes features and labels ...
	40725 / 40725
building train/test codes features and labels ...
	1194 / 1194


In [47]:
with open('../data/MIMIC_IV/train_codes_y.npy', 'wb') as f2:
    np.save(f2, train_codes_y)

with open('../data/MIMIC_IV/train_visit_lens.npy', 'wb') as f3:
    np.save(f3, train_visit_lens)

with open('../data/MIMIC_IV/test_codes_y.npy', 'wb') as f5:
    np.save(f5, test_codes_y)

with open('../data/MIMIC_IV/test_visit_lens.npy', 'wb') as f6:
    np.save(f6, test_visit_lens)
    
with open('../data/MIMIC_IV/train_codes_x.npy', 'wb') as f8:
    np.save(f8, train_codes_x)

with open('../data/MIMIC_IV/test_codes_x.npy', 'wb') as f9:
    np.save(f9, test_codes_x)

In [48]:
def parse_icd9_range(range_: str) -> (str, str, int, int):
    ranges = range_.lstrip().split('-')
    if ranges[0][0] == 'V':
        prefix = 'V'
        format_ = '%02d'
        start, end = int(ranges[0][1:]), int(ranges[1][1:])
    elif ranges[0][0] == 'E':
        prefix = 'E'
        format_ = '%03d'
        start, end = int(ranges[0][1:]), int(ranges[1][1:])
    else:
        prefix = ''
        format_ = '%03d'
        if len(ranges) == 1:
            start = int(ranges[0])
            end = start + 1
        else:
            start, end = int(ranges[0]), int(ranges[1])
    return prefix, format_, start, end

In [49]:
def generate_code_levels(path, code_map: dict) -> np.ndarray:
    print('generating code levels ...')
    three_level_code_set = set(code.split('.')[0] for code in code_map)
    icd9_path = os.path.join(path, 'icd9.txt')
    icd9_range = list(open(icd9_path, 'r', encoding='utf-8').readlines())
    three_level_dict = dict()
    level1, level2, level3 = (1, 1, 1)
    level1_can_add = False
    for range_ in icd9_range:
        range_ = range_.rstrip()
        if range_[0] == ' ':
            prefix, format_, start, end = parse_icd9_range(range_)
            level2_cannot_add = True
            for i in range(start, end + 1):
                code = prefix + format_ % i
                if code in three_level_code_set:
                    three_level_dict[code] = [level1, level2, level3]
                    level3 += 1
                    level1_can_add = True
                    level2_cannot_add = False
            if not level2_cannot_add:
                level2 += 1
        else:
            if level1_can_add:
                level1 += 1
                level1_can_add = False

    level4 = 1
    code_level = dict()
    for code in code_map:
        three_level_code = code.split('.')[0]
        if three_level_code in three_level_dict:
            three_level = three_level_dict[three_level_code]
            code_level[code] = three_level + [level4]
            level4 += 1
        else:
            code_level[code] = [0, 0, 0, 0]

    code_level_matrix = np.zeros((len(code_map) + 1, 4), dtype=int)
    for code, cid in code_map.items():
        code_level_matrix[cid] = code_level[code]

    return code_level_matrix

In [50]:
def generate_patient_code_adjacent(code_x: np.ndarray, code_num: int) -> np.ndarray:
    print('generating patient code adjacent matrix ...')
    result = np.zeros((len(code_x), code_num + 1), dtype=int)
    for i, codes in enumerate(code_x):
        adj_codes = codes[codes > 0]
        result[i][adj_codes] = 1
    return result

In [51]:
def generate_code_code_adjacent(code_num: int, code_level_matrix: np.ndarray) -> np.ndarray:
    print('generating code code adjacent matrix ...')
    n = code_num + 1
    result = np.zeros((n, n), dtype=int)
    for i in range(1, n):
        print('\r\t%d / %d' % (i, n), end='')
        for j in range(1, n):
            if i != j:
                level_i = code_level_matrix[i]
                level_j = code_level_matrix[j]
                same_level = 4
                while same_level > 0:
                    level = same_level - 1
                    if level_i[level] == level_j[level]:
                        break
                    same_level -= 1
                result[i, j] = same_level + 1
    print('\r\t%d / %d' % (n, n))
    return result

In [52]:
def co_occur(pids: np.ndarray,
             patient_admission: dict,
             admission_codes_encoded: dict,
             code_num: int) -> (np.ndarray, np.ndarray, np.ndarray):
    print('calculating co-occurrence ...')
    x = np.zeros((code_num + 1, code_num + 1), dtype=float)
    for i, pid in enumerate(pids):
        print('\r\t%d / %d' % (i + 1, len(pids)), end='')
        admissions = patient_admission[pid]
        for k, admission in enumerate(admissions[:-1]):
            codes = admission_codes_encoded[admission['admission_id']]
            for m in range(len(codes) - 1):
                for n in range(m + 1, len(codes)):
                    c_i, c_j = codes[m], codes[n]
                    x[c_i, c_j] = 1
                    x[c_j, c_i] = 1
    print('\r\t%d / %d' % (len(pids), len(pids)))
    return x

In [53]:
l1 = len(train_pids)
train_patient_ids = np.arange(0, l1)
l2 = l1 + 0
l3 = l2 + len(test_pids)
test_patient_ids = np.arange(l2, l3)
pid_map = dict()
for i, pid in enumerate(train_pids):
    pid_map[pid] = train_patient_ids[i]
for i, pid in enumerate(test_pids):
    pid_map[pid] = test_patient_ids[i]

In [54]:
with open('../data/MIMIC_IV/pid_map.pkl', 'wb') as f133:
    pickle.dump(pid_map, f133)

In [55]:
data_path = '../data/RAW/'
code_levels = generate_code_levels(data_path, code_map)

patient_code_adj = generate_patient_code_adjacent(code_x=train_codes_x, code_num=code_num)
patient_code_adj = np.delete(patient_code_adj, 0, 1)
with open('../data/MIMIC_IV/patient_code_adj.npy', 'wb') as f11:
    np.save(f11, patient_code_adj)

code_code_adj_t = generate_code_code_adjacent(code_level_matrix=code_levels, code_num=code_num)
code_levels = code_levels[1:][:]
with open('../data/MIMIC_IV/code_levels.npy', 'wb') as f10:
    np.save(f10, code_levels)

generating code levels ...
generating patient code adjacent matrix ...
generating code code adjacent matrix ...
	8134 / 8134


In [56]:
co_occur_matrix = co_occur(train_pids, patient_admission, admission_codes_encoded, code_num)
code_code_adj = code_code_adj_t * co_occur_matrix
code_code_adj = np.delete(code_code_adj[1:][:], 0, 1)
with open('../data/MIMIC_IV/code_code_adj.npy', 'wb') as f12:
    np.save(f12, code_code_adj)

calculating co-occurrence ...
	40725 / 40725


In [57]:
binary_train_codes_x = []
for i in range(len(train_pids)):
    one_patient = np.zeros((train_visit_lens[i], code_num))
    for ii in range(train_visit_lens[i]):
        temp = train_codes_x[i][ii]
        temp = temp[temp > 0] - 1
        one_patient[ii][temp] = 1
    binary_train_codes_x.append(one_patient)

with open('../data/MIMIC_IV/binary_train_codes_x.pkl', 'wb') as f134:
    pickle.dump(binary_train_codes_x, f134)

binary_test_codes_x = []
for i in range(len(test_pids)):
    one_patient = np.zeros((test_visit_lens[i], code_num))
    for ii in range(test_visit_lens[i]):
        temp = test_codes_x[i][ii]
        temp = temp[temp > 0] - 1
        one_patient[ii][temp] = 1
    binary_test_codes_x.append(one_patient)

with open('../data/MIMIC_IV/binary_test_codes_x.pkl', 'wb') as f135:
    pickle.dump(binary_test_codes_x, f135)

In [58]:
maxx, idx1 = 0, 0
for j, btcx in enumerate(binary_train_codes_x):
    if btcx.shape[0] > maxx:
        maxx = btcx.shape[0]
        idx1 = j
target = binary_train_codes_x[idx1]
np.save(f'../data/MIMIC_IV/anchor_train.npy', target)

maxx, idx2 = 0, 0
for i, btcx in enumerate(binary_test_codes_x):
    if btcx.shape[0] > maxx:
        maxx = btcx.shape[0]
        idx2 = i
target = binary_test_codes_x[idx2]
np.save(f'../data/MIMIC_IV/anchor_test.npy', target)

In [59]:
for ii, btcx in enumerate(binary_train_codes_x):
    np.save(f'../data/MIMIC_IV/binary_train_x_slices/binary_train_codes_x_{ii}.npy', btcx)

for jj, btcx in enumerate(binary_test_codes_x):
    np.save(f'../data/MIMIC_IV/binary_test_x_slices/binary_test_codes_x_{jj}.npy', btcx)

In [61]:
with open('../data/MIMIC_IV/patient_admission.pkl', 'rb') as f:
    patient_admission = pickle.load(f)

with open('../data/MIMIC_IV/admission_codes.pkl', 'rb') as f:
    admission_codes = pickle.load(f)

# If you have the train/test split pickle files
with open('../data/MIMIC_IV/train_pids.npy', 'rb') as f:
    train_pids = np.load(f)
    
with open('../data/MIMIC_IV/test_pids.npy', 'rb') as f:
    test_pids = np.load(f)

# Calculate statistics
def calculate_dataset_statistics(patient_admission, admission_codes, train_pids, test_pids):
    """
    Calculate comprehensive statistics for MIMIC-IV dataset
    """
    
    # Total patients
    total_patients = len(patient_admission)
    train_patients = len(train_pids)
    test_patients = len(test_pids)
    
    # Split test into validation and test (50/50)
    val_patients = test_patients // 2
    final_test_patients = test_patients - val_patients
    
    # Calculate stays (admissions)
    total_stays = sum(len(admissions) for admissions in patient_admission.values())
    
    train_stays = sum(len(patient_admission[pid]) for pid in train_pids)
    test_all_stays = sum(len(patient_admission[pid]) for pid in test_pids)
    val_stays = test_all_stays // 2
    final_test_stays = test_all_stays - val_stays
    
    # Gender analysis (if available in admissions.csv)
    # You'll need to load this separately
    try:
        patients_df = pd.read_csv('../data/RAW/MIMIC_IV/patients.csv')
        gender_counts = patients_df[patients_df['subject_id'].isin(patient_admission.keys())]['gender'].value_counts()
        male_percentage = (gender_counts.get('M', 0) / total_patients) * 100
    except:
        male_percentage = None
    
    # Age analysis (requires patients.csv and admissions.csv)
    try:
        admissions_df = pd.read_csv('../data/RAW/MIMIC_IV/admissions.csv')
        # Calculate age at admission
        # This requires anchor_year from patients.csv and admission time
        # Placeholder for now
        mean_age = None
    except:
        mean_age = None
    
    # Length of Stay (LoS) calculation
    try:
        admissions_df = pd.read_csv('../data/RAW/MIMIC_IV/admissions.csv',
                                   usecols=['hadm_id', 'admittime', 'dischtime'])
        admissions_df['admittime'] = pd.to_datetime(admissions_df['admittime'])
        admissions_df['dischtime'] = pd.to_datetime(admissions_df['dischtime'])
        admissions_df['los_days'] = (admissions_df['dischtime'] - admissions_df['admittime']).dt.total_seconds() / (24 * 3600)
        
        # Filter for valid admissions
        valid_admissions = admissions_df[admissions_df['hadm_id'].isin(admission_codes.keys())]
        
        mean_los = valid_admissions['los_days'].mean()
        median_los = valid_admissions['los_days'].median()
    except:
        mean_los = None
        median_los = None
    
    # In-hospital mortality
    try:
        admissions_df = pd.read_csv('../data/RAW/MIMIC_IV/admissions.csv',
                                   usecols=['hadm_id', 'hospital_expire_flag'])
        valid_admissions = admissions_df[admissions_df['hadm_id'].isin(admission_codes.keys())]
        mortality_rate = (valid_admissions['hospital_expire_flag'].sum() / len(valid_admissions)) * 100
    except:
        mortality_rate = None
    
    # Number of unique codes
    unique_codes = len(set(code for codes in admission_codes.values() for code in codes))
    
    # Codes per visit statistics
    codes_per_visit = [len(codes) for codes in admission_codes.values()]
    mean_codes_per_visit = np.mean(codes_per_visit)
    median_codes_per_visit = np.median(codes_per_visit)
    max_codes_per_visit = np.max(codes_per_visit)
    
    # Print comprehensive statistics
    print("=" * 60)
    print("DATASET STATISTICS")
    print("=" * 60)
    
    print("\n--- Patient Statistics ---")
    print(f"Total Patients: {total_patients:,}")
    print(f"  Train: {train_patients:,}")
    print(f"  Validation: {val_patients:,}")
    print(f"  Test: {final_test_patients:,}")
    
    print("\n--- Stay (Admission) Statistics ---")
    print(f"Total Stays: {total_stays:,}")
    print(f"  Train: {train_stays:,}")
    print(f"  Validation: {val_stays:,}")
    print(f"  Test: {final_test_stays:,}")
    
    if male_percentage is not None:
        print(f"\n--- Demographics ---")
        print(f"Gender (% male): {male_percentage:.1f}%")
    
    if mean_age is not None:
        print(f"Age (mean): {mean_age:.1f}")
    
    if mean_los is not None:
        print(f"\n--- Length of Stay ---")
        print(f"LoS (mean): {mean_los:.2f} days")
        print(f"LoS (median): {median_los:.2f} days")
    
    if mortality_rate is not None:
        print(f"\n--- Outcomes ---")
        print(f"In-hospital mortality: {mortality_rate:.2f}%")
    
    print(f"\n--- Code Statistics ---")
    print(f"Number of unique diagnosis codes: {unique_codes:,}")
    print(f"Codes per visit (mean): {mean_codes_per_visit:.2f}")
    print(f"Codes per visit (median): {median_codes_per_visit:.2f}")
    print(f"Codes per visit (max): {max_codes_per_visit:,}")
    
    # Visits per patient statistics
    visits_per_patient = [len(admissions) for admissions in patient_admission.values()]
    print(f"\n--- Visit Statistics ---")
    print(f"Visits per patient (mean): {np.mean(visits_per_patient):.2f}")
    print(f"Visits per patient (median): {np.median(visits_per_patient):.0f}")
    print(f"Visits per patient (max): {np.max(visits_per_patient):,}")
    
    print("=" * 60)
    
    # Return dictionary for further analysis
    return {
        'total_patients': total_patients,
        'train_patients': train_patients,
        'val_patients': val_patients,
        'test_patients': final_test_patients,
        'total_stays': total_stays,
        'train_stays': train_stays,
        'val_stays': val_stays,
        'test_stays': final_test_stays,
        'male_percentage': male_percentage,
        'mean_age': mean_age,
        'mean_los': mean_los,
        'median_los': median_los,
        'mortality_rate': mortality_rate,
        'unique_codes': unique_codes,
        'mean_codes_per_visit': mean_codes_per_visit,
        'median_codes_per_visit': median_codes_per_visit,
    }

# Run the analysis
stats = calculate_dataset_statistics(patient_admission, admission_codes, train_pids, test_pids)

DATASET STATISTICS

--- Patient Statistics ---
Total Patients: 41,919
  Train: 40,725
  Validation: 597
  Test: 597

--- Stay (Admission) Statistics ---
Total Stays: 151,847
  Train: 147,626
  Validation: 2,110
  Test: 2,111

--- Length of Stay ---
LoS (mean): 4.76 days
LoS (median): 2.91 days

--- Outcomes ---
In-hospital mortality: 2.18%

--- Code Statistics ---
Number of unique diagnosis codes: 8,133
Codes per visit (mean): 10.34
Codes per visit (median): 9.00
Codes per visit (max): 39

--- Visit Statistics ---
Visits per patient (mean): 3.62
Visits per patient (median): 2
Visits per patient (max): 95
