In [None]:
import os
import shutil
import random

def split_data_with_masks(data_source_dir,mask_source_dir, train_data_dir, val_data_dir, test_data_dir, 
                          train_mask_dir, val_mask_dir, test_mask_dir, 
                          train_ratio, val_ratio, test_ratio, mask_suffix='_mask'):
    """
    Splits data files from a source directory into training, validation, and testing directories, and
    moves corresponding mask files from separate mask directories to match the data split.

    Args:
        data_source_dir (str): The path of the directory containing the data files.
        train_data_dir (str): The path of the directory where the training data files will be moved.
        val_data_dir (str): The path of the directory where the validation data files will be moved.
        test_data_dir (str): The path of the directory where the testing data files will be moved.
        train_mask_dir (str): The path of the directory where the corresponding training mask files will be moved.
        val_mask_dir (str): The path of the directory where the corresponding validation mask files will be moved.
        test_mask_dir (str): The path of the directory where the corresponding testing mask files will be moved.
        train_ratio (float): The proportion of data to be used for training (e.g., 0.7 for 70%).
        val_ratio (float): The proportion of data to be used for validation (e.g., 0.15 for 15%).
        test_ratio (float): The proportion of data to be used for testing (e.g., 0.15 for 15%).
        mask_suffix (str): Suffix added to the file name to denote the corresponding mask file.
    """
    # Get all data files (excluding mask files)
    all_files = [f for f in os.listdir(data_source_dir) if f.endswith('.npy')]
    
    # Shuffle the files randomly
    random.shuffle(all_files)

    # Calculate the number of files for each set
    total_files = len(all_files)
    train_size = int(total_files * train_ratio)
    val_size = int(total_files * val_ratio)
    test_size = total_files - train_size - val_size  # Remaining files go to the test set

    # Ensure the destination directories exist
    for directory in [train_data_dir, val_data_dir, test_data_dir, train_mask_dir, val_mask_dir, test_mask_dir]:
        if not os.path.exists(directory):
            os.makedirs(directory)
    
    # Define sets
    train_files = all_files[:train_size]
    val_files = all_files[train_size:train_size + val_size]
    test_files = all_files[train_size + val_size:]

    # Function to move files along with their corresponding masks
    def move_files(file_list, data_source_dir, data_dest_dir, mask_dest_dir):
        for file_name in file_list:
            # Move the data file
            source_file = os.path.join(data_source_dir, file_name)
            destination_file = os.path.join(data_dest_dir, file_name)
            shutil.move(source_file, destination_file)
            print(f"Moved data: {source_file} -> {destination_file}")
            
            # Move the corresponding mask file if it exists in its specific folder
            mask_file_name = file_name.replace('.npy', mask_suffix + '.npy')
            source_mask_file = os.path.join(data_source_dir.replace("all_data", "all_masks"), mask_file_name)
            if os.path.exists(source_mask_file):
                destination_mask_file = os.path.join(mask_dest_dir, mask_file_name)
                shutil.move(source_mask_file, destination_mask_file)
                print(f"Moved mask: {source_mask_file} -> {destination_mask_file}")

    # Move files and corresponding masks to their respective directories
    move_files(train_files, data_source_dir, train_data_dir, train_mask_dir)
    move_files(val_files, data_source_dir, val_data_dir, val_mask_dir)
    move_files(test_files, data_source_dir, test_data_dir, test_mask_dir)

# Example usage
data_source_dir = "/home/max/CMB_plot/code/ILC/synthesized_ILC_MW_maps"
train_data_dir = "/home/max/CMB_plot/code/U-Net/data/train_data"
val_data_dir = "/home/max/CMB_plot/code/U-Net/data/validate_data"
test_data_dir = "/home/max/CMB_plot/code/U-Net/data/test_data"

mask_source_dir = "/home/max/CMB_plot/code/U-Net/data/all_masks"
train_mask_dir = "/home/max/CMB_plot/code/U-Net/data/train_masks"
val_mask_dir = "/home/max/CMB_plot/code/U-Net/data/validate_masks"
test_mask_dir = "/home/max/CMB_plot/code/U-Net/data/test_masks"

# Define the ratios for training, validation, and testing
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

split_data_with_masks(data_source_dir, train_data_dir, val_data_dir, test_data_dir, 
                      train_mask_dir, val_mask_dir, test_mask_dir, 
                      train_ratio, val_ratio, test_ratio)