# Analyze the datasets

This notebook is used to analyze and generate statistics about the datasets.
Most importantly, it computes the number of samples in each dataset.


In [1]:
# Import the necessary libraries
from pathlib import Path
import h5py


In [None]:
# Find all datasets
data_dir = Path(r"/hpcwork/rwth1802/coding/Large-Physics-Foundation-Model/data/datasets")

# list all dirs in the data_dir
datasets = []
for dir in data_dir.iterdir():
    if dir.is_dir():
        datasets.append(dir)

# Print the number of samples in each dataset

dataset_stats = {}

for dataset in datasets:
    dataset_name = dataset.name
    dataset_stats[dataset_name] = {}

    # gather the size of the dataset
    dataset_size = 0
    for file in dataset.glob("**/*"):
        if file.is_file():
            dataset_size += file.stat().st_size

    # Convert bytes to GB for better readability
    dataset_stats[dataset_name]["size"] = f"{dataset_size / (1024**3):.2f} GB"

    # find all h5 files in the dataset
    h5_files = list(dataset.glob("**/*.hdf5"))

    n_traj = 0
    for h5_file in h5_files:
        with h5py.File(h5_file, "r") as f:
            # traj is attribute of the file
            n_traj += int(f.attrs["n_trajectories"])

    # use the last h5 file to get the number of timesteps, x, and y
    with h5py.File(h5_files[-1], "r") as f:
        # time is a dataset in the group "dims"
        n_timesteps = len(f["dimensions"]["time"])
        # x is a dataset in the group "dims"
        n_x = len(f["dimensions"]["x"])
        # y is a dataset in the group "dims"
        n_y = len(f["dimensions"]["y"])
        # number of fields is a dataset in the group "fields"
        n_fields_0 = len(f["t0_fields"].attrs["field_names"])
        n_fields_1 = len(f["t1_fields"].attrs["field_names"])

    dataset_stats[dataset_name]["n_traj"] = n_traj
    dataset_stats[dataset_name]["n_timesteps"] = n_timesteps
    dataset_stats[dataset_name]["n_x"] = n_x
    dataset_stats[dataset_name]["n_y"] = n_y
    dataset_stats[dataset_name]["n_fields_0"] = n_fields_0
    dataset_stats[dataset_name]["n_fields_1"] = n_fields_1

for dataset_name, stats in dataset_stats.items():
    print(f"{dataset_name}:")
    for key, value in stats.items():
        print(f"  {key}: {value}")
    print()





