In [None]:
import torch
import os
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from torchvision import datasets
from tqdm import tqdm
os.chdir('../src')

from models.diffusion import Diffusion, CondDiffusion
from models.adversary import Adversary, ConditionalAdversary, GuidedAdversary
from models.guidance import Guidance
from models.unet import Unet
from models.condunet import condUnet

reverse_transform = transforms.Compose([transforms.Lambda(lambda x: (x+1)/2), transforms.ToPILImage()])


In [None]:
os.makedirs('../data/CIFAR10', exist_ok=True)

dataset_train = datasets.CIFAR10(
        root='../data/CIFAR10',
        train=True,
        download=True,
        transform=transforms.Compose([transforms.Grayscale(), transforms.CenterCrop(28),transforms.ToTensor()])
    )

In [None]:
##### Unconditional Diffusion #####

model_path = '../outputs/models/fashion_mnist_diffusion_epoch_19.pt'
# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'mps')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

net = Unet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
)
net.load_state_dict(torch.load(model_path))

fashion_mnist_diffusion = Diffusion(
    model=net,
    data_shape=data_shape,
    noise_schedule='linear',
    T=T,
    device=device,
)

In [None]:
# construct the adversary
adversary = Adversary(
    data_shape=data_shape,
    diffusion_model=fashion_mnist_diffusion,
    device=device,
)

In [None]:
# sample unperturbed image from the adversary
x = adversary.sample()
plt.imshow(reverse_transform(x[0].cpu()), cmap='gray')

In [None]:
# sample gaussian perturbed image from the adversary
perturbed_sample, original_sample = adversary.gaussian_perturb(t=150)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
adversarial_target = next(iter(dataset_train))[0].unsqueeze(0).to(device)
plt.figure()
plt.imshow(reverse_transform(adversarial_target[0].cpu()), cmap='gray')
perturbed_sample, original_sample = adversary.gradient_perturb(t=150, target=adversarial_target)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
adversarial_target = next(iter(dataset_train))[0].unsqueeze(0).to(device)
plt.figure()
plt.imshow(reverse_transform(adversarial_target[0].cpu()), cmap='gray')
perturbed_sample, original_sample = adversary.gradient_descent_perturb(t=50, target=adversarial_target, scale=5)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
##### Conditional Diffusion #####

model_path = '../outputs/models/cond_fashion_mnist_diffusion_linear_sched_epoch_19.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

net = condUnet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
    num_classes=10,
)
net.load_state_dict(torch.load(model_path))

fashion_mnist_cond_diffusion = CondDiffusion(
    model=net,
    data_shape=data_shape,
    noise_schedule='linear',
    T=T,
    device=device,
)

In [None]:
# construct the adversary
adversary = ConditionalAdversary(
    data_shape=data_shape,
    diffusion_model=fashion_mnist_cond_diffusion,
    device=device,
)

In [None]:
# sample unperturbed image from the adversary
y = torch.tensor([0]).to(device)
y_adv = torch.tensor([1]).to(device)
x = adversary.sample(y)
plt.imshow(reverse_transform(x[0].cpu()), cmap='gray')

In [None]:
# sample gaussian perturbed image from the adversary
perturbed_sample, original_sample = adversary.gaussian_perturb(t=150, y=y)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
adversarial_target = adversary.sample(y_adv)
plt.figure()
plt.imshow(reverse_transform(adversarial_target[0].cpu()), cmap='gray')
perturbed_sample, original_sample = adversary.gradient_perturb(t=150, target=adversarial_target, y=y)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
adversarial_target = adversary.sample(y_adv)
plt.figure()
plt.imshow(reverse_transform(adversarial_target[0].cpu()), cmap='gray')
perturbed_sample, original_sample = adversary.gradient_descent_perturb(t=50, target=adversarial_target, scale=5, y=y)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
##### Guided Diffusion #####
model_path = '../outputs/models/guided_fashion_mnist_diffusion_epoch_19.pt'

# model params
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_shape = (28, 28)
channels = 1
dim_mults = (1, 2, 4, )
T = 300

unet = condUnet(
    dim=data_shape[0],
    channels=channels,
    dim_mults=dim_mults,
    num_classes=11,
)
unet.load_state_dict(torch.load(model_path))

guided_diffusion = Guidance(
    model=unet,
    data_shape=data_shape,
    T=T,
    device=device,
    noise_schedule='linear'
)

In [None]:
# construct the adversary
adversary = GuidedAdversary(
    data_shape=data_shape,
    diffusion_model=guided_diffusion,
    device=device,
)

In [None]:
# sample unperturbed image from the adversary
y = torch.tensor([0]).to(device)
w = 1
y_adv = torch.tensor([1]).to(device)
x = adversary.sample(y, w)
plt.imshow(reverse_transform(x[0].cpu()), cmap='gray')

In [None]:
# sample gaussian perturbed image from the adversary
perturbed_sample, original_sample = adversary.gaussian_perturb(t=150, y=y, w=w)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')

In [None]:
adversarial_target = adversary.sample(y_adv, w=w)
plt.figure()
plt.imshow(reverse_transform(adversarial_target[0].cpu()), cmap='gray')
perturbed_sample, original_sample = adversary.gradient_perturb(t=150, target=adversarial_target, y=y, w=w)
plt.figure()
plt.imshow(reverse_transform(perturbed_sample[0].cpu()), cmap='gray')
plt.figure()
plt.imshow(reverse_transform(original_sample[0].cpu()), cmap='gray')