In [254]:
# Imports
import os
import shutil
from torchvision import transforms
from PIL import Image, ImageEnhance, ImageOps
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
import warnings
import numpy as np
import cv2

In [255]:
# Variables
# Define directories
source_dirs = ['./images', './augmented_images']  # List of source directories containing your original and augmented images
output_dir = './preprocessed_images'  # Destination directory for preprocessed images

images_by_label = {}

labels = {}

batch_size = 64

# Suppress the DecompressionBombWarning
warnings.simplefilter('ignore', Image.DecompressionBombWarning)

In [256]:
# Locate all classes in the source directories
for source_dir in source_dirs:
    for class_name in os.listdir(source_dir):
        class_path = os.path.join(source_dir, class_name)
        if os.path.isdir(class_path):
            if class_name not in labels:
                labels[class_name] = []
            print(f'Found class {class_name} in {source_dir}')

Found class Paper in ./images
Found class Rock in ./images
Found class Scissor in ./images
Found class Paper in ./augmented_images
Found class Rock in ./augmented_images
Found class Scissor in ./augmented_images


In [257]:
# Load the data
for source_dir in source_dirs:
    for class_name in os.listdir(source_dir):
        class_path = os.path.join(source_dir, class_name)
        if not os.path.isdir(class_path):
            continue  # Skip if not a directory

        # Initialize list for the class if not already present
        if class_name not in images_by_label:
            images_by_label[class_name] = []

        # Traverse images in the class folder
        for file_name in os.listdir(class_path):
            if file_name.endswith(('.png', '.jpg', '.jpeg')):  # Check for valid image extensions
                file_path = os.path.join(class_path, file_name)

                # Load the image
                image = Image.open(file_path)

                # Append the image to the corresponding label's list
                images_by_label[class_name].append(image)

# Print summary
for label, images in images_by_label.items():
    print(f"Loaded {len(images)} images for label '{label}'.")

Loaded 320 images for label 'Paper'.
Loaded 432 images for label 'Rock'.
Loaded 312 images for label 'Scissor'.


In [258]:
def resize_with_aspect_ratio(image, size):
    # Resize the image while maintaining the aspect ratio
    wpercent = (size / float(image.size[0]))
    hsize = int((float(image.size[1]) * float(wpercent)))
    img = image.resize((size, hsize), Image.Resampling.LANCZOS)
    return img

In [259]:
def auto_exposure(img):
    # Convert the image to a numpy array
    img_array = np.array(img)

    # Split into R, G, B channels
    r, g, b = cv2.split(img_array)

    # Apply CLAHE to each channel
    clahe = cv2.createCLAHE(clipLimit=25.0, tileGridSize=(8, 8))
    r = clahe.apply(r)
    g = clahe.apply(g)
    b = clahe.apply(b)

    # Merge the channels back together
    img_array = cv2.merge([r, g, b])

    # Convert back to PIL Image
    final_img = Image.fromarray(img_array)

    return final_img

In [260]:
# Define preprocessing pipeline
preprocess_pipeline = transforms.Compose([
    transforms.Lambda(lambda img: resize_with_aspect_ratio(img, 256)),  # Resize the image while maintaining the aspect ratio
    transforms.CenterCrop(256),  # Center crop the image
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Random color jitter for generalization
    transforms.Lambda(lambda img: auto_exposure(img)),  # Apply auto exposure
    transforms.ToTensor(),  # Convert the image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize the image
])

In [261]:
# Dictionary to store combined datasets for each class
datasets_by_class = {class_name: [] for class_name in labels.keys()}

# Accumulate datasets for each class
for source_dir in source_dirs:
    for class_name in labels.keys():
        class_path = os.path.join(source_dir, class_name)
        if not os.path.isdir(class_path):
            continue  # Skip if not a directory

        # Use ImageFolder with a structure that ImageFolder expects
        dataset = ImageFolder(root=source_dir, transform=preprocess_pipeline)

        # Filter dataset to include only images belonging to the current class
        class_indices = [i for i, (_, label) in enumerate(dataset.samples) if dataset.classes[label] == class_name]
        filtered_dataset = torch.utils.data.Subset(dataset, class_indices)
        
        # Append the filtered dataset to the list for this class
        datasets_by_class[class_name].append(filtered_dataset)


In [262]:
data_loaders_by_class = {}

for class_name, dataset_list in datasets_by_class.items():
    # Combine all filtered datasets for this class
    if len(dataset_list) > 1:
        combined_dataset = torch.utils.data.ConcatDataset(dataset_list)
    else:
        combined_dataset = dataset_list[0]  # If only one dataset, no need to concatenate

    # Create a single DataLoader for the combined dataset
    dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    data_loaders_by_class[class_name] = dataloader

    print(f'Loaded {len(combined_dataset)} images for class {class_name}.')

Loaded 320 images for class Paper.
Loaded 432 images for class Rock.
Loaded 312 images for class Scissor.


In [263]:
# Create the output directory if it doesn't exist
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
else:
    # Delete directory if it already exists
    shutil.rmtree(output_dir)
    os.makedirs(output_dir)

print(f"Created/checked directory: {output_dir}")	


Created/checked directory: ./preprocessed_images


In [264]:
print("Processing and saving images...")
for class_name, dataloader in data_loaders_by_class.items():
    print(f"Processing and saving images for class '{class_name}'...")
    
    # Create a subdirectory for the class
    class_dir = os.path.join(output_dir, class_name)
    os.makedirs(class_dir, exist_ok=True)

    # Iterate through the DataLoader batches
    for batch_idx, (images, labels) in enumerate(dataloader):
        print(f"Processing batch {batch_idx} of {label}...")
        for img_idx, image in enumerate(images):
            print(f"Processing image {img_idx+(batch_idx*img_idx)}...")

            # Generate a unique filename
            save_path = os.path.join(
                class_dir,
                f"{class_name}_{img_idx+(batch_idx*img_idx):03d}.jpg"
            )

            # Save the preprocessed image
            save_image(image, save_path)
            print(f"Saved: {save_path}")
            
        print(f"Batch {batch_idx} of {label} complete!")

print("Preprocessing complete!")

Processing and saving images...
Processing and saving images for class 'Paper'...
Processing batch 0...
Processing image 0...
Saved: ./preprocessed_images\Paper\Paper_000.jpg
Processing image 1...
Saved: ./preprocessed_images\Paper\Paper_001.jpg
Processing image 2...
Saved: ./preprocessed_images\Paper\Paper_002.jpg
Processing image 3...
Saved: ./preprocessed_images\Paper\Paper_003.jpg
Processing image 4...
Saved: ./preprocessed_images\Paper\Paper_004.jpg
Processing image 5...
Saved: ./preprocessed_images\Paper\Paper_005.jpg
Processing image 6...
Saved: ./preprocessed_images\Paper\Paper_006.jpg
Processing image 7...
Saved: ./preprocessed_images\Paper\Paper_007.jpg
Processing image 8...
Saved: ./preprocessed_images\Paper\Paper_008.jpg
Processing image 9...
Saved: ./preprocessed_images\Paper\Paper_009.jpg
Processing image 10...
Saved: ./preprocessed_images\Paper\Paper_010.jpg
Processing image 11...
Saved: ./preprocessed_images\Paper\Paper_011.jpg
Processing image 12...
Saved: ./preprocess