In [1]:
import torch
import torchvision
from tqdm.notebook import tqdm
import torch.nn.functional as F
import numpy as np
import random
from attack_util import batch_pgd, pgd, fgsm
from data_util import load_mnist, to_device
from util import project_lp, scale_im, accuracy, asr, dsa, expand_first


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [20]:
class MNet(torch.nn.Module):
    def __init__(self):
        super(MNet, self).__init__()
        self.l1 = torch.nn.Linear(28*28, 256)
        self.l2 = torch.nn.Linear(256, 64)
        self.l3 = torch.nn.Linear(64, 10)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        out = torch.relu(self.l1(x))
        out = torch.relu(self.l2(out))
        return self.l3(out)

In [18]:
train_loader, test_loader = load_mnist(64, 10, device)

In [21]:
# mdl = to_device(MNet(), device)
# mdl.load_state_dict(torch.load('mnist_01.pth'))
# mdl = mdl.eval()

mdl = to_device(MNet(), device)
crit = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(mdl.parameters(), lr=0.01) 
for epoch in range(2):
    pbar = tqdm(train_loader)
    for batch_id, (images, labels) in enumerate(pbar):
        outputs = mdl(images)
        loss = F.cross_entropy(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step() 
        if batch_id == len(train_loader) - 1:
            pbar.set_postfix({"Test Accuracy":dsa(test_loader, mdl)})

  0%|          | 0/938 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

In [45]:
def find_hard_uaps(xs, ys, mdl, adv_alg, xi, n = 10, device = device):
  attack_dirs = []
  for x, y in zip(xs, ys):
    attack_dirs.append(adv_alg(, y, mdl, xi = xi, device = device))
  cos_sim = torch.zeros_like((len(xs), len(xs)))
  for i in range(len(xs)):
    for j in range(i + 1, len(xs)):
      cos_sim[i][j] = F.cosine_similarity(attack_dirs[i], attack_dirs[j])
      cos_sim[j][i] = cos_sim[i][j]
  idx = find_largest(cos_sim, n)
  return xs[idx], ys[idx]

def find_largest(mat, n):
  return find_largest_greed(mat, n)

def find_largest_greed(mat, n):
  return torch.argmax(mat.sum(axis = 0))[:n]

def find_largest_brute(mat, n):
  best_idx = None
  best_value = 0
  for idx in perm(len(mat), n):
    cur_sum = mat[idx, idx].sum()
    if cur_sum > best_value:
      best_value = cur_sum
      best_idx = idx
  return best_idx

In [46]:
x_hard, y_hard = find_hard_uaps(x_test, y_test, mdl, pgd, 0.1)

ValueError: Expected input batch_size (1) to match target batch_size (0).

In [22]:
x_test, y_test = next(iter(train_loader))
#x_test = x_test.view(-1, 28*28)

In [41]:
v = batch_pgd(x_test, y_test, mdl, device = device)

In [42]:
accuracy(mdl(x_test + v), y_test)

tensor(0.7812)

In [43]:
accuracy(mdl(x_test), y_test)

tensor(0.9375)

In [None]:
def local_uap(x, y, mdl, adv_alg, xi = 0.1, n = 10):
  current_x = x
  xs = []
  for i in range(n):
    xs.append(current_x)
    current_x = x - adv_alg(x, y, mdl, xi)
  return xs