In [None]:
import random
import numpy as np
import torch
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import torchvision.transforms as transforms

# Define the CropMix class
class CropMix:
    def __init__(self, number=3, operation=0, mix_ratio=0.5, max_size_ratio=0.25):
        self.number = number
        self.operation = operation
        self.mix_ratio = mix_ratio
        self.max_size_ratio = max_size_ratio  # Limit the size of the cropping area

    def __call__(self, img):
        img_tensor = transforms.ToTensor()(img)
        original_h, original_w = img_tensor.shape[1], img_tensor.shape[2]

        # Calculate the maximum size of the cropping area
        max_h, max_w = int(original_h * self.max_size_ratio), int(original_w * self.max_size_ratio)
        cropper = transforms.RandomResizedCrop(size=(max_h, max_w), scale=(0.5, 1.0))

        # Randomly crop to generate multiple views
        views = [cropper(img_tensor) for _ in range(self.number)]

        # mixing operation
        if self.operation == 0:  # mixup
            mixed_img = self.mixup(views)
        elif self.operation == 1:  # cutmix
            mixed_img = self.cutmix(views)
        else:
            raise ValueError("Invalid operation. Use 0 for mixup or 1 for cutmix.")

        return views, mixed_img

    def mixup(self, views):
        mixed_img = views[0]
        for view in views[1:]:
            alpha = np.random.beta(self.mix_ratio, self.mix_ratio)
            mixed_img = alpha * mixed_img + (1 - alpha) * view
        return mixed_img

    def cutmix(self, views):
        mixed_img = views[0]
        for view in views[1:]:
            _, H, W = mixed_img.shape
            cut_x = random.randint(0, W)
            cut_y = random.randint(0, H)
            mixed_img[:, cut_y:, cut_x:] = view[:, cut_y:, cut_x:]
        return mixed_img

# load image
image_path = '/content/drive/My Drive/Eye_rgb/1144_left.jpg'  # Replace with the actual path
original_image = Image.open(image_path).convert('RGB')

# Define an instance of CropMix
cropmix = CropMix(
    number=3,  # Generate 3 cropped views
    operation=0,  # use mixup
    mix_ratio=0.5,  # mixing ratio
    max_size_ratio=0.25  # Limit the cropped image size to 1/4 of the original image
)

# Apply CropMix augmentation
cropped_views, mixed_image_tensor = cropmix(original_image)

# Convert a section view to a PIL image
cropped_views_pil = [transforms.ToPILImage()(view) for view in cropped_views]

# Convert a blended image to a PIL image
mixed_image = transforms.ToPILImage()(mixed_image_tensor)

# Draw a cropping frame on the original image
draw = ImageDraw.Draw(original_image)
original_w, original_h = original_image.size
for view in cropped_views_pil:
    crop_w, crop_h = view.size
    x1 = random.randint(0, original_w - crop_w)
    y1 = random.randint(0, original_h - crop_h)
    x2, y2 = x1 + crop_w, y1 + crop_h
    draw.rectangle([x1, y1, x2, y2], outline="red", width=3)

# Visualize results
plt.figure(figsize=(15, 5))

# original image
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(original_image)
plt.axis('off')

# Cropped view
plt.subplot(1, 3, 2)
plt.title("Cropped Views")
for i, view in enumerate(cropped_views_pil):
    plt.imshow(view)
    plt.axis('off')

# Mixed images
plt.subplot(1, 3, 3)
plt.title("Mixed Image")
plt.imshow(mixed_image)
plt.axis('off')

plt.tight_layout()
plt.show()
