In [1]:
import os
# Display current working directory
print(os.getcwd())
# To make sure opencv imports .exr files
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"

d:\minorProject\WDSS\jupyter_notebooks


In [2]:
# Set working directory to one level up
os.chdir('..')
print(os.getcwd())

d:\minorProject\WDSS


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
from config import device
import config
import os
import numpy as np
from utils.image_utils import ImageUtils

In [26]:
import os
from torch.utils.data import Dataset
from enum import Enum

class GBufferType(Enum):
    """Enum class for GBuffer types."""
    BASE_COLOR = 0
    BASE_COLOR_AA = 1
    METALLIC = 2
    MOTION_VECTOR = 3
    NOV = 4
    POST_TONEMAP_HDR_COLOR = 5
    SCENE_DEPTH = 6
    WORLD_NORMAL = 7

class WDSSdataset(Dataset):
    "Dataset class for the WDSS dataset."
    def __init__(self, settings):
        self.settings = settings
        self.data = []
        self.high_res_path = self._get_file_paths('high_res')
        self.low_res_path = self._get_file_paths('low_res')
        self.all_g_buffer_path = self._get_file_paths('g_buffers')
        self.buffer_paths = self._group_g_buffers(60)

        print(f"Found {len(self.high_res_path)} high res images")
        print(f"Found {len(self.low_res_path)} low res images")
        print(f"Found {len(self.buffer_paths)} g buffer groups")

    def _get_file_paths(self, subfolder):
        """Retrieve file paths from a specific subfolder."""
        folder_path = os.path.join(self.settings.dataset_path, subfolder)
        return [os.path.join(folder_path, f) for f in os.listdir(folder_path)]

    def _group_g_buffers(self, group_size):
        """Group g_buffers into lists of specified group size."""
        buffer_groups = []
        num_groups = len(self.all_g_buffer_path) // group_size
        for i in range(group_size):
            buffer = [
                self.all_g_buffer_path[j] for j in range(i, len(self.all_g_buffer_path), group_size)
            ]
            buffer_groups.append(buffer)
        return buffer_groups

    def __len__(self):
        return len(self.high_res_path)
    
    def __getitem__(self, idx):
        # Load high-resolution and low-resolution images
        high_res = ImageUtils.load_exr_image_opencv(self.high_res_path[idx])
        low_res = ImageUtils.load_exr_image_opencv(self.low_res_path[idx])
                    
        # Permute dimensions to CHW (if the loaded images are in HWC format)
        high_res = high_res.transpose(2, 0, 1)  # HWC -> CHW
        low_res = low_res.transpose(2, 0, 1)    # HWC -> CHW

        # Load g_buffers and permute dimensions to CHW
        g_buffers = {
            g_buffer.name.lower(): ImageUtils.load_exr_image_opencv(self.buffer_paths[idx][g_buffer.value]).transpose(2, 0, 1)
            for g_buffer in GBufferType
        }
        
        # Create a sample dictionary
        sample = {
            'high_res': high_res,
            'low_res': low_res,
            'g_buffers': g_buffers
        }
        
        return sample

    


In [33]:
setting = config.Settings()
dataset = WDSSdataset(settings=setting)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

Found 60 high res images
Found 61 low res images
Found 60 g buffer groups


In [37]:
for i_batch, sample_batched in enumerate(dataloader):
    print(f"Batch {i_batch}")
    print(f"High-res: {sample_batched['high_res'].size()}")
    # Print the size of the first low-resolution image in the batch
    print(f"Low-res: {sample_batched['low_res'].size()}")
    # Print the size of the first g-buffer in the batch
    for key, value in sample_batched['g_buffers'].items():
        # Proper Padding
        print(f"{key}:  {value.size()}")
    break

Batch 0
High-res: torch.Size([10, 3, 720, 1280])
Low-res: torch.Size([10, 3, 360, 640])
base_color:  torch.Size([10, 3, 576, 1024])
base_color_aa:  torch.Size([10, 3, 576, 1024])
metallic:  torch.Size([10, 3, 576, 1024])
motion_vector:  torch.Size([10, 3, 576, 1024])
nov:  torch.Size([10, 3, 576, 1024])
post_tonemap_hdr_color:  torch.Size([10, 3, 576, 1024])
scene_depth:  torch.Size([10, 3, 576, 1024])
world_normal:  torch.Size([10, 3, 576, 1024])
