In [6]:
import os
import cv2
import numpy as np
from PIL import Image, ImageEnhance
import random
from pathlib import Path


class DataAugmentator:
    def __init__(self, input_folder="data_set", output_folder="data_aumentacion"):
        self.input_folder = input_folder.rstrip("/\\")
        self.output_folder = output_folder.rstrip("/\\")
        self.supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']

        os.makedirs(self.output_folder, exist_ok=True)
        print(f"Output folder created: {self.output_folder}")

    # ------------------ AUGMENTATION FUNCTIONS ------------------

    def flip_rotate(self, image):
        if random.random() > 0.5:
            image = cv2.flip(image, 1)
        if random.random() > 0.5:
            image = cv2.flip(image, 0)

        angle = random.choice([0, 90, 180, 270])
        if angle != 0:
            h, w = image.shape[:2]
            center = (w // 2, h // 2)
            M = cv2.getRotationMatrix2D(center, angle, 1.0)
            image = cv2.warpAffine(image, M, (w, h))
        return image
    
    def color_jittering(self, image):
        pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

        pil_image = ImageEnhance.Brightness(pil_image).enhance(random.uniform(0.7, 1.3))
        pil_image = ImageEnhance.Contrast(pil_image).enhance(random.uniform(0.8, 1.2))
        pil_image = ImageEnhance.Color(pil_image).enhance(random.uniform(0.8, 1.2))

        return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

    def pixelation(self, image):
        h, w = image.shape[:2]
        factor = random.randint(2, 8)
        small = cv2.resize(image, (w//factor, h//factor), interpolation=cv2.INTER_LINEAR)
        return cv2.resize(small, (w, h), interpolation=cv2.INTER_NEAREST)

    def random_cropping(self, image):
        h, w = image.shape[:2]
        crop_factor = random.uniform(0.7, 0.95)
        ch, cw = int(h*crop_factor), int(w*crop_factor)

        sh = random.randint(0, h-ch)
        sw = random.randint(0, w-cw)

        crop = image[sh:sh+ch, sw:sw+cw]
        return cv2.resize(crop, (w, h))

    def color_inversion(self, image):
        variant = random.choice(['invert', 'hue_shift', 'channel_shift'])

        try:
            if variant == 'invert':
                return 255 - image

            elif variant == 'hue_shift':
                hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32)
                hsv[:, :, 0] = (hsv[:, :, 0] + random.randint(-30, 30)) % 180
                return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR)

            else:
                b, g, r = cv2.split(image)
                channels = [b, g, r]
                random.shuffle(channels)
                return cv2.merge(channels)

        except:
            return image

    def width_height_shift(self, image):
        h, w = image.shape[:2]
        nw, nh = int(w*random.uniform(0.8,1.2)), int(h*random.uniform(0.8,1.2))
        stretched = cv2.resize(image, (nw, nh))
        return cv2.resize(stretched, (w, h))

    # ------------------ APPLY ------------------

    def apply_augmentation(self, image, variant):
        return {
            "flip_rotate": self.flip_rotate,
            "color_jittering": self.color_jittering,
            "pixelation": self.pixelation,
            "random_cropping": self.random_cropping,
            "color_inversion": self.color_inversion,
            "width_height_shift": self.width_height_shift
        }[variant](image)

    # ------------------ MAIN PROCESS ------------------

    def process(self, copies_per_variant=3):

        variants = [
            "flip_rotate",
            "color_jittering",
            "pixelation",
            "random_cropping",
            "color_inversion",
            "width_height_shift"
        ]

        # Obtener imágenes
        images = []
        for ext in self.supported_formats:
            images.extend(Path(self.input_folder).glob(f"*{ext}"))
            images.extend(Path(self.input_folder).glob(f"*{ext.upper()}"))

        print(f"Found {len(images)} images in {self.input_folder}")

        # Copiar originales
        for img in images:
            original_path = Path(self.output_folder) / img.name
            shutil.copy2(img, original_path)

        # Generar aumentaciones
        for img in images:
            image = cv2.imread(str(img))
            if image is None:
                continue

            for variant in variants:
                for i in range(copies_per_variant):
                    aug = self.apply_augmentation(image, variant)
                    name = f"{img.stem}_{variant}_{i+1}{img.suffix}"
                    cv2.imwrite(str(Path(self.output_folder) / name), aug)

        print(f"\n✓ Proceso terminado.")
        print(f"Todas las imágenes están en: {self.output_folder}")


# ------------------ USO FINAL ------------------

augmentor = DataAugmentator(
    input_folder="data_set",
    output_folder="data_aumentacion"
)

augmentor.process(copies_per_variant=3)


Output folder created: data_aumentacion
Found 12 images in data_set

✓ Proceso terminado.
Todas las imágenes están en: data_aumentacion
