# Ring Dataset Augmentation for CNN Training

This notebook augments the ring images using various transformations:
- Random rotations (90¬∞, 180¬∞, 270¬∞)
- Horizontal and vertical flipping
- Upside-down images (180¬∞ rotation)
- Random scaling and cropping
- Color adjustments (brightness, contrast)
- Minor distortions (shear, zoom, affine transformations)

In [None]:
import os
import random
from PIL import Image, ImageEnhance, ImageOps, ImageFilter
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

In [None]:
# Configuration
INPUT_FOLDER = r"ring"  # Source folder with original images
OUTPUT_FOLDER = r"ring_augmented"  # Output folder for augmented images
AUGMENTATIONS_PER_IMAGE = 10  # Number of augmented versions per original image

# Create output folder if it doesn't exist
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
print(f"Output folder created/verified: {OUTPUT_FOLDER}")

In [None]:
# List all images in the input folder
valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp')
image_files = [f for f in os.listdir(INPUT_FOLDER) if f.lower().endswith(valid_extensions)]
print(f"Found {len(image_files)} images in '{INPUT_FOLDER}' folder")
print(f"Sample files: {image_files[:5]}")

## Augmentation Functions

In [None]:
def rotate_image(img, angle):
    """Rotate image by specified angle (90, 180, 270 degrees)"""
    return img.rotate(angle, expand=True, fillcolor=(255, 255, 255))

def flip_horizontal(img):
    """Flip image horizontally (mirror)"""
    return ImageOps.mirror(img)

def flip_vertical(img):
    """Flip image vertically"""
    return ImageOps.flip(img)

def upside_down(img):
    """Rotate image 180 degrees (upside down)"""
    return img.rotate(180, expand=True, fillcolor=(255, 255, 255))

def random_rotation(img):
    """Rotate image by a random angle between -30 and 30 degrees"""
    angle = random.uniform(-30, 30)
    return img.rotate(angle, expand=True, fillcolor=(255, 255, 255))

def adjust_brightness(img, factor=None):
    """Adjust image brightness (factor: 0.5-1.5)"""
    if factor is None:
        factor = random.uniform(0.6, 1.4)
    enhancer = ImageEnhance.Brightness(img)
    return enhancer.enhance(factor)

def adjust_contrast(img, factor=None):
    """Adjust image contrast (factor: 0.5-1.5)"""
    if factor is None:
        factor = random.uniform(0.7, 1.3)
    enhancer = ImageEnhance.Contrast(img)
    return enhancer.enhance(factor)

def adjust_saturation(img, factor=None):
    """Adjust image saturation/color intensity"""
    if factor is None:
        factor = random.uniform(0.7, 1.3)
    enhancer = ImageEnhance.Color(img)
    return enhancer.enhance(factor)

def adjust_sharpness(img, factor=None):
    """Adjust image sharpness"""
    if factor is None:
        factor = random.uniform(0.5, 2.0)
    enhancer = ImageEnhance.Sharpness(img)
    return enhancer.enhance(factor)

def random_crop(img, crop_percent=None):
    """Randomly crop the image and resize back to original size"""
    if crop_percent is None:
        crop_percent = random.uniform(0.75, 0.95)
    
    width, height = img.size
    new_width = int(width * crop_percent)
    new_height = int(height * crop_percent)
    
    left = random.randint(0, width - new_width)
    top = random.randint(0, height - new_height)
    right = left + new_width
    bottom = top + new_height
    
    cropped = img.crop((left, top, right, bottom))
    return cropped.resize((width, height), Image.Resampling.LANCZOS)

def random_scale(img, scale_range=(0.8, 1.2)):
    """Scale image randomly and pad/crop to original size"""
    scale = random.uniform(*scale_range)
    width, height = img.size
    new_width = int(width * scale)
    new_height = int(height * scale)
    
    scaled = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
    
    # Create new image with original size
    result = Image.new('RGB', (width, height), (255, 255, 255))
    
    if scale > 1:
        # Crop center
        left = (new_width - width) // 2
        top = (new_height - height) // 2
        scaled = scaled.crop((left, top, left + width, top + height))
        result.paste(scaled, (0, 0))
    else:
        # Paste in center
        left = (width - new_width) // 2
        top = (height - new_height) // 2
        result.paste(scaled, (left, top))
    
    return result

