# PyTorch Data Augmentation Notebook

Imports the necessary Python libraries, including 
- os
- random
- torch, torchvision for image processing and transformations - `pip install torch torchvision torchaudio`
- PIL for handling images - `pip install pillow`

### Imports

In [1]:
# imports
import os
import cv2
import random
import torch
import torchvision.utils as t_utils
import torchvision.transforms.v2 as t_v2

#### Variables

In [2]:
dataset_path = 'ConstellationDataset'
num_of_iterations_per_image = 50

if not os.path.exists(dataset_path):
    print("!!! <Dataset Not Found> !!!")

### Accessing/Loading Images

Defines a function to load images from a folder. The function goes through all files in the specified folder and subfolders, checking for image files (JPEG, PNG) and loading them. It saves the images in a list along with their file paths

In [5]:
# function to load images from a specified folder
def load_images_from_folder(folder, skip_dirs=None):
    # initialize empty list to store the images and their paths
    images = []
    
    # if no directories to skip are specified, initialize to an empty list
    if skip_dirs is None:
        skip_dirs = []
    
    # walking/going through the folder, including all subfolders
    for root, _, files in os.walk(folder):
        # skip the root folder - this makes sure to save only the images in the subfolders/classes
        # if root == folder:
        #     continue  # Skip files directly in the root folder
        
        # skip any directories specified in skip_dirs
        if any(skip in root for skip in skip_dirs):
            continue  # Skip this directory and its contents
        
        # iterate over each file in the folder
        for file in files:
            # check if the file has an image extension
            if file.endswith(('jpg', 'jpeg', 'png')):
                img_path = os.path.join(root, file)         # Form the full file path
                img = cv2.imread(img_path)                  # Read image using OpenCV
                
                if img is not None:                         # Ensure image was loaded successfully
                    images.append((img, img_path))          # Append image with its path as a tuple
                
    # return images list
    return images

# define skip directories
skip_directories = ['TargetImages']

# load all images from the dataset
all_images = (load_images_from_folder(dataset_path, skip_directories))

# displaying a completion message
print(f"----- <Loaded {len(all_images)} images from {dataset_path}> -----")

----- <Loaded 60 images from ConstellationDataset> -----


### Defining and Setting Transformations

 Contains helper functions to control random transformations. `coin_toss()` returns a random True/False value to decide whether to apply a transformation, and `getRandomNumber()` returns a random number for adjusting brightness, contrast, saturation, and hue

In [12]:
# function to return True of False randomly. This determines if transofrmation will be applied
def coin_toss():
    return random.choice([True, False])

# function to return a random number from 0 to 1 for the brightness, contrast, saturation and hue
def getRandomNumber(type):
    if type == 'hue':
        return round(random.uniform(0, 0.5), 2)
    else :
        return round(random.uniform(0, 1), 2)

Builds a list of image transformations `(resize, flips, crops, rotation, color adjustments)` that may be applied randomly to each image. Finally, it combines all transformations into a single transformation pipeline

In [13]:
# define transform - a list of transformations 
transform_list = []

transform_list.append(t_v2.Resize((224, 224)))                                                  # Resize images to 255x255 pixels to maintain consistent input size
if coin_toss(): transform_list.append(t_v2.RandomHorizontalFlip())                              # Randomly flipping the image horizontally
if coin_toss(): transform_list.append(t_v2.RandomVerticalFlip())                                # Randomly flipping the image vertically
if coin_toss(): transform_list.append(t_v2.RandomCrop((224, 224)))                              # Randomly cropping the image to 224x224 pixels
if coin_toss(): transform_list.append(t_v2.RandomRotation(degrees=(0, 180)))                    # Rotating the image by a random angle
if coin_toss(): transform_list.append(t_v2.ColorJitter(brightness=getRandomNumber('brightness'), 
                                                       contrast=getRandomNumber('contrast'), 
                                                       saturation=getRandomNumber('saturation'), 
                                                       hue=getRandomNumber('hue')))             # Randomly changing the brightness, contrast, saturation and hue of the image
if coin_toss(): transform_list.append(t_v2.RandomAdjustSharpness(sharpness_factor=2))           # Randomly adjusting the sharpness of the image
if coin_toss(): transform_list.append(t_v2.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5)))    # Applying Gaussian Blur to the image

# Add the final transformation to convert images to tensors
transform_list.append(t_v2.Compose([t_v2.ToImage(), t_v2.ToDtype(torch.float32, scale=True)]))

# Compose all transformations
transform = t_v2.Compose(transform_list)

# define transformToTensor seperatly
transformToTensor = t_v2.Compose([
    # t_v2.ToTensor(),  # Convert images to tensors - provides warning and recomends the code on line 44
    t_v2.Compose([t_v2.ToImage(), t_v2.ToDtype(torch.float32, scale=True)]),
])

### Saving Augmented Images

Loops through each loaded image, applies transformations, and saves five augmented versions of each image. Each saved image is named with an `_aug_{i + 1}.jpg` suffix to differentiate from the original

In [14]:
# generate and save augmented images
for img, img_path in all_images:
    dir_path = os.path.dirname(img_path) # get the directory path for saving the augmented images
    filename = os.path.splitext(os.path.basename(img_path))[0]  # extract filename without extension
    
    # Convert OpenCV image from BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # converting og image to tensor
    tensor_img = transformToTensor(img)
    
    for i in range(num_of_iterations_per_image):  # generate a number of augmented images
        augmented_img = transform(tensor_img)  # apply transformations

        augmented_filename = f"{filename}_aug_{i + 1}.jpg"  # naming convention: originalname_aug_index.jpg
        t_utils.save_image(augmented_img, os.path.join(dir_path, augmented_filename)) # save the augmented image
    

### Clean Up

Defines a function to delete previously saved augmented images. It checks each directory in the dataset for files containing `_aug_` in the name and deletes them

In [4]:
# function to empty the save directories
def empty_directory(directory):
    # iterate over each file in the directory
    for file in os.listdir(directory):
        # get the full file path
        file_path = os.path.join(directory, file)
        
        # check if the file is a file (not a subdirectory)
        if os.path.isfile(file_path):
            # check if '_aug_' is in the file name
            if '_aug_' in file:
                # remove the file
                os.remove(file_path)

# obtain all dataset directories
for _, img_path in all_images:
    dir_path = os.path.dirname(img_path)
    # empty directory
    empty_directory(dir_path)

## References

- https://pytorch.org/
- https://pytorch.org/vision/stable/transforms.html
- https://discuss.pytorch.org/t/save-transformed-resized-images-after-dataloader/56464/12