In [None]:
from torchvision import transforms
import random
import torch
import numpy as np

# Define las transformaciones
resize = transforms.Resize((112, 112))
horizontal_flip = transforms.RandomHorizontalFlip(p=0.5)
rotation = transforms.RandomRotation(degrees=10)
color_jitter = transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1)
random_crop = transforms.RandomResizedCrop(size=(112, 112), scale=(0.9, 1.0), ratio=(0.9, 1.1))
to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Función para aplicar transformaciones
def apply_transform_list(imgs, train=True):
    # Seed para asegurarse de que las transformaciones aleatorias sean consistentes
    seed = np.random.randint(2147483647)
    random.seed(seed)
    torch.manual_seed(seed)

    # Generar parámetros aleatorios para las transformaciones
    params = {
        'horizontal_flip': random.random(),
        'rotation': random.uniform(-10, 10),
        'brightness': random.uniform(0.9, 1.1),
        'contrast': random.uniform(0.9, 1.1),
        'saturation': random.uniform(0.9, 1.1),
        'hue': random.uniform(-0.1, 0.1),
        'crop_params': random_crop.get_params(resize(imgs[0]), scale=(0.9, 1.0), ratio=(0.9, 1.1))
    }

    new_imgs = []

    for img in imgs:
        img = resize(img)
        
        if train:
            # Aplicar transformaciones solo en entrenamiento
            if params['horizontal_flip'] < 0.5:
                img = transforms.functional.hflip(img)
            img = transforms.functional.rotate(img, params['rotation'])
            img = transforms.functional.adjust_brightness(img, params['brightness'])
            img = transforms.functional.adjust_contrast(img, params['contrast'])
            img = transforms.functional.adjust_saturation(img, params['saturation'])
            img = transforms.functional.adjust_hue(img, params['hue'])
            img = transforms.functional.resized_crop(img, *params['crop_params'], size=(112, 112))

        img = to_tensor(img)
        img = normalize(img)

        new_imgs.append(img)
    
    return new_imgs

In [1]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
import glob
import random

class FireSeriesDataset(Dataset):
    def __init__(self, root_dir, img_size=112, transform=None, train=True):
        self.transform = transform
        self.sets = glob.glob(f"{root_dir}/**/*")
        self.img_size = img_size
        self.train = train  # Para diferenciar entre entrenamiento y validación/prueba
        random.shuffle(self.sets)

    def __len__(self):
        return len(self.sets)

    def __getitem__(self, idx):
        img_folder = self.sets[idx]
        img_list = glob.glob(f"{img_folder}/*.jpg")

        # Cargar todas las imágenes disponibles
        images = [Image.open(file) for file in img_list]

        # Redimensionar todas las imágenes al tamaño img_size
        img_list = [im.resize((self.img_size, self.img_size)) for im in images]

        # Aplicar las transformaciones a las imágenes (solo si train=True)
        tensor_list = apply_transform_list(img_list, self.train)

        # Concatenar todas las imágenes en un solo tensor
        return torch.cat(tensor_list, dim=0), int(img_folder.split("/")[-2])

ModuleNotFoundError: No module named 'custom_tf'