def affine_transform(img):
    """Apply random affine transformation (shear)"""
    width, height = img.size
    
    # Random shear factors
    shear_x = random.uniform(-0.2, 0.2)
    shear_y = random.uniform(-0.2, 0.2)
    
    # Affine transformation matrix coefficients
    # (a, b, c, d, e, f) where new_x = a*x + b*y + c, new_y = d*x + e*y + f
    coeffs = (
        1, shear_x, -shear_x * height / 2,
        shear_y, 1, -shear_y * width / 2
    )
    
    return img.transform(
        (width, height),
        Image.Transform.AFFINE,
        coeffs,
        resample=Image.Resampling.BILINEAR,
        fillcolor=(255, 255, 255)
    )

def zoom_transform(img, zoom_factor=None):
    """Zoom in/out of the image center"""
    if zoom_factor is None:
        zoom_factor = random.uniform(0.85, 1.15)
    
    width, height = img.size
    new_width = int(width / zoom_factor)
    new_height = int(height / zoom_factor)
    
    left = (width - new_width) // 2
    top = (height - new_height) // 2
    right = left + new_width
    bottom = top + new_height
    
    if zoom_factor > 1:
        # Zoom in - crop and resize
        cropped = img.crop((left, top, right, bottom))
        return cropped.resize((width, height), Image.Resampling.LANCZOS)
    else:
        # Zoom out - resize and pad
        scaled = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
        result = Image.new('RGB', (width, height), (255, 255, 255))
        paste_left = (width - new_width) // 2
        paste_top = (height - new_height) // 2
        result.paste(scaled, (paste_left, paste_top))
        return result

def add_gaussian_noise(img, intensity=0.05):
    """Add slight Gaussian noise to the image"""
    img_array = np.array(img).astype(np.float32)
    noise = np.random.normal(0, intensity * 255, img_array.shape)
    noisy = np.clip(img_array + noise, 0, 255).astype(np.uint8)
    return Image.fromarray(noisy)

def blur_image(img):
    """Apply slight blur to image"""
    return img.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.5, 1.5)))

print("All augmentation functions defined successfully!")

In [None]:
def get_augmentation_pipeline():
    """Returns a list of all available augmentation functions with names"""
    return [
        ('rot90', lambda img: rotate_image(img, 90)),
        ('rot180', lambda img: rotate_image(img, 180)),
        ('rot270', lambda img: rotate_image(img, 270)),
        ('flip_h', flip_horizontal),
        ('flip_v', flip_vertical),
        ('upside_down', upside_down),
        ('rand_rot', random_rotation),
        ('bright', adjust_brightness),
        ('contrast', adjust_contrast),
        ('saturation', adjust_saturation),
        ('sharp', adjust_sharpness),
        ('crop', random_crop),
        ('scale', random_scale),
        ('affine', affine_transform),
        ('zoom', zoom_transform),
        ('noise', add_gaussian_noise),
        ('blur', blur_image),
    ]

def apply_random_augmentations(img, num_augmentations=3):
    """Apply multiple random augmentations to an image"""
    pipeline = get_augmentation_pipeline()
    selected = random.sample(pipeline, min(num_augmentations, len(pipeline)))
    
    result = img.copy()
    applied = []
    
    for name, func in selected:
        try:
            result = func(result)
            applied.append(name)
        except Exception as e:
            print(f"Warning: {name} failed - {e}")
    
    return result, applied

print("Augmentation pipeline ready!")

## Preview Sample Augmentations

In [None]:
# Preview augmentations on a sample image
if image_files:
    sample_path = os.path.join(INPUT_FOLDER, image_files[0])
    sample_img = Image.open(sample_path).convert('RGB')
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.flatten()
    
    # Show original
    axes[0].imshow(sample_img)
    axes[0].set_title('Original')
    axes[0].axis('off')
    
    # Show various augmentations
    augmentations = [
        ('Rotate 90¬∞', lambda img: rotate_image(img, 90)),
        ('Rotate 180¬∞', lambda img: rotate_image(img, 180)),
        ('Flip Horizontal', flip_horizontal),
        ('Flip Vertical', flip_vertical),
        ('Brightness', adjust_brightness),
        ('Contrast', adjust_contrast),
        ('Random Crop', random_crop),
        ('Affine/Shear', affine_transform),
        ('Zoom', zoom_transform),
        ('Scale', random_scale),
        ('Blur', blur_image),
    ]
    
    for idx, (name, func) in enumerate(augmentations, 1):
        try:
            aug_img = func(sample_img.copy())
            axes[idx].imshow(aug_img)
            axes[idx].set_title(name)
        except Exception as e:
            axes[idx].set_title(f'{name} (Error)')
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.suptitle('Sample Augmentations Preview', y=1.02, fontsize=14)
    plt.show()
else:
    print("No images found to preview!")

