In [None]:
import os
import tensorflow as tf

In [2]:
def load_dataset_paths(root_folder):
    files = os.listdir(root_folder)

    data = {
        'modern' : {
            'file_paths' : [],
            'file_numbers_of_samples' : []
        },
        'proterozoic' : {
            'file_paths' : [],
            'file_numbers_of_samples' : []
        },
        'archean' : {
            'file_paths' : [],
            'file_numbers_of_samples' : []
        },
    }

    for file in files:
        era = file.split('/')[-1].split('_')[0]

        data[era]['file_paths'].append(os.path.join(root_folder, file))
        data[era]['file_numbers_of_samples'].append(int(file.split('_')[-1].split('.')[0]))

    return data

In [3]:
def split_era(root_folder, train_split = 0.8, val_split = 0.1):
    tf_data = load_dataset_paths(root_folder)

    split_indexes = {
        'train' : False,
        'val' : False,
    }

    # We do it for each era so that the dataset can be stratified.
    for era in list(tf_data.keys()):
        total_number_of_samples = sum(tf_data[era]['file_numbers_of_samples'])

        count = 0
        for index, num_samples in enumerate(tf_data[era]['file_numbers_of_samples']):
            if (count >= total_number_of_samples * (train_split)) and (split_indexes['train'] == False):
                split_indexes['train'] = index

            if count >= total_number_of_samples * (train_split + val_split):
                split_indexes['val'] = index
                break
            
            count += num_samples
        

        tf_data[era]['split_indexes'] = split_indexes
    
    return tf_data

In [4]:
def train_val_test_split_concatenate(root_folder):
    tf_data = split_era(root_folder)

    # All splits must be the same for all eras.
    train_split = tf_data['modern']['split_indexes']['train']
    val_split = tf_data['modern']['split_indexes']['val']

    data_types_paths = {
        'train' : tf_data['modern']['file_paths'][:train_split] + tf_data['proterozoic']['file_paths'][:train_split] + tf_data['archean']['file_paths'][:train_split],

        'val' : tf_data['modern']['file_paths'][train_split:val_split] + tf_data['proterozoic']['file_paths'][train_split:val_split] + tf_data['archean']['file_paths'][train_split:val_split],

        'test' : tf_data['modern']['file_paths'][val_split:] + tf_data['proterozoic']['file_paths'][val_split:] + tf_data['archean']['file_paths'][val_split:]
    }

    for data_type in list(data_types_paths.keys()):
        # Output TFRecord file
        output_tfrecord_file = f'../data/geexhp_{data_type}_samples.tfrecord'

        # Write concatenated records to a new TFRecord file
        with tf.io.TFRecordWriter(output_tfrecord_file) as writer:
            for tfrecord_file in data_types_paths[data_type]:
                # Read each TFRecord file
                for record in tf.data.TFRecordDataset(tfrecord_file):
                    writer.write(record.numpy())

        print(f"Concatenated TFRecord file saved to '{output_tfrecord_file}'")


In [None]:
root_folder = '../data/TFRecord_data'
data = train_val_test_split_concatenate(root_folder)