In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define the YOCO function
def YOCO(images, aug, h, w):
    if torch.rand(1) > 0.5:
        # Split and augment horizontally
        images = torch.cat((aug(images[:, :, :, 0:int(w/2)]), aug(images[:, :, :, int(w/2):w])), dim=3)
    else:
        # Split and augment vertically
        images = torch.cat((aug(images[:, :, 0:int(h/2), :]), aug(images[:, :, int(h/2):h, :])), dim=2)
    return images

# Mount Google Drive (if running on Google Colab)
from google.colab import drive
drive.mount('/content/drive')

# Load the image from Google Drive
image_path = '/content/drive/My Drive/Eye_rgb/1144_left.jpg'
image = Image.open(image_path).convert('RGB')

# Convert the image to a tensor
transform_to_tensor = transforms.Compose([
    transforms.ToTensor()
])
image_tensor = transform_to_tensor(image).unsqueeze(0)  # Add batch dimension

# Get image dimensions
_, _, h, w = image_tensor.shape

# Define the augmentation pipeline
aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=1.0),  # Always apply horizontal flip for visualization
])

# Apply YOCO augmentation
augmented_image_tensor = YOCO(image_tensor, aug, h, w)

# Convert tensors back to numpy for visualization
def tensor_to_image(tensor):
    tensor = tensor.squeeze(0)  # Remove batch dimension
    tensor = tensor.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
    return tensor.numpy()

original_image_np = tensor_to_image(image_tensor)
augmented_image_np = tensor_to_image(augmented_image_tensor)

# Plot the original and YOCO-augmented images
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# Original Image
axs[0].imshow(original_image_np)
axs[0].set_title('Original Image')
axs[0].axis('off')

# YOCO-Augmented Image
axs[1].imshow(augmented_image_np)
axs[1].set_title('YOCO-Augmented Image')
axs[1].axis('off')

plt.tight_layout()
plt.show()
