In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
import glob
import numpy as np
import pandas as pd
import h5py
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import lightning as L
from collections import OrderedDict
from electric_images_dataset import ElectricImagesDataset
from EndToEndConvNN import EndToEndConvNN
from EndToEndConvNN_PL import EndToEndConvNN_PL

import sys
sys.path.append("../../efish-physics-model/objects")
sys.path.append("../../efish-physics-model/helper_functions")
sys.path.append("../../efish-physics-model/uniform_points_generation")

In [None]:
data_dir_name = "../../efish-physics-model/data/processed/data-2024_06_13-characterization_dataset"
# data_dir_name = "../../efish-physics-model/data/processed/data-2024_06_13-characterization_dataset_mockup"
dataset = pd.read_pickle(f"{data_dir_name}/dataset.pkl")
h5py_file = h5py.File(f"{data_dir_name}/responses.hdf5",'r')["responses"]

In [None]:
with h5py.File(f"hdf5-chunking/responses-chunked.hdf5", 'w') as f:
    f.create_dataset("test-chunking", shape=h5py_file.shape, dtype=h5py_file.dtype, chunks=tuple([1]+list(h5py_file.shape[1:])))

In [None]:
with h5py.File(f"hdf5-chunking/responses-chunked.hdf5", "r+") as f:
    write_size = 100_000
    for i in range(0, h5py_file.shape[0], write_size):
        print(i//write_size, end=", ")
        f["test-chunking"][i : i + write_size] = h5py_file[i : i + write_size]
        if (i/write_size+1) % 50 == 0:
            print()

In [None]:
ids = np.random.permutation(h5py_file.shape[0])[:10]
ids = np.sort(ids)
ids, h5py_file[ids], h5py_file_chunked[ids]

In [None]:
dset = ElectricImagesDataset(data_dir_name="hdf5-chunking", fish_t=20, fish_u=30)

In [None]:
train_dset, valid_dset = torch.utils.data.random_split(dset, [0.8, 0.2])
train_loader = DataLoader(train_dset, batch_size=27000, shuffle=True, drop_last=True, num_workers=12)
valid_loader = DataLoader(valid_dset, batch_size=64, shuffle=False, drop_last=True)

In [None]:
for batch_idx, batch in enumerate(train_loader):
    print(batch[0].shape, end=" ")
    if (batch_idx + 1) % 50 == 0:
        print()

In [None]:
next(iter(train_loader))[0].shape

In [3]:
data_dir_name = "../../efish-physics-model/data/processed/data-2024_06_13-characterization_dataset"
h5py_file = h5py.File(f"{data_dir_name}/responses.hdf5",'r')["responses"]
h5py_file_chunked = h5py.File(f"hdf5-chunking/responses-chunked.hdf5", 'r')["test-chunking"]
h5py_file.shape, h5py_file_chunked.shape

((51004800, 600, 2), (51004800, 600, 2))

In [9]:
ids = np.random.permutation(h5py_file.shape[0])[:27000]
# ids = np.sort(ids)
t0 = time.time()
for i in ids:
    _ = h5py_file[i]
t1 = time.time()
for i in ids:
    _ = h5py_file_chunked[i]
t2 = time.time()
print(f"Time for non-chunked: {t1-t0:.2f} s")
print(f"Time for     chunked: {t2-t1:.2f} s")

Time for non-chunked: 8.94 s
Time for     chunked: 4.88 s
