In [None]:
import torch
import os
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm import tqdm
os.chdir('../src')

from models.diffusion import Diffusion, CondDiffusion
from models.unet import Unet
from models.condunet import condUnet

reverse_transform = transforms.Compose([transforms.Lambda(lambda x: (x+1)/2), transforms.ToPILImage()])


In [None]:
model_path = '../outputs/models/cond_fashion_mnist_diffusion_linear_sched_epoch_19.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

net = condUnet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
    num_classes=10,
)
net.load_state_dict(torch.load(model_path))

fashion_mnist_cond_diffusion = CondDiffusion(
    model=net,
    data_shape=data_shape,
    noise_schedule='linear',
    T=T,
    device=device,
)

In [None]:
### Adversarial attacks via backpropagation ###
os.makedirs('../outputs/adversary', exist_ok=True)
y_0 = torch.tensor([0], device='cuda:0')
y_1 = torch.tensor([1], device='cuda:0')
deterministic=False

noises = fashion_mnist_cond_diffusion.explicit_sample(y_0, deterministic=deterministic)
adversarial_target = fashion_mnist_cond_diffusion.sample(y_1).squeeze(0)

# # plot the adversarial target
plt.figure()
plt.imshow(reverse_transform(adversarial_target.cpu()), cmap='gray')
plt.savefig('../outputs/adversary/adversarial_target.png')
# # plot original image
plt.figure()
plt.imshow(reverse_transform(noises[-1].squeeze(0).cpu()), cmap='gray')
plt.savefig('../outputs/adversary/original.png')

In [None]:
t = 250
latent_noise = noises[fashion_mnist_cond_diffusion.T - t]

loss = torch.norm(noises[-1] - adversarial_target)
grad = torch.autograd.grad(loss, latent_noise)[0]

eps = 1e-2

for j in range(5):
    perturbed_latent = latent_noise - j * eps * grad
    x_t = perturbed_latent
    for i in reversed(range(0, t)):
        t_cur = torch.tensor([i], device='cuda:0')
        x_t = fashion_mnist_cond_diffusion.sample_p_t_grad(x_t, t_cur, y_0, deterministic=deterministic)
    image = reverse_transform(x_t.squeeze(0).cpu())
    plt.imshow(image, cmap='gray')
    plt.savefig(f'../outputs/adversary/decoded_adversarial_perturbation_t_{t}_eps_{j * eps}.png')

In [None]:
### Decoding from latent noise ###
os.makedirs('../outputs/latent_noise', exist_ok=True)
for t in range(0, 300, 50):
    y_0 = torch.tensor([0], device='cuda:0')

    noise = torch.randn(1, 1, 28, 28).to(device)
    x_t = noise
    for i in reversed(range(0, t)):
        t_cur = torch.tensor([i], device='cuda:0')
        x_t = fashion_mnist_cond_diffusion.sample_p_t(x_t, t_cur, y_0)

    image = reverse_transform(x_t.squeeze(0).cpu())
    plt.imshow(image, cmap='gray')
    plt.savefig(f'../outputs/latent_noise/decoded_latent_noise_t_{t}.png')

In [None]:
### Adversarial attacks via latent space interpolation ###
os.makedirs('../outputs/adversary', exist_ok=True)
y_0 = torch.tensor([0], device='cuda:0')
y_1 = torch.tensor([1], device='cuda:0')
deterministic=False

noises = fashion_mnist_cond_diffusion.explicit_sample(y_0, deterministic=deterministic)
adversarial_target = fashion_mnist_cond_diffusion.sample(y_1).squeeze(0)