In [None]:
import torch
import numpy as np
import cv2
from matplotlib import pyplot as plt
from torchvision import transforms
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Load Image from Google Drive
image_path = '/content/drive/My Drive/Eye_rgb/1144_left.jpg'  # Replace with your image path
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Preprocess Image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224))  # Resize to a fixed size
])
input_image = transform(image).unsqueeze(0)  # Add batch dimension

# Simulate Target Labels
target = torch.tensor([1])  # Example target label

# Configuration for SnapMix
class Config:
    prob = 1.0  # Apply SnapMix with certainty
    beta = 1.0  # Beta distribution parameter

conf = Config()

# Define Dummy Functions for get_spm and utils.rand_bbox
def get_spm(input, target, conf, model=None):
    # Dummy saliency map (uniform random values)
    bs, _, h, w = input.size()
    wfmaps = torch.rand(bs, h, w).to(input.device)  # Ensure it's on the same device as input
    return wfmaps, None

def rand_bbox(size, lam):
    # Generate random bounding box
    W = size[-1]
    H = size[-2]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniformly select center point
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# Add Dummy Functions to utils Namespace
class utils:
    rand_bbox = rand_bbox

# SnapMix Function
def snapmix(input, target, conf, model=None):
    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0)).to(input.device)  # Ensure lam_a is on the same device as input
    lam_b = 1 - lam_a
    target_b = target.clone()

    if r < conf.prob:
        wfmaps, _ = get_spm(input, target, conf, model)
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        lam1 = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).to(input.device)  # Ensure rand_index is on the same device as input
        wfmaps_b = wfmaps[rand_index, :, :].to(input.device)  # Move wfmaps_b to the same device
        target_b = target[rand_index]

        same_label = target == target_b
        bbx1, bby1, bbx2, bby2 = utils.rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = utils.rand_bbox(input.size(), lam1)

        area = (bby2 - bby1) * (bbx2 - bbx1)
        area1 = (bby2_1 - bby1_1) * (bbx2_1 - bbx1_1)

        if area1 > 0 and area > 0:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = torch.nn.functional.interpolate(
                ncont, size=(bbx2 - bbx1, bby2 - bby1), mode='bilinear', align_corners=True
            )
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            lam_a = 1 - wfmaps[:, bbx1:bbx2, bby1:bby2].sum(2).sum(1) / (wfmaps.sum(2).sum(1) + 1e-8)
            lam_b = wfmaps_b[:, bbx1_1:bbx2_1, bby1_1:bby2_1].sum(2).sum(1) / (wfmaps_b.sum(2).sum(1) + 1e-8)
            tmp = lam_a.clone()
            lam_a[same_label] += lam_b[same_label]
            lam_b[same_label] += tmp[same_label]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
            lam_a[torch.isnan(lam_a)] = lam
            lam_b[torch.isnan(lam_b)] = 1 - lam

    return input, target, target_b, lam_a.cuda(), lam_b.cuda()

# Apply SnapMix
input_image = input_image.cuda()  # Ensure input is on GPU
target = target.cuda()  # Ensure target is on GPU
input, target, target_b, lam_a, lam_b = snapmix(input_image, target, conf)

# Visualize Results
input_np = input.squeeze(0).permute(1, 2, 0).cpu().numpy()  # Convert tensor to numpy array
input_np = (input_np * 255).astype(np.uint8)  # Scale pixel values

plt.figure(figsize=(10, 5))

# Original Image
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Original Image")
plt.axis('off')

# SnapMix Image
plt.subplot(1, 2, 2)
plt.imshow(input_np)
plt.title("SnapMix Image")
plt.axis('off')

plt.tight_layout()
plt.show()
