In [47]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

In [56]:
class LazyLoadDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []  # To store tuples of (X_paths, Y_path)

        for class_name in os.listdir(root_dir):
            class_dir = os.path.join(root_dir, class_name)
            if os.path.isdir(class_dir):
                ink_label_dir = os.path.join(class_dir, "inklabel_crops")
                surface_volume_dir = os.path.join(class_dir, "surface_volume")
                mask_path = os.path.join(class_dir, "mask.png")

                for i in range(65):  # Assuming there are 65 crop images
                    Y_path = os.path.join(ink_label_dir, f"crop_{i}.png")
                    X_paths = [
                        os.path.join(surface_volume_dir, f"{j:02d}_crops", f"crop_{i}.png")
                        for j in range(65)
                    ]
                    self.samples.append((X_paths, Y_path))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        X_paths, Y_path = self.samples[idx]

        return X_paths, Y_path

In [57]:
# Example usage
root_dir = "../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024"
transform = None  # You can add transformations if needed

dataset = LazyLoadDataset(root_dir=root_dir)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

for batch in data_loader:
    X_paths, Y_path = batch
    print("X paths:", X_paths)
    print("Y path:", Y_path)
    break  # Print the first batch for demonstration

X paths: [('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\00_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\01_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\02_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\03_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\04_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\05_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\06_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\07_crops\\crop_50.png',), ('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\

In [58]:
X_paths[64]

('../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\surface_volume\\64_crops\\crop_50.png',)

In [59]:
Y_path[0]

'../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\\3\\inklabel_crops\\crop_50.png'

In [68]:
class CustomImageLoader(Dataset):
    def __init__(self, samples, transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x_paths, ink_label_path = self.samples[idx]

        # Load Y (inklabel crop image)
        print(x_paths)
        x_images = [Image.open(path) for path in x_paths]

        # Load X (list of crop images from subfolders)
        x_images = []
        for subfolder_paths in x_paths:
            x_images.append([Image.open(path) for path in subfolder_paths])

        # Apply transformations if specified
        if self.transform:
            ink_label_images = [self.transform(image) for image in ink_label_images]
            x_images = [[self.transform(image) for image in subfolder] for subfolder in x_images]

        return ink_label_images, x_images

In [69]:
# Create LazyLoadDataset
lazy_loader = LazyLoadDataset(root_dir=root_dir)

# Create CustomImageLoader using the samples from LazyLoadDataset
custom_loader = CustomImageLoader(samples=lazy_loader.samples, transform=transform)

# Create DataLoader for batch processing
batch_size = 1  # Adjust as needed
data_loader = DataLoader(custom_loader, batch_size=batch_size, shuffle=True)

In [70]:
for batch_idx, (ink_label_batch, x_batch) in enumerate(data_loader):
   print(ink_label_batch.shape)

../../Datasets/vesuvius-challenge-ink-detection/cropped_train_1024\2\inklabel_crops\crop_47.png


PermissionError: [Errno 13] Permission denied: '.'