In [None]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
import h5py

# Paths
imgs_path = "lra_release (2)/lra_release/pathfinder32/curv_contour_length_14/imgs"
metadata_path = "lra_release (2)/lra_release/pathfinder32/curv_contour_length_14/metadata"
output_file = "merged_data.h5"

# Preprocessing and storing
images_list = []
labels_list = []

for folder in os.listdir(imgs_path):
    folder_path = os.path.join(imgs_path, folder)
    if not os.path.isdir(folder_path):
        continue

    # Load metadata as text
    metadata_file = os.path.join(metadata_path, f"{folder}.npy")
    with open(metadata_file, "r") as f:
        metadata_lines = f.readlines()

    # Parse metadata and load corresponding images
    for line in metadata_lines:
        parts = line.strip().split()  # Split by whitespace
        img_relative_path = folder + "/" + parts[1]  # Reconstruct image filename
        # if(parts[1] == "sample_172.png"):
        #     continue
        label = int(parts[3])  # Extract the fourth column as label

        # Load the image
        img_path = os.path.join(imgs_path, img_relative_path)
        image = Image.open(img_path).convert("L")  # Ensure grayscale
        image = np.array(image, dtype="uint8")

        images_list.append(image)
        labels_list.append(label)

# Convert to arrays
images_array = np.stack(images_list, axis=0)  # Shape: (N, 32, 32)
labels_array = np.array(labels_list, dtype="int")  # Shape: (N,)

# Save to HDF5
with h5py.File(output_file, "w") as h5_file:
    h5_file.create_dataset("images", data=images_array, dtype="uint8")
    h5_file.create_dataset("labels", data=labels_array, dtype="int")

print("Merged data preprocessing and storage complete!")

# Dataset class
class CustomDataset(Dataset):
    def __init__(self, h5_file):
        self.h5_file = h5py.File(h5_file, "r")
        self.images = self.h5_file["images"]
        self.labels = self.h5_file["labels"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx], dtype=torch.uint8).unsqueeze(0)  # Add channel dimension
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

dataset = CustomDataset(output_file)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

for images, labels in loader:
    print(images.shape, labels.shape)


Merged data preprocessing and storage complete!
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.Size([32, 1, 32, 32]) torch.Size([32])
torch.S