In [1]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
from operators import Radon, Radon_torch
from utils import random_weighted_norm, tv
import numpy as np

im_size = 128
num_thetas = 10
min_angle = 0
max_angle = 90
theta = torch.linspace(min_angle, max_angle, num_thetas)
R = Radon_torch(theta=theta)

KeyboardInterrupt: 

In [None]:
from models import UNet

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Operating on device: {device}')

model = UNet().to(device)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model.load_state_dict(torch.load('models/la-2', map_location=device))
model.to(device)

In [None]:
def adv_loss(delta, k,  xtarget, x, sgn=1, lamda=(1,1)):
    if sgn < 0:
        xtarget = x
    Rinv = R.inv(k+delta)
    oup = model(Rinv)
    l1 = torch.linalg.vector_norm(xtarget - oup)
    l2 = torch.linalg.vector_norm(Rinv - x)
    return sgn * l1 + lamda[0] * tv(oup) + lamda[1] * l2

def erange(x):
    x.data = torch.clamp(x.data, 0.,1.)

def dnorm(z, p):
    return torch.sum(z.abs()**p,dim=0)**(1/p)

def diffable_wn(xy, m, w, p, r, c=0.5):
    z = dnorm((xy - m) * w[:,None,None], p)
    return (torch.clamp(torch.exp(-(z/r)**20), min=c) - c)/(1-c)

In [None]:
m, w, r = (torch.zeros((2,1,1)), torch.tensor([1.,1.]), 0.3)
rwn = random_weighted_norm(im_size = 128)
mode = 0
if mode == 0:
    xtarget = rwn.weighted_norm(m, w, float('inf'), 0.5).to(device)
    x = rwn.weighted_norm(m, w, 2, 0.5).to(device)
elif mode == 1:
    x = rwn.weighted_norm(m, w, 2, 0.6).to(device)
    x = torch.tensor(np.dot(plt.imread('data/happy.png')[...,:3], [0.2989, 0.5870, 0.1140]), dtype=torch.float).to(device)
    xtarget = torch.tensor(np.dot(plt.imread('data/sad.png')[...,:3], [0.2989, 0.5870, 0.1140]), dtype=torch.float).to(device)
    
    x, xtarget = [torchvision.transforms.functional.rotate(xx[None,...], 45).squeeze() for xx in [x, xtarget]]
else:
    x = rwn.weighted_norm(m, w, float('inf'), 0.5).to(device)
    xtarget = rwn.weighted_norm(m, w, 2, 0.4).to(device)

In [None]:
fig, ax = plt.subplots(1,2)
ax[0].imshow(x.cpu(), cmap='bone')
ax[1].imshow(xtarget.cpu(), cmap='bone')

In [None]:
def proj_linf(delta, k, budget=0.1):
    delta.data = torch.clamp(delta.data, min=-budget, max=budget)
    
def proj_l2(delta, k, budget=0.1):
    delta.data = delta.data * (k.abs()>0)
    delta.data = budget * delta.data/torch.linalg.vector_norm(delta.data)
    
project_delta = proj_l2

In [None]:
k = R(x)
budget=.75

delta = nn.Parameter(torch.zeros_like(k))
opt = torch.optim.Adam([delta], lr=0.01)
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, patience=50)

for i in range(500):
    opt.zero_grad()
    l = adv_loss(delta, k,  xtarget, x, sgn=-1, lamda=(500,0))
    l.backward()
    opt.step()
    sched.step(l)
    project_delta(delta, k, budget=budget)
    
    
    if i%50 == 0:
        print(30*'-')
        print('Iteration: ' +str(i))
        print(l.item())
        for param_group in opt.param_groups:
            print('Current lr:' + str(param_group['lr']))

In [None]:
fig, ax = plt.subplots(1,4, figsize=(20,12))
for i,xx in enumerate([x,model(R.inv(k)), model(R.inv(k+delta)), R.inv(k+delta)]):
    ax[i].imshow(xx.cpu().detach().squeeze(), cmap='bone', vmin=0., vmax=1.)
    ax[i].axis('off')

In [None]:
fig, ax = plt.subplots(1,3, figsize=(20,12))
inv = R.inv(k+delta)
minv = model(inv)
mminv = minv

for i,xx in enumerate([inv, minv, mminv]):
    if i < 2:
        ax[i].imshow(xx.cpu().detach().squeeze(), cmap='bone', vmin=0., vmax=1.)
    else:
        im = ax[i].imshow(xx.cpu().detach().squeeze(), cmap='bone')
        #plt.colorbar(im,fraction=0.046, pad=0.04)
    ax[i].axis('off')
    
cax = fig.add_axes([ax[-1].get_position().x1+0.01,ax[-1].get_position().y0,0.02,ax[-1].get_position().height])
plt.colorbar(im, cax=cax) # Similar to fig.colorbar(im, cax = cax)
plt.savefig('results/adv.pdf')

    

In [None]:
minv.mean()

In [None]:
kk = k + delta
plt.imshow(kk.cpu().detach(), interpolation='nearest', aspect='auto',cmap='bone')
plt.axis('off')
plt.savefig('results/SinoAttack' + '.pdf')

In [None]:
kk = k
plt.imshow(kk.cpu().detach(), interpolation='nearest', aspect='auto',cmap='bone')
plt.axis('off')
plt.savefig('results/Sino' + '.pdf')

In [None]:
torch.linalg.vector_norm(delta)