In [21]:
import os
import h5py
import numpy as np
from tqdm import tqdm

In [22]:
def split_to_patches(arr, patch_size=256, stride=256):
    """
    arr: numpy array of shape (H, W) or (C, H, W)
    return: list of patch arrays
    """
    if arr.ndim == 2:
        H, W = arr.shape
        get_patch = lambda y, x: arr[y:y+patch_size, x:x+patch_size]
    elif arr.ndim == 3:
        C, H, W = arr.shape
        get_patch = lambda y, x: arr[:, y:y+patch_size, x:x+patch_size]
    else:
        raise ValueError("Unsupported shape")

    patches = []
    y_coords = list(range(0, H - patch_size + 1, stride))
    x_coords = list(range(0, W - patch_size + 1, stride))
    
    if (H - patch_size) not in y_coords:
        y_coords.append(H - patch_size)
    if (W - patch_size) not in x_coords:
        x_coords.append(W - patch_size)
    
    for y in y_coords:
        for x in x_coords:
            patch = get_patch(y, x)
            patches.append(patch)
    
    return patches

def process_h5_file(file_path, save_dir=None, patch_size=256, stride=128):
    with h5py.File(file_path, 'r') as f:
        keys = list(f.keys())
        #print(keys)
        data_dict = {k: f[k][()] for k in keys}  # 모든 데이터를 메모리로 불러옴
        #print(f"Data shape: {[data.shape for data in data_dict.values()]}")

    # 각 key별로 patch 나누기
    patch_lists = []
    for key, data in data_dict.items():
        if key != 'EXP':
            patches = split_to_patches(data, patch_size, stride)
            patch_lists.append((key, patches))

    # 패치 저장
    base_name = os.path.splitext(os.path.basename(file_path))[0]
    if save_dir is None:
        save_dir = os.path.dirname(file_path)

    num_patches = len(patch_lists[0][1])  # 모든 key에 대해 patch 수가 같다고 가정

    for i in range(num_patches):
        save_path = os.path.join(save_dir, f"{base_name}_p{i:03d}.h5")
        with h5py.File(save_path, 'w') as f_out:
            f_out.create_dataset('EXP', data=data_dict['EXP'])
            for key, patches in patch_lists:
                f_out.create_dataset(key, data=patches[i])

def process_all_h5_in_dir(h5_dir, save_dir, patch_size=256, stride=128):
    h5_files = [os.path.join(h5_dir, f) for f in os.listdir(h5_dir) if f.endswith('.h5')]
    for path in tqdm(h5_files, desc="Processing H5 files"):
        process_h5_file(path, save_dir, patch_size=patch_size, stride=stride)

In [26]:
process_all_h5_in_dir('/home/jaekim/ws/data/Kalantari/HDF/aligned/Training', 
                      '/home/jaekim/ws/data/Kalantari/HDF/aligned/Training_patch',
                      patch_size=256, stride=128)

Processing H5 files: 100%|██████████| 74/74 [07:26<00:00,  6.04s/it]


In [27]:
import shutil

src_dir = "/home/jaekim/ws/data/Kalantari/HDF/aligned/Training_patch"
dst_dir = "/data2/Kalantari/aligned/Training_patch"
shutil.copytree(src_dir, dst_dir)

'/data2/Kalantari/aligned/Training_patch'

In [28]:
process_all_h5_in_dir('/home/jaekim/ws/data/Kalantari/HDF/org/Training', 
                      '/home/jaekim/ws/data/Kalantari/HDF/org/Training_patch',
                      patch_size=256, stride=128)

Processing H5 files: 100%|██████████| 74/74 [06:19<00:00,  5.13s/it]


In [29]:
src_dir = "/home/jaekim/ws/data/Kalantari/HDF/org/Training_patch"
dst_dir = "/data2/Kalantari/Training_patch"
shutil.copytree(src_dir, dst_dir)

KeyboardInterrupt: 