In [2]:
import os
import glob
import hashlib
import numpy as np
from PIL import Image, ImageOps
import torch
from torchvision import transforms

In [3]:
def convert_png(directory, label, target):
    for i, image in enumerate(os.listdir(directory)):
        f = os.path.join(directory, image)
        if os.path.isfile(f):
            image = Image.open(f)
            image.save(f'data/{i}{label}.{target}')
            
    os.system('rmdir /S /Q "{}"'.format(directory))

In [4]:
"""
    DO NOT RUN
"""

directory_1 = 'data/cancer'
directory_2 = 'data/non-cancer'
target = 'png'

convert_png(directory_1, 1, target)
convert_png(directory_2, 0, target)

In [17]:
class ImagePreprocessor:
    """
    A class for preprocessing images.

    Args:
        image_dir (str): The directory containing the images.
        target_size (tuple): The target size of the images after resizing.
        normalize_range (tuple): The range to normalize the pixel values to.

    Attributes:
        image_dir (str): The directory containing the images.
        target_size (tuple): The target size of the images after resizing.
        normalize_range (tuple): The range to normalize the pixel values to.
        image_hashes (dict): A dictionary to store the image hashes.

    Methods:
        hash_image(image): Computes the hash value of an image.
        check_duplicates(): Checks for duplicate images and performs augmentation if necessary.
        augment_image(image): Applies a random augmentation to an image.
        resize_image(image): Resizes an image to the target size.
        normalize_image(image): Normalizes the pixel values of an image.
        convert_to_tensor(image): Converts an image to a tensor.
        preprocess_image(): Preprocesses the images in the image directory.

    """

    def __init__(self, image_dir, target_size, normalize_range):
            """
            Initializes a SmileSavior object.

            Args:
                image_dir (str): The directory path where the images are stored.
                target_size (tuple): The desired size of the images after resizing.
                normalize_range (bool): Flag indicating whether to normalize the pixel values of the images.

            Attributes:
                image_dir (str): The directory path where the images are stored.
                target_size (tuple): The desired size of the images after resizing.
                normalize_range (bool): Flag indicating whether to normalize the pixel values of the images.
                image_hashes (dict): A dictionary to store the hashes of the images.

            """
            self.image_dir = image_dir
            self.target_size = target_size
            self.normalize_range = normalize_range
            self.image_hashes = {}

    def hash_image(self, image):
        """
        Calculates the MD5 hash of the given image.

        Parameters:
        image (numpy.ndarray): The image to be hashed.

        Returns:
        str: The MD5 hash of the image.
        """
        return hashlib.md5(image.tobytes()).hexdigest()

    def check_duplicates(self):
        """
        Check for duplicate images in the specified image directory and perform image augmentation if duplicates are found.

        This method iterates over all the images in the specified image directory and checks if each image has a duplicate.
        If a duplicate is found, the image is augmented and saved, replacing the original image. If no duplicate is found,
        the image is added to the list of image hashes for future comparison.

        Returns:
            None
        """
        images = glob.glob(os.path.join(self.image_dir, '*'))
        for image_path in images:
            image = Image.open(image_path).convert('RGB')
            image_hash = self.hash_image(image)
            
            if image_hash in self.image_hashes:
                image = self.augment_image(image)
                image.save(image_path)
            else:
                self.image_hashes[image_hash] = image_path

    def augment_image(self, image):
            """
            Apply a random augmentation to the given image.

            Parameters:
            image (PIL.Image.Image): The input image to be augmented.

            Returns:
            PIL.Image.Image: The augmented image.
            """
            augmentations = [
                ImageOps.mirror,
                ImageOps.flip,
                lambda img: img.rotate(90)
            ]
            augmentation = np.random.choice(augmentations)
            return augmentation(image)
    
    def resize_image(self, image):
        """
        Resizes the given image to the target size.

        Parameters:
        - image: The image to be resized.

        Returns:
        - The resized image.
        """
        return image.resize(self.target_size)

    def normalize_image(self, image):
        """
        Normalize the given image.

        Args:
            image (PIL.Image.Image): The input image to be normalized.

        Returns:
            PIL.Image.Image: The normalized image.

        """
        image_array = np.array(image).astype(np.float32)
        image_array /= 255.0
        if self.normalize_range == (-1, 1):
            image_array = image_array * 2 - 1
        return Image.fromarray((image_array * 255).astype(np.uint8))
    
    def convert_to_tensor(self, image):
        """
        Converts an image to a PyTorch tensor.

        Args:
            image (PIL.Image.Image): The input image.

        Returns:
            torch.Tensor: The converted tensor representation of the image.
        """
        transform = transforms.ToTensor()
        return transform(image)
    
    def preprocess_image(self):
            """
            Preprocesses the images in the specified directory.

            Returns:
            processed_images (list): A list of processed images.
            """
            self.check_duplicates()
            images = glob.glob(os.path.join(self.image_dir, '*'))
            if not images:
                print("No images found in the directory.")

            processed_images = []
            labels = []
            for image_path in images:
                print(f"Processing image: {image_path}")
                image = Image.open(image_path).convert('RGB')
                image = self.resize_image(image)
                print(f"Resized image: {image.size}")
                image = self.normalize_image(image)
                print(f"Normalized image: {np.array(image).shape}")
                image = self.convert_to_tensor(image)
                print(f"Converted to tensor: {image.shape}")
                processed_images.append(image)
                
                target = image_path.split('.')[0][-1]
                labels.append(target)
                
            return processed_images, labels


In [18]:
preprocessor = ImagePreprocessor('data', (224, 224), (0, 1))
data, labels = preprocessor.preprocess_image()
print("-------------------------------------------------------------Completed Preprocessing-------------------------------------------------------------")

Processing image: data\00.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\01.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\10.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\100.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\101.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\11.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\110.png
Resized image: (224, 224)
Normalized image: (224, 224, 3)
Converted to tensor: torch.Size([3, 224, 224])
Processing image: data\111.png
Resized image: (224, 

In [20]:
len(data), len(labels)

(131, 131)