In [None]:
import torch
import random
import numpy as np
from PIL import Image, ImageFilter
import torchvision.transforms.functional as F
import torchvision.transforms as T

class CustomAugment:
    def __init__(self, size=(224, 224), p_flip=0.5, p_vflip=0.2, p_blur=0.3, p_saltpepper=0.3):
        self.size = size
        self.p_flip = p_flip
        self.p_vflip = p_vflip
        self.p_blur = p_blur
        self.p_saltpepper = p_saltpepper

        # Transform helpers
        self.color_jitter = T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
        
    def __call__(self, img: Image.Image):
        # Ensure PIL format
        if not isinstance(img, Image.Image):
            img = F.to_pil_image(img)

        # 1. Random Horizontal Flip
        if random.random() < self.p_flip:
            img = F.hflip(img)

        # 2. Random Vertical Flip
        if random.random() < self.p_vflip:
            img = F.vflip(img)

        # 3. Random Translation + Rotation (affine)
        max_translate = 0.1  # 10% shift
        translate = (random.uniform(-max_translate, max_translate) * img.width,
                     random.uniform(-max_translate, max_translate) * img.height)
        angle = random.uniform(-15, 15)  # rotation degrees
        scale = random.uniform(0.9, 1.1)
        shear = random.uniform(-5, 5)
        img = F.affine(img, angle=angle, translate=translate, scale=scale, shear=shear)

        # 4. Reflective Padding before resize
        img = F.pad(img, padding=10, padding_mode='reflect')

        # 5. Color Jitter
        img = self.color_jitter(img)

        # 6. Gaussian Blur
        if random.random() < self.p_blur:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))

        # 7. Resize to final size
        img = F.resize(img, self.size)

        # 8. Salt & Pepper Noise
        if random.random() < self.p_saltpepper:
            img = self.add_salt_pepper(img, amount=0.005)

        return img

    def add_salt_pepper(self, img: Image.Image, amount=0.005):
        np_img = np.array(img)
        num_salt = np.ceil(amount * np_img.size * 0.5)
        num_pepper = np.ceil(amount * np_img.size * 0.5)

        # Salt (white pixels)
        coords = [np.random.randint(0, i - 1, int(num_salt)) for i in np_img.shape[:2]]
        np_img[coords[0], coords[1]] = 255

        # Pepper (black pixels)
        coords = [np.random.randint(0, i - 1, int(num_pepper)) for i in np_img.shape[:2]]
        np_img[coords[0], coords[1]] = 0

        return Image.fromarray(np_img.astype(np.uint8))

# --------------------
# Usage in pipeline
# --------------------
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transforms = T.Compose([
    CustomAugment(size=(224, 224)),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

val_transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])
