In [None]:
import torch
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
class AdvancedImagePreprocessor:
    def __init__(self, image_size=(384, 384)):
        self.image_size = image_size
        self.setup_transforms()
    
    def setup_transforms(self):
        """Setup various image transformation pipelines"""
        
        # Basic transforms
        self.basic_transforms = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                               std=[0.229, 0.224, 0.225])
        ])
        
        # Advanced augmentations using Albumentations
        self.train_transforms = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.3),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.4),
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
            A.OneOf([
                A.GaussianBlur(blur_limit=(3, 7), p=0.3),
                A.MotionBlur(blur_limit=7, p=0.3),
                A.MedianBlur(blur_limit=7, p=0.3)
            ], p=0.2),
            A.OneOf([
                A.OpticalDistortion(distort_limit=0.1, p=0.3),
                A.GridDistortion(distort_limit=0.1, p=0.3),
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.3)
            ], p=0.2),
            A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=0.2),
            A.RandomShadow(p=0.1),
            A.RandomSunFlare(flare_roi=(0, 0, 1, 0.5), p=0.1),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        # Validation transforms
        self.val_transforms = A.Compose([
            A.Resize(self.image_size[0], self.image_size[1]),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
        
        # SAM preprocessing (for Segment Anything Model)
        self.sam_transforms = A.Compose([
            A.LongestMaxSize(max_size=1024),
            A.PadIfNeeded(min_height=1024, min_width=1024, 
                         border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    
    def apply_mixup(self, images, labels, alpha=0.2):
        """Apply MixUp augmentation"""
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1
        
        batch_size = images.size(0)
        index = torch.randperm(batch_size)
        
        mixed_images = lam * images + (1 - lam) * images[index, :]
        y_a, y_b = labels, labels[index]
        
        return mixed_images, y_a, y_b, lam
    
    def apply_cutmix(self, images, labels, alpha=1.0):
        """Apply CutMix augmentation"""
        lam = np.random.beta(alpha, alpha)
        batch_size = images.size(0)
        index = torch.randperm(batch_size)
        
        y_a, y_b = labels, labels[index]
        bbx1, bby1, bbx2, bby2 = self.rand_bbox(images.size(), lam)
        
        images[:, :, bbx1:bbx2, bby1:bby2] = images[index, :, bbx1:bbx2, bby1:bby2]
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))
        
        return images, y_a, y_b, lam
    
    def rand_bbox(self, size, lam):
        """Generate random bounding box for CutMix"""
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        return bbx1, bby1, bbx2, bby2
    
    def generate_synthetic_data(self, num_samples=100):
        """Generate synthetic image data for testing"""
        synthetic_images = []
        synthetic_labels = []
        
        for i in range(num_samples):
            # Create random colored rectangles
            img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            
            # Add some geometric shapes
            cv2.rectangle(img, (50, 50), (150, 150), 
                         (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)), -1)
            cv2.circle(img, (100, 100), 30, 
                      (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)), -1)
            
            synthetic_images.append(img)
            synthetic_labels.append(i % 10)  # 10 classes
        
        return synthetic_images, synthetic_labels
    
    def visualize_augmentations(self, image_path=None):
        """Visualize different augmentation techniques"""
        if image_path is None:
            # Create a sample image
            image = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            cv2.rectangle(image, (50, 50), (150, 150), (255, 0, 0), -1)
            cv2.circle(image, (100, 100), 30, (0, 255, 0), -1)
        else:
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        # Apply different transformations
        augmentations = [
            ("Original", lambda x: x),
            ("Horizontal Flip", A.HorizontalFlip(p=1.0)),
            ("Random Brightness", A.RandomBrightnessContrast(brightness_limit=0.3, p=1.0)),
            ("Gaussian Blur", A.GaussianBlur(blur_limit=(5, 5), p=1.0)),
            ("Rotation", A.Rotate(limit=45, p=1.0)),
            ("Elastic Transform", A.ElasticTransform(alpha=50, sigma=5, p=1.0))
        ]
        
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        axes = axes.flatten()
        
        for i, (name, aug) in enumerate(augmentations):
            if callable(aug) and name != "Original":
                augmented = aug(image=image)['image']
            else:
                augmented = image
            
            axes[i].imshow(augmented)
            axes[i].set_title(name)
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
def main():
    preprocessor = AdvancedImagePreprocessor()
    
    print("Image Preprocessing Pipeline Initialized")
    print("=" * 50)
    
    # Generate synthetic data for demonstration
    print("Generating synthetic image data...")
    images, labels = preprocessor.generate_synthetic_data(10)
    print(f"Generated {len(images)} synthetic images")
    
    # Test transformations
    print("\nTesting image transformations...")
    sample_image = images[0]
    
    # Convert to PIL for torchvision transforms
    pil_image = Image.fromarray(sample_image)
    basic_transformed = preprocessor.basic_transforms(pil_image)
    print(f"Basic transform output shape: {basic_transformed.shape}")
    
    # Test albumentations
    train_transformed = preprocessor.train_transforms(image=sample_image)['image']
    print(f"Advanced transform output shape: {train_transformed.shape}")
    
    # Visualize augmentations
    print("\nVisualizing augmentation techniques...")
    preprocessor.visualize_augmentations()
    
    # Test batch augmentations
    print("\nTesting batch augmentations...")
    batch_images = torch.randn(4, 3, 224, 224)
    batch_labels = torch.tensor([0, 1, 2, 3])
    
    # MixUp
    mixed_images, y_a, y_b, lam = preprocessor.apply_mixup(batch_images, batch_labels)
    print(f"MixUp - Lambda: {lam:.3f}, Output shape: {mixed_images.shape}")
    
    # CutMix
    cut_images, y_a, y_b, lam = preprocessor.apply_cutmix(batch_images, batch_labels)
    print(f"CutMix - Lambda: {lam:.3f}, Output shape: {cut_images.shape}")
    
    print("\n✓ Image preprocessing pipeline tested successfully!")

main()