##  augmentation of grouped PNGs in folder, save origins and results in one new folder

In [26]:
import os
import csv
from PIL import Image, ImageEnhance
import numpy as np
import random
import shutil

# Define augmentation functions
def flip_image(img, mode='horizontal'):
    # print(img.path)
    """
    Flip the image in the specified mode.
    Modes:
    - 'horizontal': Flip left to right (default).
    - 'vertical': Flip top to bottom.
    - 'diagonal': Flip both horizontally and vertically.
    """
    if mode == 'horizontal':
        # print("horizontal")
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    elif mode == 'vertical':
        # print("vertical")
        return img.transpose(Image.FLIP_TOP_BOTTOM)
    elif mode == 'diagonal':
        # print("diagonal")
        img = img.transpose(Image.FLIP_LEFT_RIGHT).transpose(Image.FLIP_TOP_BOTTOM)
        return img
    else:
        raise ValueError("Invalid mode. Use 'horizontal', 'vertical', or 'diagonal'.")

def rotate_and_zoom_image(img, angle=15, zoom_factor=1.2):
    """Rotate the image by a few degrees and adjust zoom to maintain proportions."""
    width, height = img.size
    img_rotated = img.rotate(angle, expand=True)
    rotated_width, rotated_height = img_rotated.size

    # Compute crop dimensions to maintain proportions after rotation
    x = int(rotated_width / zoom_factor)
    y = int(rotated_height / zoom_factor)
    img_cropped = img_rotated.crop(((rotated_width - x) // 2, (rotated_height - y) // 2, (rotated_width + x) // 2, (rotated_height + y) // 2))
    return img_cropped.resize((width, height), Image.LANCZOS)

def adjust_brightness(img, factor=1.5):
    """Adjust the brightness of the image."""
    enhancer = ImageEnhance.Brightness(img)
    return enhancer.enhance(factor)

def add_noise(img, noise_level=0.05):
    """Add random noise to the image."""
    img_array = np.asarray(img).astype(np.float32) / 255.0
    noise = np.random.normal(0, noise_level, img_array.shape)
    noisy_img_array = np.clip(img_array + noise, 0, 1) * 255
    noisy_img = Image.fromarray(noisy_img_array.astype(np.uint8))
    return noisy_img

def random_augmentation(img, used_augmentations):
    # Define available augmentations with descriptive names
    augmentations = {
        'horizontal': lambda img: flip_image(img, mode='horizontal'),
        'vertical': lambda img: flip_image(img, mode='vertical'),
        'diagonal': lambda img: flip_image(img, mode='diagonal'),
        # 'rotate_zoom': lambda img: rotate_and_zoom_image(img, angle=8, zoom_factor=1.5),
        # 'brightness': lambda img: adjust_brightness(img, factor=1.5),
        # 'noise': lambda img: add_noise(img, noise_level=0.15)
    }

    # Filter out already used augmentations by name
    available_augmentations = {name: func for name, func in augmentations.items() if name not in used_augmentations}

    if not available_augmentations:
        raise ValueError("No more unique augmentations available.")

    # Randomly select an augmentation from the remaining ones
    selected_augmentation = random.choice(list(available_augmentations.keys()))
    # print(f"Selected augmentation: {selected_augmentation}")

    # Apply the selected augmentation and mark it as used
    used_augmentations.add(selected_augmentation)
    return available_augmentations[selected_augmentation](img), used_augmentations


# Read CSV and get streak labels
def read_csv(csv_path):
    streak_labels = {}
    rows = []
    with open(csv_path, newline='', encoding='utf-8') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            streak_labels[row['output']] = int(row['label'])
            rows.append(row)
    return streak_labels, rows

# Save the updated CSV
def save_csv(output_csv_path, old_rows, new_rows, fieldnames):
    with open(output_csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(old_rows)  # Write original CSV content
        writer.writerows(new_rows)  # Add new rows for augmentations

# Main processing function
def process_images(input_folder, csv_path):
    streak_labels, old_rows = read_csv(csv_path)

    output_folder = os.path.join(input_folder, "augmented")
    os.makedirs(output_folder, exist_ok=True)

    # Prepare the output CSV
    csv_name = os.path.basename(csv_path)
    output_csv_path = os.path.join(output_folder, csv_name)
    fieldnames = ['input', 'output', 'label']
    new_rows = []

    # Copy all original files to the output folder
    for filename in os.listdir(input_folder):
        if filename.lower().endswith('.png'):
            file_path = os.path.join(input_folder, filename)
            output_path = os.path.join(output_folder, filename)

            # Copy original file
            shutil.copy(file_path, output_path)

            # Augment images if streaks are present
            if streak_labels.get(filename, 0) == 1:
                # Initialize a new set for each image
                used_augmentations = set()

                # Apply augmentations
                augmentations = 3  # Number of unique augmentations per image
                for iteration in range(1, augmentations + 1):
                    img = Image.open(file_path)
                    # print(f"Used augmentations: {used_augmentations}")
                    img_augmented, used_augmentations = random_augmentation(img, used_augmentations)

                    # Save augmented image
                    augmented_filename = f"{os.path.splitext(filename)[0]}_{iteration}.png"
                    augmented_output_path = os.path.join(output_folder, augmented_filename)
                    img_augmented.save(augmented_output_path)


                    # Add augmented file info to new rows
                    new_rows.append({'input': filename, 'output': augmented_filename, 'label': 1})

    # Save the updated CSV with old and new rows
    save_csv(output_csv_path, old_rows, new_rows, fieldnames)
    print(f"Processing complete. Outputs saved in '{output_folder}'.")

input_folder = './Data/fits_filtered2'  
csv_path = os.path.join(input_folder, 'dictionary_0.csv') 
process_images(input_folder, csv_path)


Processing complete. Outputs saved in './Data/fits_filtered2\augmented'.
