## Environment Setting

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from pathlib import Path
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode as IM


ModuleNotFoundError: No module named 'torch'

## DataLoader

### Load Image

In [None]:
class MyDataset(Dataset):
    def __init__(self, 
                 root, 
                 if_train = True,
                 data_dir = "train",
                 color_dir = "color",
                 depth_dir = "depth_raw",
                 rgb_name = "rgb.png",
                 depth_name="depth_raw.png",
                 csv_name = "nutrition5k_train.csv",
                 transform = False,
                 ):
        self.root = Path(root)
        self.data = self.root / data_dir
        self.color_dir  = self.data / color_dir
        self.depth_dir = self.data / depth_dir
        self.rgb_name = rgb_name
        self.depth_name = depth_name
        self.if_train = if_train
        self.transform = transform

        df = pd.read_csv(self.root / csv_name)
        self.id2cal = {str(r["ID"]) : float(r["Value"]) for _, r in df.iterrow()}

        rgb_paths = sorted((p / rgb_name for p in self.color_dir.glob("dish_*") if (p / rgb_name).exists()))
        if not rgb_paths:
            raise RuntimeError(f"Found 0 images in {self.color_dir}.")
        
        self.samples = []
        if self.if_train:
            for rgb_path in rgb_paths:
                dish_id = rgb_path.parent.name
                if dish_id not in self.id2cal:
                    print(f"Warning: {dish_id} not found in CSV.")
                depth_path = None
                if self.depth_dir is not None:
                    depth_path = self.depth_dir / dish_id / self.depth_name

            self.samples.append((rgb_path, depth_path, self.id2cal[dish_id]))
        else:
            for rgb_path in rgb_paths:
                dish_id = rgb_path.parent.name
                depth_path = None
                if self.depth_dir is not None:
                    depth_path = self.depth_dir / dish_id / self.depth_name

                self.samples.append((rgb_path, depth_path, None))


    def __len__(self):
            return len(self.samples)
    
    def __getitem__(self, idx):
        rgb_path, depth_path, cal = self.samples[idx]
        rgb = Image.open(rgb_path).convert("RGB")
        depth = Image.open(depth_path).convert("L") if depth_path and depth_path.exists() else None

        if self.if_train == True:
            cal = torch.tensor(cal, dtype=torch.float32)
            return rgb, depth, cal
        else:
            return rgb, depth
            
        

In [None]:
rgbd = MyDataset(root="")