In [None]:
import os
import pandas as pd
import numpy as np
from math import floor

In [14]:
# Split subjects into training set and testing set. And which subjects in the training set is to be used for validation. 
# Randomly select which subjects are in which category.

# Define the folder paths and CHF status for each database
database_info = {
    "BIDMC-CHF": {"path": "../data/BIDMC-CHF_bidmc-congestive-heart-failure-database-1.0.0/files/", "CHF": 1},
    "CHF-RR": {"path": "../data/CHF-RR_congestive-heart-failure-rr-interval-database-1.0.0/", "CHF": 1},
    "NSR": {"path": "../data/NSR_mit-bih-normal-sinus-rhythm-database-1.0.0/", "CHF": 0},
    "NSR-RR": {"path": "../data/NSR-RR_normal-sinus-rhythm-rr-interval-database-1.0.0/", "CHF": 0},
    "Fantasia": {"path": "../data/FD_fantasia-database-1.0.0/", "CHF": 0},
}

# Initialize a list to store the subject information
data = []

# Loop through each database
for db_name, info in database_info.items():
    db_path = info["path"]
    chf_status = info["CHF"]
    
    # Read the RECORDS file in the database folder
    records_file = os.path.join(db_path, "RECORDS")
    if os.path.exists(records_file):
        with open(records_file, "r") as f:
            subjects = f.read().splitlines()  # Get all subject IDs from the file
        
        # Append each subject with the database name and CHF status
        for subject in subjects:
            data.append({"subject_id": subject, "database": db_name, "CHF": chf_status})

# Create a pandas DataFrame from the collected data
df = pd.DataFrame(data)

# Save the DataFrame to a CSV file for further processing
# Define output path
output_folder = "../data/dataset_splits"
os.makedirs(output_folder, exist_ok=True)
df.to_csv(f"{output_folder}/all_subjects_overview.csv", index=False)

Unnamed: 0,subject_id,database,CHF
0,chf01,BIDMC-CHF,1
1,chf02,BIDMC-CHF,1
2,chf03,BIDMC-CHF,1
3,chf04,BIDMC-CHF,1
4,chf05,BIDMC-CHF,1
...,...,...,...
151,f2y06,Fantasia,0
152,f2y07,Fantasia,0
153,f2y08,Fantasia,0
154,f2y09,Fantasia,0


In [16]:
# Generate a new set of splits

def split_data(df, included_dbs, test_set_size_each_class, validation_percentage, random_state=42):
    # test_set_size_each_class: number of subjects in each of the classes CHF=0 and CHF=1.
    # Example value: test_set_size_each_class = 3. --> 3 CHF subjects and 3 NSR subjects.
    
    # Filter the dataframe to include only specified databases
    df = df[df['database'].isin(included_dbs)].copy()
    
    # Initialize empty DataFrames for training, validation and test sets
    training_set = pd.DataFrame()
    validation_set = pd.DataFrame()
    test_set = pd.DataFrame()
    
    # Split the test set by class
    for chf_status in [0, 1]:
        class_df = df[df['CHF'] == chf_status]
        
        # Split the test subjects across databases within the class
        db_counts = len(class_df['database'].unique())
        base_test_count = test_set_size_each_class // db_counts
        remainder = test_set_size_each_class % db_counts
        
        test_subset = pd.DataFrame()
        
        # Allocate the base number of subjects per database
        for db in class_df['database'].unique():
            db_subset = class_df[class_df['database'] == db]
            n_test = base_test_count + (1 if remainder > 0 else 0)
            remainder -= 1 if remainder > 0 else 0
            
            # Randomly sample the subjects for testing
            sampled = db_subset.sample(n=min(n_test, len(db_subset)), random_state=random_state)
            test_subset = pd.concat([test_subset, sampled])
        
        # Add to the overall test set
        test_set = pd.concat([test_set, test_subset])
    
    # Remaining data for training
    training_set = df.drop(test_set.index)
    
    # Create the validation set on a per-database basis from the training data
    for database in training_set['database'].unique():
        db_training_subset = training_set[training_set['database'] == database]
        
        # Calculate the number of subjects for validation (rounding as necessary)
        n_validation = floor(len(db_training_subset) * validation_percentage)
        
        # Randomly select subjects for validation
        db_validation_subset = db_training_subset.sample(n=n_validation, random_state=random_state)
        
        # Append to validation set
        validation_set = pd.concat([validation_set, db_validation_subset])
    
    # Update training set by removing validation subjects
    training_set = training_set.drop(validation_set.index)
    
    # Return the splits
    return training_set, validation_set, test_set

def get_next_folder_name(base_folder_name, directory="."):
    """
    Generates the next available folder name in a sequence and returns the full path.

    Args:
        base_folder_name: The base name of the folder (e.g., "generated_set_").
        directory: The directory to check for existing folders (default is the current directory).

    Returns:
        The path of the new folder (e.g., ".../generated_set_3").
        Returns base_name + "1" if no folders with that base name exist yet.
        Also creates that folder.
    """

    i = 1
    while True:
        folder_name = f"{base_folder_name}{i}"
        folder_path = os.path.join(directory, folder_name)
        if not os.path.exists(folder_path):
            os.makedirs(folder_path)
            #return folder_name
            return folder_path
            break
        i += 1

