In [None]:
from google.colab import drive
import os
import random
from PIL import Image, ImageFilter
import torch
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F


drive.mount('/content/drive')


dataset_path = '/content/drive/My Drive/contenteye_diseases/Training'


output_folder_name = "SnapMix"


def ensure_dir(directory):
   
    if not os.path.exists(directory):
        os.makedirs(directory)

def save_image(tensor, path):
   
    image = transforms.ToPILImage()(tensor.cpu())
    image.save(path)

def is_valid_image(img_tensor):
    
    img_np = img_tensor.numpy()
    return np.any(img_np > 0)  

def rand_bbox(size, lam):
   
    W = size[2]  
    H = size[3]  
    cut_rat = np.sqrt(1. - lam)  
    cut_w = int(W * cut_rat)  
    cut_h = int(H * cut_rat)  

   
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def random_color_jitter(image):
    
    color_jitter = transforms.ColorJitter(
        brightness=0.4,  
        contrast=0.4,    
        saturation=0.4,  
        hue=0.2          
    )
    return color_jitter(image)

def random_blur(image):
   
    if random.random() < 0.5:
        return image.filter(ImageFilter.GaussianBlur(radius=random.uniform(0.1, 2.0)))
    return image

def random_noise(image):
    
    if random.random() < 0.5:
        np_image = np.array(image)
        noise = np.random.normal(0, 25, np_image.shape).astype(np.int16)
        np_image = np.clip(np_image + noise, 0, 255).astype(np.uint8)
        return Image.fromarray(np_image)
    return image

def apply_random_augmentations(image):
    
    image = random_color_jitter(image)
    image = random_blur(image)
    image = random_noise(image)
    return image

def snapmix(input, target, conf=None, model=None):
    
    r = np.random.rand(1)  
    lam_a = torch.ones(input.size(0), device=input.device)  
    lam_b = 1 - lam_a  
    target_b = target.clone()  

    if r < 0.5:  
        bs = input.size(0)  
        lam = np.random.beta(1.0, 1.0)  

        
        target = target.cuda()  
        rand_index = torch.randperm(bs).cuda()  
        target_b = target[rand_index] 

        
        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)

        
        if bbx2 > bbx1 and bby2 > bby1:
            ncont = input[rand_index, :, bbx1:bbx2, bby1:bby2].clone()  
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont  
            lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))  
            lam_b = 1 - lam_a  

   
    lam_a = torch.tensor(lam_a, device=input.device)
    lam_b = torch.tensor(lam_b, device=input.device)

    return input, target, target_b, lam_a, lam_b

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])


for subfolder in os.listdir(dataset_path):
    subfolder_path = os.path.join(dataset_path, subfolder)
    
    
    if not os.path.isdir(subfolder_path):
        continue
    
    
    snapmix_folder = os.path.join(subfolder_path, output_folder_name)
    ensure_dir(snapmix_folder)
    
    
    for img_name in os.listdir(subfolder_path):
        img_path = os.path.join(subfolder_path, img_name)
        
        
        if not img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
            continue
        
        try:
           
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"无法加载图片 {img_path}: {e}")
            continue
        
       
        img = apply_random_augmentations(img)

        
        input_tensor = transform(img).unsqueeze(0).cuda()
        
        
        target = torch.tensor([random.randint(0, 3)]).cuda()  
        
        
        augmented_tensor, _, _, _, _ = snapmix(input_tensor, target)
        
       
        if not is_valid_image(augmented_tensor[0].cpu()):
            print(f"增强后的图片无效，跳过: {img_name}")
            continue
        
        
        output_path = os.path.join(snapmix_folder, f"SnapMix_{img_name}")
        save_image(augmented_tensor[0], output_path)
        print(f"保存增强图片: {output_path}")
