In [None]:
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import itertools


In [None]:
def compute_jacobian(model, image):
    """ Compute the Jacobian matrix of the model output with respect to the input image. """

    image.requires_grad = True
    output = torch.sigmoid(model(image.unsqueeze(0)))  # Forward pass
    grad_outputs = torch.ones_like(output.squeeze(0))  # Same shape as y
    element_wise_grad = torch.autograd.grad(output.squeeze(0), image, grad_outputs=grad_outputs)[0]
    return element_wise_grad
    
def saliency_map(jacobian, target_label):
    """ Compute the saliency map to determine which pixels to perturb. """
    # Compute impact of modifying each pixel
    target_grad = jacobian * target_label  # Positive if it pushes towards target class    
    return target_grad  # Saliency scores

def jsma_attack(model, image, target_label, epsilon, alpha):
    """
    Performs a JSMA attack with invisible perturbations on a binary segmentation model.
    
    Args:
        model: The PyTorch model.
        image: The input image (tensor of shape [C, H, W]).
        target_label: The desired segmentation output (tensor of shape [H, W]).
        initial_alpha: The initial perturbation step size.
        epsilon: Maximum perturbation magnitude.
    
    Returns:
        The adversarial image.
    """
    adversarial = image.clone().detach()
    decay_factor = 0.8
    acc = 0
    accs = torch.zeros(1000)
    itr = 0
    while (acc < 0.9):
        jacobian = compute_jacobian(model, adversarial)  # Compute Jacobian
        saliency = saliency_map(jacobian, target_label)  # Compute saliency
        
        
        with torch.no_grad():
            
            adversarial = adversarial + alpha * saliency.sign()
            
            adversarial = torch.max(torch.min(adversarial, image + epsilon), image - epsilon)
            adversarial = torch.clamp(adversarial, 0, 1)


        # Check if misclassification occurs
        pred = model(adversarial.unsqueeze(0))
        pred = torch.sigmoid(pred)
        pred = (pred > 0.5).float()
        acc = ((pred == target_label).sum()) / torch.numel(target_label)
        
        #print(f"iter {itr} acc: {acc}")
        
        accs[itr] = acc
        if itr > 0 and accs[itr] <= accs[itr-1]:  
            break
        itr += 1
        
        alpha *= decay_factor
    return adversarial


In [None]:

# Load the image for the targeted class 
image_path = "./ball.jpeg"  # Change to your image path
image = Image.open(image_path).convert("L")


# Define transformations: Resize, Convert to Tensor, Normalize
threshold = 100  # Define the threshold

transform = transforms.Compose([
    transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
    transforms.ToTensor(),  # Converts image to [0, 1]
    transforms.Lambda(lambda x: (x < threshold / 255).float())  # Binarization
])
# Apply transformations
image_tensor = transform(image)
print(image_tensor.unique())
plt.imshow(image_tensor.permute(1,2,0).numpy(), cmap="gray")


In [None]:
fig, axes = plt.subplots(4, 4, figsize=(20, 30))
val_batch = next(itertools.islice(iter(val_loader), 1, None))
for i in range(4):
    image, target = val_batch[0][i], val_batch[1][i]

    model.eval()  # Set to evaluation mod

    adv_img = jsma_attack(model, image, image_tensor,0.062, 4/255)

    Adv_pred = torch.sigmoid(model(adv_img.unsqueeze(0)))
    Adv_pred = (Adv_pred > 0.5).float()
    
    pred = torch.sigmoid(model(val_batch[0][i].unsqueeze(0)))
    pred = (pred > 0.5).float()
    
    axes[i,0].imshow(image.squeeze(0).permute(1,2,0).detach().cpu().numpy())
    axes[i,0].set_title("original_image")
    axes[i,1].imshow(pred.squeeze(0).permute(1,2,0).detach().cpu().numpy(), cmap = "gray")
    axes[i,1].set_title("pred")
    axes[i,2].imshow(adv_img.permute(1,2,0).detach().cpu().numpy())
    axes[i,2].set_title("adv_image")
    axes[i,3].imshow(Adv_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy(), cmap = "gray")
    axes[i,3].set_title("Adv_pred")
    Adv_pred = Adv_pred.cpu()
    num_correct = (Adv_pred == target).sum()
    num_pixels = torch.numel(target)
    print((num_correct/num_pixels).item())
