In [1]:
import os
import shutil
import numpy as np
from itertools import groupby
from tqdm import tqdm

# Names of the directories (same location)
source_dirs = ['data/third_data']
destination_dir = 'data/third_data_split/'

# Create the destination directories
train_dir = os.path.join(destination_dir, 'Train')
eval_dir = os.path.join(destination_dir, 'Eval')
test_dir = os.path.join(destination_dir, 'Test')

for dir in [train_dir, eval_dir, test_dir]:
    if not os.path.exists(dir):
        os.makedirs(dir)
    else:
        # Replace the existing files with new split
        print("Destination is not empty. Replacing with new splitted files.")
        for file in os.listdir(dir):
            os.remove(os.path.join(dir, file))

# Go through each directory in source
for source in tqdm(source_dirs):
   
    files = os.listdir(source)

    # Group all the files with the same row number
    file_groups = [list(group) for key, group in groupby(sorted(files), key=lambda x: x.split('_')[1])]
    print(f"Number of groups in {source}: {len(file_groups)}")

    # Random Split
    #np.random.seed(42)
    np.random.shuffle(file_groups)

    # Split the list into train, eval and test (70%, 15%, 15%)
    train_split = int(0.7 * len(file_groups))
    eval_split = int(0.15 * len(file_groups)) + train_split

    train_groups = file_groups[:train_split]
    eval_groups = file_groups[train_split:eval_split]
    test_groups = file_groups[eval_split:]

    # Copy the file groups to the split directories
    def copy_files(file_groups, destination):
        for file_group in file_groups:
            for file in file_group:
                shutil.copy(os.path.join(source, file), os.path.join(destination, file))

    copy_files(train_groups, train_dir)
    copy_files(eval_groups, eval_dir)
    copy_files(test_groups, test_dir)

print(f"Number of splitted groups in training set: {len(train_groups)}")
print(f"Number of splitted groups in evaluation set: {len(eval_groups)}")
print(f"Number of splitted groups in test set: {len(test_groups)}")

if not os.listdir(source):
    print("Source directory is empty.")
    exit()

print("Split complete.")

  0%|          | 0/1 [00:00<?, ?it/s]

Number of groups in data/third_data: 10170


100%|██████████| 1/1 [01:25<00:00, 85.82s/it]

Number of splitted groups in training set: 7119
Number of splitted groups in evaluation set: 1525
Number of splitted groups in test set: 1526
Split complete.