# Various (sub)sets of databases
all_databases = ["BIDMC-CHF", "NSR", "Fantasia", "NSR-RR", "CHF-RR"]
subset_1 = ["BIDMC-CHF", "NSR", "Fantasia"]
subset_2 = ["NSR-RR", "CHF-RR"]

# Run the function for each subset
df = pd.read_csv("../data/dataset_splits/all_subjects_overview.csv")

# Generate a random state to use as input
rand_state = np.random.randint(0, 10**6)

## Subset 1
training_set_1, validation_set_1, test_set_1 = split_data(
    df,
    included_dbs=subset_1,
    test_set_size_each_class=3,
    validation_percentage=0.2,
    random_state=rand_state
)

# Define output path
output_dir = "../data/dataset_splits"
output_path = get_next_folder_name(base_folder_name="generated_splits_", directory=output_dir) # Rename to generated_set_1 etc.
print(f"Output folder path: {output_path}")

# Save the random state as metadata in the output folder
metadata_file = os.path.join(output_path, "metadata.txt")
with open(metadata_file, "w") as f:
    f.write(f"random_state: {rand_state}\n")

# Save subset 1
os.path.join(output_folder, "db1_training_set.csv")
training_set_1.to_csv(os.path.join(output_path, "db1_training_set.csv"), index=False)
validation_set_1.to_csv(os.path.join(output_path, "db1_validation_set.csv"), index=False)
test_set_1.to_csv(os.path.join(output_path, "db1_test_set.csv"), index=False)


## Subset 2
training_set_2, validation_set_2, test_set_2 = split_data(
    df, 
    included_dbs=subset_2, 
    test_set_size_each_class=6,
    validation_percentage=0.2,
    random_state=rand_state
)

# Save subset 2
training_set_2.to_csv(os.path.join(output_path, "db2_training_set.csv"), index=False)
validation_set_2.to_csv(os.path.join(output_path, "db2_validation_set.csv"), index=False)
test_set_2.to_csv(os.path.join(output_path, "db2_test_set.csv"), index=False)


## All databases
training_set_3, validation_set_3, test_set_3 = split_data(
    df, 
    included_dbs=all_databases, 
    test_set_size_each_class=9,
    validation_percentage=0.2,
    random_state=rand_state
)

training_set_3.to_csv(os.path.join(output_path, "all_dbs_training_set.csv"), index=False)
validation_set_3.to_csv(os.path.join(output_path, "all_dbs_validation_set.csv"), index=False)
test_set_3.to_csv(os.path.join(output_path, "all_dbs_test_set.csv"), index=False)

Output folder path: ../data/dataset_splits\generated_splits_2


Unnamed: 0,subject_id,database,CHF
59,19093,NSR,0
58,19090,NSR,0
142,f2o07,Fantasia,0
7,chf08,BIDMC-CHF,1
14,chf15,BIDMC-CHF,1
0,chf01,BIDMC-CHF,1


Unnamed: 0,subject_id,database,CHF
65,nsr004,NSR-RR,0
82,nsr021,NSR-RR,0
69,nsr008,NSR-RR,0
103,nsr042,NSR-RR,0
74,nsr013,NSR-RR,0
79,nsr018,NSR-RR,0
39,chf225,CHF-RR,1
31,chf217,CHF-RR,1
41,chf227,CHF-RR,1
43,chf229,CHF-RR,1


Unnamed: 0,subject_id,database,CHF
59,19093,NSR,0
58,19090,NSR,0
61,19830,NSR,0
65,nsr004,NSR-RR,0
82,nsr021,NSR-RR,0
69,nsr008,NSR-RR,0
142,f2o07,Fantasia,0
119,f1o04,Fantasia,0
133,f1y08,Fantasia,0
7,chf08,BIDMC-CHF,1


In [17]:
# Script for renaming files of dataset splits, if desired

# This string will be added at the end of the filenames. 
# Change it to whatever you'd like it to be.
string_to_append = "_2"

# Define the directory containing the files
directory = "../data/dataset_splits"

# Iterate through the files in the directory
for filename in os.listdir(directory):
    # Check if the file starts with "db" and ends with ".csv"
    if filename.startswith("db") and filename.endswith(".csv"):
        # Extract base name and extension
        base_name, ext = os.path.splitext(filename)
        
        # Check if the filename already ends with the string_to_append
        if not base_name.endswith(string_to_append):
            # Construct the new filename by adding the string_to_append
            new_name = f"{base_name}{string_to_append}{ext}"
            
            # Rename the file
            os.rename(os.path.join(directory, filename), os.path.join(directory, new_name))
            print(f"Renamed: {filename} -> {new_name}")
        else:
            print(f"Skipped: {filename} (already ends with '{string_to_append}')")
print("Renaming complete.")

Renamed: db1_test_set_1.csv -> db1_test_set_1_2.csv
Renamed: db1_training_set_1.csv -> db1_training_set_1_2.csv
Renamed: db1_validation_set_1.csv -> db1_validation_set_1_2.csv
Renamed: db2_test_set_1.csv -> db2_test_set_1_2.csv
Renamed: db2_training_set_1.csv -> db2_training_set_1_2.csv
Renamed: db2_validation_set_1.csv -> db2_validation_set_1_2.csv
Renaming complete.
