In [None]:
# load Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Import necessary libraries
import os
import random
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import tqdm  # Used to display progress bar

# Defining the CropMix Class
class CropMix:
    def __init__(self, scale, mix_ratio, number, operation, inter_aug, post_aug):
        self.scale = scale
        self.mix_ratio = mix_ratio
        self.number = number
        self.operation = operation
        self.inter_aug = inter_aug
        self.post_aug = post_aug

    def __call__(self, x):
        if self.number == 234:
            self.number = random.choice([2, 3, 4])

        if self.number == 2:
            t1 = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(self.scale, self.scale + (1 - self.scale) / self.number)),
                transforms.ToTensor(),
            ])
            t2 = transforms.Compose([
                transforms.RandomResizedCrop(224, scale=(self.scale + (1 - self.scale) / self.number, 1)),
                transforms.ToTensor(),
            ])
            view1 = t1(x)
            view2 = t2(x)
            lam = np.random.beta(self.mix_ratio / self.number, self.mix_ratio / self.number)
            mixed = lam * view1 + (1 - lam) * view2

        return self.post_aug(mixed)

# Define functions for visualizing and saving augmented images
def visualize_and_save(image_path, cropmix, output_dir):
    """
    Perform CropMix on a single image and save the original image and the augmented image.
    
    Args:
        image_path (str): Image path.
        cropmix (CropMix): The instantiated CropMix class.
        output_dir (str): The directory where the augmented images are saved.
    """
    # Load original image
    original_image = Image.open(image_path).convert("RGB")

    # CropMix of the original image
    transformed_image = cropmix(original_image)

    # Visualize the original and augmented images
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].imshow(original_image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")

    # Convert the augmented image to numpy format to support visualization
    transformed_image_np = np.array(transformed_image)
    ax[1].imshow(transformed_image_np)
    ax[1].set_title("CropMix Augmented Image")
    ax[1].axis("off")

    # Save the augmented image
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    transformed_image.save(output_path)
    plt.close(fig)

# Traverse folders and process pictures in batches
def process_all_images(input_dir, output_dir, cropmix):
    """
    Traverse the folder, perform CropMix on all images, and save the results.
    
    Args:
        input_dir (str): Enter the root directory of your images.
        output_dir (str): The root directory for output augmented images.
        cropmix (CropMix): The instantiated CropMix class.
    """
    for root, _, files in os.walk(input_dir):
        for file in tqdm(files, desc=f"Processing {os.path.basename(root)}"):
            if file.endswith((".jpg", ".jpeg", ".png")):  # Supported image formats
                image_path = os.path.join(root, file)
                
                # Construct the corresponding output path
                relative_path = os.path.relpath(root, input_dir)
                save_dir = os.path.join(output_dir, relative_path, "CropMix")
                
                # Visualize and save the augmented image
                visualize_and_save(image_path, cropmix, save_dir)

# set path
input_dir = "/content/drive/My Drive/contenteye_diseases/Training"
output_dir = "/content/drive/My Drive/contenteye_diseases/Training"

# initial CropMix class
cropmix = CropMix(
    scale=0.5,  # Minimum ratio for random cropping
    mix_ratio=1.0,  # Mixing ratio
    number=2,  # Number of operations
    operation=0,  # 0 means mixup, 1 means cutmix
    inter_aug=False,  # Whether to perform cross enhancement
    post_aug=transforms.ToPILImage()  # Convert to PIL format
)

# Process images in batches
process_all_images(input_dir, output_dir, cropmix)
