In [None]:
# step1: 将多级目录的mat数据集整理成h5
    parser = argparse.ArgumentParser(description='Prepare H5 dataset for CMRxRecon series dataset')
    parser.add_argument('--output_h5_folder', type=str,
                        default='/common/users/bx64/dataset/CMRxRecon2024/h5_dataset',
                        help='path to save H5 dataset')
    parser.add_argument('--input_matlab_folder', type=str,
                        default='/common/users/bx64/dataset/CMRxRecon2024/home2/Raw_data/MICCAIChallenge2024/ChallengeData/MultiCoil',
                        help='path to the original matlab data')
    parser.add_argument('--split_json', type=str, default='configs/data_split/cmr24-cardiac.json', help='path to the split json file')
    parser.add_argument('--year', type=int, required=True, choices=[2024, 2023, 2025], help='year of the dataset')
    args = parser.parse_args()
    
    save_folder = args.output_h5_folder
    mat_folder = args.input_matlab_folder
    year = args.year
    split_json = args.split_json
    
    print('matlab data folder: ', mat_folder)
    print('h5 save folder: ', save_folder)

    if not os.path.exists(save_folder):
        os.makedirs(save_folder)
        
    print('## step 1: convert matlab training dataset to h5 dataset')

    file_list = sorted(glob.glob(join(mat_folder, '*/TrainingSet/FullSample/Center*/*/P*/*.mat')))
    print('number of total matlab files: ', len(file_list))
    
    # check if cuda is available
    if torch.cuda.is_available():
        device = torch.device('cuda:1')
    else:
        device = torch.device('cpu')

    for ff in tqdm(file_list):
        ##* get info from path
        match = re.search(r'MultiCoil/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)/([^/]+)', ff)
        modal = match.group(1)
        TrainingSet = match.group(2)
        FullSample = match.group(3)
        center = match.group(4)
        mridevice = match.group(5)
        paid = match.group(6)
        directory, filename = os.path.split(ff)  # 分割为目录和文件名
        fid = os.path.basename(directory)  # 获取上一级目录名
        ftype = os.path.splitext(filename)[0]  # 获取文件名（不含扩展名）
        save_name = f'{modal}@{TrainingSet}@{FullSample}@{center}@{mridevice}@{paid}@{ftype}'
        
        ##*remove bad files
        if remove_bad_files(save_name) and year == 2024:
            continue

        ##* load kdata
        kdata = load_kdata(ff)
        
        ## transpose if the format of shape is matlab style
        if kdata.shape[0] > 100:
            kdata = kdata.transpose(tuple(range(kdata.ndim)[::-1]))

        ##* swap phase_encoding and readout
        kdata = kdata.swapaxes(-1,-2)
        
        ##* remove bad slices
        if year == 2024:
            kdata = remove_bad_slices(kdata, save_name)
        
        ##* get rss from kdata
        kdata_th = to_tensor(kdata)
        img_coil = ifft2c(kdata_th).to(device)
        img_rss = rss_complex(img_coil, dim=-3).cpu().numpy()

        ##* save h5
        file = h5py.File(join(save_folder, save_name + '.h5'), 'w')
        file.create_dataset('kspace', data=kdata)
        file.create_dataset('reconstruction_rss', data=img_rss)

        file.attrs['max'] = img_rss.max()
        file.attrs['norm'] = np.linalg.norm(img_rss)
        file.attrs['acquisition'] = modal
        file.attrs['shape'] = kdata.shape
        file.attrs['padding_left'] = 0
        file.attrs['padding_right'] = kdata.shape[-1]
        file.attrs['encoding_size'] = (kdata.shape[-2],kdata.shape[-1],1)
        file.attrs['recon_size'] = (kdata.shape[-2],kdata.shape[-1],1)
        file.attrs['patient_id'] = paid
        file.attrs['center'] = center
        file.attrs['mridevice'] = mridevice
        file.close()

In [2]:
# step2: 生成数据集分割json

import glob
import json
import os
from sklearn.model_selection import train_test_split

def generate_dataset_split_json(input_h5file_path, output_json_path="dataset_split.json", test_size=0.2, random_seed=42):
    """
    根据当前路径下的 .h5 文件生成数据集分割的 JSON 文件。
    
    Args:
        output_json_path (str): 输出 JSON 文件的路径。
        test_size (float): 验证集占比（默认为 0.2，即 20%）。
        random_seed (int): 随机种子（默认为 42，保证可复现性）。
    """
    # 获取当前路径下所有 .h5 文件的路径
    h5_files = glob.glob(os.path.join(input_h5file_path, "*.h5"))
    
    # 检查是否有 .h5 文件
    if not h5_files:
        print("当前路径下没有找到任何 .h5 文件！")
        return

    # 使用 train_test_split 进行数据集划分
    train_files, val_files = train_test_split(h5_files, test_size=test_size, random_state=random_seed)

    # 构造 JSON 数据
    dataset_split = {
        "train": train_files,
        "val": val_files
    }

    # 将分割结果写入 JSON 文件
    with open(output_json_path, "w") as json_file:
        json.dump(dataset_split, json_file, indent=4)

    print(f"数据集分割已完成，结果已保存到 {output_json_path}")

# 调用函数
generate_dataset_split_json(input_h5file_path="/home/hulabdl/CMRxRecon2025/h5_dataset",
                            output_json_path="configs/data_split/cmr25-cardiac.json", test_size=0.2, random_seed=42)

数据集分割已完成，结果已保存到 configs/data_split/cmr25-cardiac.json


# step3: 根据json生成软连接
import os
import json
import glob
from os.path import join, split, realpath

def create_symbolic_links(split_json_path, save_folder):
    """
    根据给定的 JSON 文件（包含训练集和验证集文件名），
    在目标文件夹内创建训练集和验证集的符号链接。

    参数：
        split_json_path (str): 包含训练集和验证集文件名的 JSON 文件路径。
        save_folder (str): 包含 .h5 文件的目录路径。

    """
    # 读取 JSON 文件
    with open(split_json_path, 'r', encoding="utf-8") as f:
        split_dict = json.load(f)

    print('JSON 文件中训练文件数量: ', len(split_dict['train']))
    print('JSON 文件中验证文件数量: ', len(split_dict['val']))

    # 创建训练和验证文件夹
    train_save_folder = join(save_folder, 'train')
    val_save_folder = join(save_folder, 'val')
    if not os.path.exists(train_save_folder):
        os.makedirs(train_save_folder)
    if not os.path.exists(val_save_folder):
        os.makedirs(val_save_folder)

    # 创建符号链接
    for f in split_dict['train']:
        # 根据f获取绝对路径
        f = realpath(f)
        save_name = split(f)[-1]
        os.symlink(f, join(train_save_folder, save_name))

    for f in split_dict['val']:
        f = realpath(f)
        save_name = split(f)[-1]
        os.symlink(f, join(val_save_folder, save_name))

    # 打印结果
    print('完成！')
    print('训练集文件夹中的符号链接文件数量: ', len(glob.glob(join(train_save_folder, '*.h5'))))
    print('验证集文件夹中的符号链接文件数量: ', len(glob.glob(join(val_save_folder, '*.h5'))))

create_symbolic_links("configs/data_split/cmr25-cardiac.json", "/home/hulabdl/CMRxRecon2025/")