In [3]:
import os
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from skimage.color import rgb2lab
from PIL import Image

# Ruta base al dataset (ajustala según tu entorno)
DATA_DIR = Path("/imagewoof2-160")  # ejemplo: /content/imagewoof2-320

# --- Definimos una clase Dataset personalizada ---
class ImagewoofColorizationDataset(Dataset):
    def __init__(self, root_dir, img_size=224, split="train"):
        self.root_dir = Path(root_dir) / split
        self.img_paths = list(self.root_dir.rglob("*.JPEG"))
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert("RGB")
        img = np.array(self.transform(img)).transpose(1, 2, 0)  # HWC

        # Convertir RGB → Lab
        lab = rgb2lab(img).astype("float32")
        L = lab[:, :, 0] / 100.0                     # Normalizamos 0–1
        ab = lab[:, :, 1:] / 128.0                   # Normalizamos -1–1 aprox.

        # Convertimos a tensores
        L = torch.from_numpy(L).unsqueeze(0)         # (1, H, W)
        ab = torch.from_numpy(ab).permute(2, 0, 1)   # (2, H, W)
        return L, ab

# --- Crear datasets ---
train_dataset = ImagewoofColorizationDataset(DATA_DIR, split="train")
val_dataset = ImagewoofColorizationDataset(DATA_DIR, split="val")

# --- Crear DataLoaders ---
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# --- Probar una muestra ---
L, ab = next(iter(train_loader))
print("Shape canal L:", L.shape)
print("Shape canales ab:", ab.shape)


ValueError: num_samples should be a positive integer value, but got num_samples=0