In [6]:
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
import h5py

file_path = "binary/cropped_ions.h5"


class IonImageDatasetWithNormalizationTensors(Dataset):
    def __init__(self, file_path):
        # Load the .h5 file
        self.file = h5py.File(file_path, "r")
        # Extract and store keys
        self.keys = list(self.file.keys())

        # Here, we compute the global mean and std using PyTorch tensors
        self.global_mean, self.global_std = self.compute_global_mean_std()

    def compute_global_mean_std(self):
        # Accumulators for mean and std computation
        all_data = []
        for key in self.keys:
            data = torch.tensor(self.file[key][:], dtype=torch.float32).flatten()
            all_data.append(data)
        all_data = torch.cat(all_data)

        # Compute the global mean and std
        global_mean = torch.mean(all_data)
        global_std = torch.std(all_data)

        return global_mean.item(), global_std.item()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        # Access the dataset using the key
        data = self.file[self.keys[idx]]
        # Convert the data to a torch tensor (for preprocessing)
        image_tensor = torch.tensor(data[:], dtype=torch.float32)

        # Standardization
        image_tensor = (image_tensor - self.global_mean) / self.global_std

        # Placeholder for actual label determination logic
        label = torch.tensor(-1)

        return image_tensor, label

    def close(self):
        self.file.close()


# Initialize the dataset with tensor-based normalization
dataset_normalized_tensors = IonImageDatasetWithNormalizationTensors(file_path)

# Fetch one sample from the dataset to demonstrate normalization with tensors
normalized_tensor_sample_image, normalized_tensor_sample_label = (
    dataset_normalized_tensors[0]
)

# Showing the shape of one sample image and its mean and std after normalization using tensors
print(
    (
        normalized_tensor_sample_image.shape,
        normalized_tensor_sample_image.mean(),
        normalized_tensor_sample_image.std(),
        normalized_tensor_sample_label,
    )
)

# Remember to close the dataset file when done to avoid resource leaks
dataset_normalized_tensors.close()

(torch.Size([5, 5]), tensor(-0.0400), tensor(0.9086), tensor(-1))
