In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
from google.colab import drive


drive.mount('/content/drive')


dataset_path = "/content/drive/My Drive/contenteye_diseases/Training/"


categories = ["cataract", "diabetic_retinopathy", "glaucoma", "normal"]


def random_crop(img, min_ratio=1/6, max_ratio=0.9):
    """
    随机裁剪图片的任意区域，裁剪区域大小在 min_ratio 和 max_ratio 之间。
    """
    h, w, _ = img.shape

   
    crop_ratio = np.random.uniform(min_ratio, max_ratio)
    crop_h = int(h * crop_ratio)
    crop_w = int(w * crop_ratio)

   
    x1 = np.random.randint(0, w - crop_w + 1)
    y1 = np.random.randint(0, h - crop_h + 1)

   
    x2 = x1 + crop_w
    y2 = y1 + crop_h

   
    cropped_img = img[y1:y2, x1:x2, :]
    return cropped_img, (x1, y1, x2, y2)


def cutmix(img, cropped_img, crop_coords):
    """
    将裁剪的区域插入到原图的随机位置，形成 CutMix 效果。
    """
    h, w, _ = img.shape
    crop_h, crop_w, _ = cropped_img.shape

   
    cx = np.random.randint(crop_w // 2, w - crop_w // 2)
    cy = np.random.randint(crop_h // 2, h - crop_h // 2)

  
    x1 = max(cx - crop_w // 2, 0)
    y1 = max(cy - crop_h // 2, 0)
    x2 = min(cx + crop_w // 2, w)
    y2 = min(cy + crop_h // 2, h)

  
    img[y1:y2, x1:x2, :] = cropped_img[:y2-y1, :x2-x1, :]
    return img


for category in categories:
    category_path = os.path.join(dataset_path, category)
    cutmix_path = os.path.join(category_path, "CutMix")

   
    if not os.path.exists(cutmix_path):
        os.makedirs(cutmix_path)

 
    images = [f for f in os.listdir(category_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    for image_name in tqdm(images, desc=f"Processing {category}"):
        
        image_path = os.path.join(category_path, image_name)
        img = cv2.imread(image_path)

        if img is None:
            continue  

       
        cropped_img, crop_coords = random_crop(img)

        
        cutmix_img = cutmix(img.copy(), cropped_img, crop_coords)

     
        save_path = os.path.join(cutmix_path, f"cutmix_{image_name}")
        cv2.imwrite(save_path, cutmix_img)
