In [None]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import random
import torch  # use PyTorch implement partial enhancement logic
from google.colab import drive

# load Google Drive
drive.mount('/content/drive')

# load image
def load_image(image_path):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError(f"Unable to load image, please check the path：{image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换为 RGB 格式
    return image

# AugMix Augmentation
def blur(image, severity):
    """blurred image"""
    ksize = severity * 2 + 1  # blur kernel size
    blurred = cv2.GaussianBlur(image.permute(1, 2, 0).numpy(), (ksize, ksize), 0)
    return torch.tensor(blurred, dtype=torch.float32).permute(2, 0, 1)

def sharpen(image, severity):
    """sharpened image"""
    kernel = np.array([[0, -1, 0], [-1, 5 + severity, -1], [0, -1, 0]])
    sharpened = cv2.filter2D(image.permute(1, 2, 0).numpy(), -1, kernel)
    return torch.tensor(sharpened, dtype=torch.float32).permute(2, 0, 1)

def gamma_correction(image, severity):
    """gamma correction"""
    gamma = 1.0 + severity * 0.2
    inv_gamma = 1.0 / gamma
    table = np.array([(i / 255.0) ** inv_gamma * 255 for i in range(256)]).astype("uint8")
    corrected = cv2.LUT(image.permute(1, 2, 0).numpy().astype(np.uint8), table)
    return torch.tensor(corrected, dtype=torch.float32).permute(2, 0, 1)

def color_jitter(image, severity):
    """color dithering"""
    jitter = severity * 0.2
    image = image + torch.randn_like(image) * jitter
    return torch.clamp(image, 0, 1)

def contrast(image, severity):
    """adjust contrast"""
    factor = 1 + severity * 0.5
    mean = torch.mean(image, dim=(1, 2), keepdim=True)
    return torch.clamp((image - mean) * factor + mean, 0, 1)

# List of enhanced operations after replacement
augmentations = [
    blur,
    sharpen,
    gamma_correction,
    color_jitter,
    contrast,
]

# AugMix Function
def apply_op(image, op, severity):
    return op(image, severity)

def augmix(image, severity=3, width=3, depth=-1, alpha=1.):
    """Perform AugMix augmentation and blend the results"""
    ws = np.float32(np.random.dirichlet([alpha] * width))
    m = np.float32(np.random.beta(alpha, alpha))

    # Convert to PyTorch tensor
    image_torch = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
    mix = torch.zeros_like(image_torch)

    for i in range(width):
        image_aug = image_torch.clone()
        d = depth if depth > 0 else np.random.randint(1, 4)
        for _ in range(d):
            op = np.random.choice(augmentations)
            image_aug = apply_op(image_aug, op, severity)
        mix += ws[i] * image_aug

    mixed = (1 - m) * image_torch + m * mix
    return (mixed.permute(1, 2, 0).numpy() * 255).astype(np.uint8)

# load image
image_path = '/content/drive/My Drive/Eye_rgb/1144_left.jpg'  # Replace with your image path 
original_image = load_image(image_path)

# use AugMix Augmentation
augmix_image = augmix(original_image, severity=5, width=3)

# visual comparison
plt.figure(figsize=(10, 5))

# original image
plt.subplot(1, 2, 1)
plt.imshow(original_image)
plt.title("Original")
plt.axis("off")

# AugMix augmented image
plt.subplot(1, 2, 2)
plt.imshow(augmix_image)
plt.title("AugMix Enhanced")
plt.axis("off")

plt.show()
