In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
from PIL import Image
import torchvision.transforms as transforms
from tqdm import tqdm

def YOCO(images, aug_flip, aug_color, h, w):
   
    if torch.rand(1) > 0.5:
        
        left = aug_flip(images[:, :, :, 0:int(w/2)])
        right = aug_color(images[:, :, :, int(w/2):w]) if torch.rand(1) > 0.5 else images[:, :, :, int(w/2):w]
        images = torch.cat((left, right), dim=3)
    else:
        
        top = aug_flip(images[:, :, 0:int(h/2), :])
        bottom = aug_color(images[:, :, int(h/2):h, :]) if torch.rand(1) > 0.5 else images[:, :, int(h/2):h, :]
        images = torch.cat((top, bottom), dim=2)
    return images


def augment_and_save(image_path, output_dir, aug_flip, aug_color):
    
    original_image = Image.open(image_path).convert("RGB")
    
    transform_to_tensor = transforms.ToTensor()
    original_tensor = transform_to_tensor(original_image).unsqueeze(0)  
    _, _, h, w = original_tensor.shape

    augmented_tensor = YOCO(original_tensor, aug_flip, aug_color, h, w)

    transform_to_pil = transforms.ToPILImage()
    augmented_image = transform_to_pil(augmented_tensor.squeeze(0))  

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, os.path.basename(image_path))
    augmented_image.save(output_path)

def process_all_images(input_dir, output_dir, aug_flip, aug_color):
   
    for root, _, files in os.walk(input_dir):
       
        if os.path.basename(root) not in ["glioma_tumor", "meningioma_tumor", "no_tumor", "pituitary_tumor"]:
            continue
        
        for file in tqdm(files, desc=f"Processing {os.path.basename(root)}"):
            if file.endswith((".jpg", ".jpeg", ".png")):  
                image_path = os.path.join(root, file)
                
                
                relative_path = os.path.relpath(root, input_dir)
                save_dir = os.path.join(output_dir, relative_path, "YOCO")
                
                
                augment_and_save(image_path, save_dir, aug_flip, aug_color)

input_dir = "/content/drive/My Drive/MRI/Training"
output_dir = "/content/drive/My Drive/MRI/Training"

aug_flip = torch.nn.Sequential(
    transforms.RandomHorizontalFlip(p=1.0)  
)

aug_color = torch.nn.Sequential(
    transforms.ColorJitter(
        brightness=0.5,  
        contrast=0.5,    
        saturation=0.5   
    )
)

process_all_images(input_dir, output_dir, aug_flip, aug_color)
