In [None]:
class SWDataset(Dataset):
    """A custom dataset to read the shallow water (SW) simulation data."""

    def __init__(
        self,
        file_path,
        order=1,
        numtime=1200,
        mode="train",
        train_frac=0.5,
        valid_frac=0.2,
        normalize=True
    ):
        """
        Initialize the dataset.

        Parameters:
            - file_path (str): Path to the data file.
            - order (int): Autoregressive model order.
            - numtime (int): Total number of time steps in the data.
            - mode (str): Mode of operation ("train", "valid", "test").
            - train_frac (float): Fraction of data to use for training.
            - valid_frac (float): Fraction of data to use for validation.
        """
        assert mode in [
            "train",
            "valid",
            "test",
        ], "Mode should be either 'train', 'valid', or 'test'"

        super(SWDataset, self).__init__()

        self.file_path = file_path
        self.order = order
        self.numtime = numtime
        self.mode = mode

        # Determine split indices based on dataset size and provided fractions
        total_samples = self.numtime - 1 - self.order
        self.train_end = int(train_frac * total_samples)
        self.valid_end = self.train_end + int(valid_frac * total_samples)

        # Lists to store elevation and velocity data for normalization
        zetas = []
        velocities = []

        # Read data from HDF5 file
        with h5py.File(self.file_path, "r") as hdf_file:
            self.init_zeta = torch.tensor(hdf_file[f"timestep_0"]['elevation'][:])
            self.init_vel = torch.tensor(hdf_file[f"timestep_0"]['velocity'][:])
            for index in range(self.numtime):
                zetas.append(hdf_file[f"timestep_{index}"]["elevation"][:])
                velocities.append(hdf_file[f"timestep_{index}"]["velocity"][:])

        # Convert list to numpy array for easier operations
        zetas = np.array(zetas)
        velocities = np.array(velocities)

        # Calculate statistics (mean, std, min, max) for normalization
        self.zeta_min, self.velocity_min = 0.0, 0.0
        self.zeta_max, self.velocity_max = 1.0, 1.0
        if normalize:
          self.zeta_min = zetas.min()
          self.zeta_max = zetas.max()
          self.velocity_min = velocities.min()
          self.velocity_max = velocities.max()

    def __len__(self):
        """Return the length of the dataset based on mode."""
        if self.mode == "train":
            return self.train_end
        elif self.mode == "valid":
            return self.valid_end - self.train_end
        else:  # mode is 'test'
            return self.numtime - 1 - self.order - self.valid_end

    def __getitem__(self, index):
        """Return the data at the given index."""
        # Adjust the index based on mode (train/valid/test)
        if self.mode == "valid":
            index += self.train_end
        elif self.mode == "test":
            index += self.valid_end

        zetas, velocities = [], []

        # Read data for the given index from the HDF5 file
        with h5py.File(self.file_path, "r") as hdf_file:
            for i in range(self.order):
                # Normalize Elevation and velocity using min-max scaling
                zeta = (
                    hdf_file[f"timestep_{index + i}"]["elevation"][:]
                    - self.zeta_min
                ) / (self.zeta_max - self.zeta_min)
                vel = (
                    2
                    * (
                        (
                            hdf_file[f"timestep_{index + i}"]["velocity"][:]
                            - self.velocity_min
                        )
                        / (self.velocity_max - self.velocity_min)
                    )
                    - 1
                )

                # Add an extra dimension to match the expected input shape
                zetas.append(zeta[None, :])
                velocities.append(vel[None, :])

            # Extract and normalize target data for the given index
            target_zeta = (
                hdf_file[f"timestep_{index + self.order}"]["elevation"][:]
                - self.zeta_min
            ) / (self.zeta_max - self.zeta_min)
            target_vel = (
                2
                * (
                    (
                        hdf_file[f"timestep_{index + self.order}"]["velocity"][:]
                        - self.velocity_min
                    )
                    / (self.velocity_max - self.velocity_min)
                )
                - 1
            )

        # Concatenate the input data
        input_zeta = np.concatenate(zetas, axis=0)
        input_vel = np.concatenate(velocities, axis=0)

        # Convert input and target data to PyTorch tensors
        return (
            torch.tensor(input_zeta, dtype=torch.float32),
            torch.tensor(input_vel, dtype=torch.float32),
        ), (
            torch.tensor(
                target_zeta[None, :], dtype=torch.float32
            ),  # Add an extra dimension
            torch.tensor(
                target_vel[None, :], dtype=torch.float32
            ),  # Add an extra dimension
        )

    def get_initial_conditions(self):
        return self.init_zeta.unsqueeze(0).unsqueeze(0), self.init_vel.unsqueeze(0).unsqueeze(0)