## Data Augmentation Demo

In [26]:
import torch
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np
import os
from pathlib import Path
import argparse


class AlexNetAugmentation:
    """
    Implements AlexNet-style data augmentation:
    1. Resize to 256x256
    2. Extract 224x224 patches (5 locations: 4 corners + center)
    3. Horizontal flips (doubles to 10 versions)
    4. PCA color augmentation
    """
    
    def __init__(self, pca_eigenvectors=None, pca_eigenvalues=None, pca_std=0.1):
        self.resize_size = 256
        self.crop_size = 224
        self.pca_eigenvectors = pca_eigenvectors
        self.pca_eigenvalues = pca_eigenvalues
        self.pca_std = pca_std  # Standard deviation for color augmentation
        
    def compute_pca_from_images(self, image_paths, sample_size=1000):
        """
        Compute PCA on RGB pixel values from a sample of images.
        """
        print(f"Computing PCA from {min(sample_size, len(image_paths))} images...")
        
        pixels = []
        sample_paths = np.random.choice(image_paths, 
                                       min(sample_size, len(image_paths)), 
                                       replace=False)
        
        for img_path in sample_paths:
            try:
                img = Image.open(img_path).convert('RGB')
                img_array = np.array(img).reshape(-1, 3) / 255.0
                pixels.append(img_array)
            except Exception as e:
                print(f"Error loading {img_path}: {e}")
                continue
        
        # Combine all pixels
        all_pixels = np.vstack(pixels)
        
        # Compute covariance matrix
        cov_matrix = np.cov(all_pixels.T)
        
        # Compute eigenvalues and eigenvectors
        eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)
        
        # Sort by eigenvalue magnitude
        idx = eigenvalues.argsort()[::-1]
        eigenvalues = eigenvalues[idx]
        eigenvectors = eigenvectors[:, idx]
        
        self.pca_eigenvalues = eigenvalues.real
        self.pca_eigenvectors = eigenvectors.real
        
        print(f"PCA computed. Eigenvalues: {self.pca_eigenvalues}")
        
    def apply_pca_augmentation(self, img_tensor):
        """
        Apply PCA color augmentation to image tensor.
        img_tensor: CHW format, range [0, 1]
        """
        if self.pca_eigenvectors is None or self.pca_eigenvalues is None:
            return img_tensor
        
        # Generate random alphas from N(0, pca_std)
        alphas = np.random.normal(0, self.pca_std, 3)
        
        # Compute the color offset
        delta = np.dot(self.pca_eigenvectors, 
                      alphas * self.pca_eigenvalues).astype(np.float32)
        
        # Debug output
        if np.random.random() < 0.01:  # Print occasionally
            print(f"  PCA augmentation - alphas: {alphas}, delta: {delta}")
        
        # Add to each pixel (broadcast across spatial dimensions)
        img_array = img_tensor.numpy()
        for c in range(3):
            img_array[c] += delta[c]
        
        # Clip to valid range
        img_array = np.clip(img_array, 0, 1)
        
        return torch.from_numpy(img_array)
    
    def get_five_crops_and_flips(self, img_tensor):
        """
        Extract 5 crops (4 corners + center) and their horizontal flips.
        img_tensor: already resized to 256x256 and converted to tensor
        Returns 10 image tensors.
        """
        crops = []
        
        # Top-left
        crops.append(F.crop(img_tensor, 0, 0, self.crop_size, self.crop_size))
        
        # Top-right
        crops.append(F.crop(img_tensor, 0, self.resize_size - self.crop_size, 
                           self.crop_size, self.crop_size))
        
        # Bottom-left
        crops.append(F.crop(img_tensor, self.resize_size - self.crop_size, 0, 
                           self.crop_size, self.crop_size))
        
        # Bottom-right
        crops.append(F.crop(img_tensor, self.resize_size - self.crop_size, 
                           self.resize_size - self.crop_size, 
                           self.crop_size, self.crop_size))
        
        # Center
        center_offset = (self.resize_size - self.crop_size) // 2
        crops.append(F.crop(img_tensor, center_offset, center_offset, 
                           self.crop_size, self.crop_size))
        
        # Add horizontal flips
        augmented = []
        for crop in crops:
            augmented.append(crop)  # Original
            augmented.append(F.hflip(crop))  # Flipped
        
        return augmented
    
    def augment_image(self, img_path):
        """
        Apply full augmentation pipeline to a single image.
        Returns 10 augmented versions, each with independent color augmentation.
        """
        img = Image.open(img_path).convert('RGB')
        
        # Resize to 256x256
        img = F.resize(img, self.resize_size)
        
        # Convert to tensor
        img_tensor = F.to_tensor(img)
        
        all_augmented = []
        
        # Generate 10 augmented versions, each with independent color augmentation
        # Get 5 crop positions (will be doubled with flips)
        crop_positions = [
            (0, 0),  # Top-left
            (0, self.resize_size - self.crop_size),  # Top-right
            (self.resize_size - self.crop_size, 0),  # Bottom-left
            (self.resize_size - self.crop_size, self.resize_size - self.crop_size),  # Bottom-right
            ((self.resize_size - self.crop_size) // 2, (self.resize_size - self.crop_size) // 2)  # Center
        ]
        
        for top, left in crop_positions+crop_positions:
            for flip in [False, True]:
                # Apply independent color augmentation to the full image
                color_augmented = self.apply_pca_augmentation(img_tensor.clone())
                
                # Extract the crop
                crop = F.crop(color_augmented, top, left, self.crop_size, self.crop_size)
                
                # Apply horizontal flip if needed
                if flip:
                    crop = F.hflip(crop)
                
                all_augmented.append(crop)
        
        return all_augmented


def process_directory(input_dir, output_dir, compute_pca=True, sample_size=1000, pca_std=0.1):
    """
    Process all images in input directory and save augmented versions.
    """
    input_path = Path(input_dir)
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Get all image files
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    image_files = [f for f in input_path.rglob('*') 
                   if f.suffix.lower() in image_extensions]
    
    print(f"Found {len(image_files)} images in {input_dir}")
    
    # Initialize augmentation
    augmentor = AlexNetAugmentation(pca_std=pca_std)
    
    # Compute PCA if requested
    if compute_pca:
        augmentor.compute_pca_from_images([str(f) for f in image_files], 
                                         sample_size=sample_size)
    else:
        print("PCA color augmentation disabled")
    
    # Process each image
    for idx, img_path in enumerate(image_files):
        try:
            print(f"Processing {idx+1}/{len(image_files)}: {img_path.name}")
            
            # Get augmented versions (10 per image, each with independent color aug)
            augmented_images = augmentor.augment_image(str(img_path))
            
            # Save each augmented version
            stem = img_path.stem
            for aug_idx, aug_img in enumerate(augmented_images):
                # Convert tensor back to PIL Image
                aug_img_pil = F.to_pil_image(aug_img)
                
                # Save with descriptive filename
                output_filename = f"{stem}_aug{aug_idx:02d}.png"
                output_filepath = output_path / output_filename
                aug_img_pil.save(output_filepath)
            
            print(f"  Saved {len(augmented_images)} augmented versions")
            
        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            continue
    
    print(f"\nDone! Augmented images saved to {output_dir}")


# if __name__ == "__main__":
#     parser = argparse.ArgumentParser(
#         description='AlexNet-style data augmentation pipeline'
#     )
#     parser.add_argument('input_dir', type=str, 
#                        help='Directory containing input images')
#     parser.add_argument('output_dir', type=str, 
#                        help='Directory to save augmented images')
#     parser.add_argument('--no-pca', action='store_true',
#                        help='Disable PCA color augmentation')
#     parser.add_argument('--pca-samples', type=int, default=1000,
#                        help='Number of images to sample for PCA computation (default: 1000)')
#     parser.add_argument('--pca-std', type=float, default=0.1,
#                        help='Standard deviation for PCA color augmentation (default: 0.1, increase for stronger color variations)')
    
#     args = parser.parse_args()
    
#     process_directory(args.input_dir, args.output_dir, 
#                      compute_pca=not args.no_pca,
#                      sample_size=args.pca_samples,
#                      pca_std=args.pca_std)

In [27]:
input_dir='/Users/stephen/Stephencwelch Dropbox/welch_labs/double_descent/hackin/imagenet_favorites'
output_dir='/Users/stephen/Stephencwelch Dropbox/welch_labs/double_descent/hackin/imagenet_favorites_aug'

process_directory(input_dir, output_dir, compute_pca=True, pca_std=1.2)

Found 12 images in /Users/stephen/Stephencwelch Dropbox/welch_labs/double_descent/hackin/imagenet_favorites
Computing PCA from 12 images...
PCA computed. Eigenvalues: [0.19004357 0.0330579  0.00299923]
Processing 1/12: n02099601_185.JPEG
  Saved 20 augmented versions
Processing 2/12: n02123045_474.JPEG
  Saved 20 augmented versions
Processing 3/12: n02123045_507.JPEG
  PCA augmentation - alphas: [ 0.33379083  0.27998214 -0.15854916], delta: [-0.0444839  -0.03545243 -0.02956644]
  Saved 20 augmented versions
Processing 4/12: n02687172_1299.JPEG
  Saved 20 augmented versions
Processing 5/12: n01751748_1216.JPEG
  PCA augmentation - alphas: [-0.37019333 -0.70501712 -0.30939888], delta: [0.05817369 0.03969678 0.02309894]
  Saved 20 augmented versions
Processing 6/12: n02123045_1462.JPEG
  Saved 20 augmented versions
Processing 7/12: n02099601_1299.JPEG
  Saved 20 augmented versions
Processing 8/12: n02687172_137.JPEG
  Saved 20 augmented versions
Processing 9/12: n01751748_43.JPEG
  Saved 