In [1]:
import tensorflow as tf
from pathlib import Path
import random
import numpy as np
import h5py
import glob
import random

In [2]:
# parameters
work_dir = Path('/home/abslon/Workspace/pointnet2-tf2/data')  # path to data folder ex) ~/Workspace/data
env = 'DamBreak'
data_names = ['positions', 'velocities']

In [3]:
data_dir = work_dir / ('data_' + env)
stat_path = data_dir / 'stat.h5'
train_dir = data_dir / 'train'
vaild_dir = data_dir / 'vaild'

## consume all data files - train_data, valid_data

In [4]:
def read_h5(data_names, path):
    hf = h5py.File(path, 'r')
    data = []
    for i in range(len(data_names)):
        d = np.array(hf.get(data_names[i]))
        data.append(d)
    hf.close()
    return data

In [5]:
example_data_path = train_dir / '0' / '12.h5'
example_data = read_h5(data_names, str(example_data_path))
example_data  # list of data as np.ndarray

[array([[0.09320369, 0.01      , 0.02530662],
        [0.097744  , 0.01      , 0.07167163],
        [0.0566079 , 0.01      , 0.14564845],
        ...,
        [0.72503525, 1.0285343 , 0.5305556 ],
        [0.73040456, 1.0201541 , 0.5734387 ],
        [0.7225195 , 1.0150986 , 0.6175292 ]], dtype=float32),
 array([[-1.476855  ,  0.        , -0.36068556],
        [-1.0073737 ,  0.        , -0.45203176],
        [-1.4364111 ,  0.        , -0.22117494],
        ...,
        [-0.13065577, -2.0351171 , -0.08110284],
        [-0.13343096, -2.1802855 , -0.15206693],
        [-0.21788119, -2.1926022 , -0.20386575]], dtype=float32)]

In [6]:
train_data = list()  #[(data, next_data), ...],  data = [pos, vel]
valid_data = list() 

In [7]:
def sort_key(path):
    names = path.split('/')
    num_str = names[-1]
    nums = [int(s) for s in num_str.split('.') if s.isdigit()]
    return nums[0]

In [8]:
# get train data
rollout_list = glob.glob( str(train_dir / '*') )
rollout_list = sorted(rollout_list, key=sort_key)
rollout_list

['/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/0',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/1',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/2',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/3',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/4',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/6',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/7',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/8',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/9',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/10',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/11',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/12',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/13',
 '/home/abslon/Workspace/pointnet2-tf2/data/

In [9]:
example_rollout = rollout_list[21]
example_data_list = glob.glob( str(Path(example_rollout) / '*.h5') )
example_data_list = sorted(example_data_list, key=sort_key)
example_data_list

['/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/0.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/1.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/2.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/3.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/4.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/5.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/6.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/7.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/8.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/9.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/10.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/11.h5',
 '/home/abslon/Workspace/pointnet2-tf2/data/data_DamBreak/train/21/12.h5',
 '/home/abslon/Workspace/pointnet2-

In [10]:
#suffle rollout list
random.shuffle(rollout_list)
for ro in rollout_list:
    data_path_list = glob.glob( str(Path(ro) / '*.h5') )
    data_path_list = sorted(data_path_list, key=sort_key)
    
    random.shuffle(data_path_list)
    for idx in range(len(data_path_list)):
        # if idx is the last index of the list, continue
        if idx == len(data_path_list) - 1:
            continue
            
        data = read_h5(data_names, data_path_list[idx])
        data_nxt = read_h5(data_names, data_path_list[idx+1])
        
        train_data.append((data, data_nxt))

KeyboardInterrupt: 