In [1]:
! pip install torch
! pip install torchvision



In [2]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image

In [3]:
class CustomDataset(Dataset):
    def __init__(self, drawn_folder, normal_folder, transform=None):
        self.drawn_images = [os.path.join(drawn_folder, f) for f in os.listdir(drawn_folder) if f.endswith('.jpg')]
        self.normal_images = [os.path.join(normal_folder, f) for f in os.listdir(normal_folder) if f.endswith('.jpg')]
        self.transform = transform

    def __len__(self):
        return len(self.drawn_images)

    def __getitem__(self, idx):
        drawn_img = Image.open(self.drawn_images[idx])
        normal_img = Image.open(self.normal_images[idx])

        # Normaliser les images
        drawn_img = drawn_img.resize((256, 256))
        normal_img = normal_img.resize((256, 256))

        # Appliquer les transformations si elles existent
        if self.transform:
            drawn_img = self.transform(drawn_img)
            normal_img = self.transform(normal_img)

        return {'drawn_img': drawn_img, 'normal_img': normal_img}

# Utiliser le dataset
drawn_folder = '/Users/macbookpro/Downloads/all-in-one/sketch-rendered/width-5'
normal_folder = '/Users/macbookpro/Downloads/all-in-one/sketch'
transform = transforms.Compose([transforms.ToTensor()])
dataset = CustomDataset(drawn_folder, normal_folder, transform)
