In [None]:
from functools import partial
from io import BytesIO

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pax
import requests
from PIL import Image

from pretrained_resnet18 import IMAGENET_MEAN, IMAGENET_STD, load_pretrained_resnet18

In [None]:
pax.seed_rng_key(42)

In [None]:
def prepare_image(img):
    """Normalize the image to the data distribution in which the model is trained."""
    img224 = img.astype(np.float32) / 255.0
    img224 = (img224 - IMAGENET_MEAN) / IMAGENET_STD
    img224 = jnp.transpose(img224, axes=(0, 3, 1, 2))
    return img224

In [None]:
# download the ground truth labels
LABELS = (
    requests.get(
        "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
    )
    .content.decode("utf-8")
    .split("\n")
)

In [None]:
def prediction(net, img):
    """make a prediction."""
    img224 = prepare_image(img)
    logits = net.eval()(img224)
    index = jnp.argmax(logits, axis=-1).item()
    predicted_label = LABELS[index]
    return predicted_label

In [None]:
resnet18 = load_pretrained_resnet18()

In [None]:
# download an image of a cat from the Internet.
URL = "https://i.natgeofe.com/n/3861de2a-04e6-45fd-aec8-02e7809f9d4e/02-cat-training-NationalGeographic_1484324_square.jpg"
response = requests.get(URL)
img = Image.open(BytesIO(response.content))
img = img.resize((224, 224))
img

In [None]:
# check if the model is working correctly.
input_image = np.array(img)[None].astype(np.float32)
predicted_label = prediction(resnet18, input_image)
print(predicted_label)

In [None]:
def loss_fn(net, image, label):
    """a standard loss function"""
    image = prepare_image(image)
    logits = net.eval()(image)
    llh = jax.nn.log_softmax(logits, axis=-1)
    target = jax.nn.one_hot(label, num_classes=llh.shape[-1])
    llh = jnp.sum(target * llh, axis=-1)
    loss = -jnp.mean(llh)
    return loss

In [None]:
@partial(jax.jit, static_argnames="epsilon")
def adversarial_step(net, image, label, original_image, epsilon=1.0):
    # compute the gradient w.r.t. the image
    loss, grads = jax.value_and_grad(loss_fn, argnums=1)(net, image, label)

    # projected gradient descent
    image = image - jnp.sign(grads) * 1e-3
    image = original_image + jnp.clip(
        image - original_image, a_min=-epsilon, a_max=epsilon
    )
    image = jnp.clip(image, a_min=0.0, a_max=255.0)
    return image, loss

In [None]:
new_label = "African elephant"
adversarial_label = jnp.array([LABELS.index(new_label)])
adversarial_image = input_image
epsilon = 1.0

In [None]:
for step in range(100_000):
    adversarial_image, loss = adversarial_step(
        resnet18,
        adversarial_image,
        adversarial_label,
        input_image,
        epsilon=epsilon,
    )
    if step % 100 == 0:
        label = prediction(resnet18, adversarial_image.astype(jnp.uint8))
        print(f"step {step:4d}  loss {loss:6.3f}  ->  {label}")
        if label == new_label:
            break

In [None]:
# sanity check with a real image of an african elephant
elephant_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/b/bf/African_Elephant_%28Loxodonta_africana%29_male_%2817289351322%29.jpg/1200px-African_Elephant_%28Loxodonta_africana%29_male_%2817289351322%29.jpg"
response = requests.get(elephant_url)
elephant_img = Image.open(BytesIO(response.content))
elephant_img = elephant_img.resize((224, 224))

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(14, 3))
ax[0].imshow(input_image[0].astype(jnp.uint8))
ax[1].imshow(adversarial_image[0].astype(jnp.uint8))
diff = jnp.max(jnp.abs(adversarial_image - input_image), axis=-1)
diff_img = ax[2].imshow(diff[0])
fig.colorbar(diff_img, ax=ax[2])
ax[3].imshow(elephant_img)

label0 = prediction(resnet18, input_image.astype(jnp.uint8))
label1 = prediction(resnet18, adversarial_image.astype(jnp.uint8))
label3 = prediction(resnet18, np.array(elephant_img)[None].astype(np.float32))

for i in range(4):
    ax[i].axis("off")

ax[0].set_title(label0)
ax[1].set_title(label1)
ax[2].set_title("Difference")
ax[3].set_title(label3)
plt.tight_layout()
plt.show()