## Run Full Augmentation

In [None]:
def augment_dataset(input_folder, output_folder, augmentations_per_image=10):
    """
    Process all images in the input folder and create augmented versions.
    
    Args:
        input_folder: Path to folder containing original images
        output_folder: Path to save augmented images
        augmentations_per_image: Number of augmented versions to create per image
    """
    os.makedirs(output_folder, exist_ok=True)
    
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp')
    image_files = [f for f in os.listdir(input_folder) if f.lower().endswith(valid_extensions)]
    
    total_images = len(image_files)
    total_augmented = 0
    failed = 0
    
    print(f"Starting augmentation of {total_images} images...")
    print(f"Each image will have {augmentations_per_image} augmented versions")
    print(f"Expected output: ~{total_images * augmentations_per_image} new images")
    print("-" * 50)
    
    for img_idx, filename in enumerate(image_files, 1):
        try:
            # Load image
            img_path = os.path.join(input_folder, filename)
            img = Image.open(img_path).convert('RGB')
            
            base_name = os.path.splitext(filename)[0]
            
            # Also copy original to output folder
            original_output = os.path.join(output_folder, f"{base_name}_original.jpg")
            img.save(original_output, 'JPEG', quality=95)
            
            # Create augmented versions
            for aug_idx in range(augmentations_per_image):
                # Apply 2-4 random augmentations combined
                num_transforms = random.randint(2, 4)
                aug_img, applied = apply_random_augmentations(img, num_transforms)
                
                # Save augmented image
                aug_name = f"{base_name}_aug{aug_idx+1:02d}.jpg"
                aug_path = os.path.join(output_folder, aug_name)
                aug_img.save(aug_path, 'JPEG', quality=95)
                total_augmented += 1
            
            # Progress update
            if img_idx % 20 == 0 or img_idx == total_images:
                print(f"Processed {img_idx}/{total_images} images ({total_augmented} augmented created)")
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            failed += 1
    
    print("-" * 50)
    print(f"\n‚úÖ Augmentation Complete!")
    print(f"   Original images: {total_images}")
    print(f"   Augmented images created: {total_augmented}")
    print(f"   Total images in output: {total_images + total_augmented}")
    print(f"   Failed: {failed}")
    print(f"   Output folder: {output_folder}")
    
    return total_images, total_augmented

In [None]:
# Run the augmentation!
original_count, augmented_count = augment_dataset(
    input_folder=INPUT_FOLDER,
    output_folder=OUTPUT_FOLDER,
    augmentations_per_image=AUGMENTATIONS_PER_IMAGE
)

In [None]:
# Verify output
output_files = os.listdir(OUTPUT_FOLDER)
print(f"\nTotal files in output folder: {len(output_files)}")
print(f"\nSample output files:")
for f in sorted(output_files)[:15]:
    print(f"  - {f}")

## Display Sample Augmented Results

In [None]:
# Show some random augmented samples
output_images = [f for f in os.listdir(OUTPUT_FOLDER) if f.endswith('.jpg')]

if output_images:
    sample_outputs = random.sample(output_images, min(12, len(output_images)))
    
    fig, axes = plt.subplots(3, 4, figsize=(16, 12))
    axes = axes.flatten()
    
    for idx, filename in enumerate(sample_outputs):
        img_path = os.path.join(OUTPUT_FOLDER, filename)
        img = Image.open(img_path)
        axes[idx].imshow(img)
        axes[idx].set_title(filename[:25] + '...' if len(filename) > 25 else filename, fontsize=8)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.suptitle('Sample Augmented Images', y=1.02, fontsize=14)
    plt.show()
else:
    print("No augmented images found!")

## Summary Statistics

In [None]:
# Final summary
original_files = [f for f in os.listdir(INPUT_FOLDER) if f.lower().endswith(valid_extensions)]
augmented_files = [f for f in os.listdir(OUTPUT_FOLDER) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

print("=" * 50)
print("DATASET AUGMENTATION SUMMARY")
print("=" * 50)
print(f"üìÅ Original folder: {INPUT_FOLDER}")
print(f"   - Images: {len(original_files)}")
print()
print(f"üìÅ Augmented folder: {OUTPUT_FOLDER}")
print(f"   - Total images: {len(augmented_files)}")
print()
print(f"üìà Dataset expansion: {len(original_files)} ‚Üí {len(augmented_files)}")
print(f"   ({len(augmented_files) / max(len(original_files), 1):.1f}x increase)")
print("=" * 50)
print("\n‚úÖ Your augmented dataset is ready for CNN training!")