In [2]:
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import pickle
import os

from src import *

external_path=''

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

In [4]:
model=CNN()
model.load_state_dict(torch.load('cnn_mnist.pth'))
model.eval()

CNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=64, bias=True)
  )
  (out_layer): Linear(in_features=64, out_features=10, bias=True)
)

In [5]:
correctly_classified_test_indices_file=open(f'{external_path}\\correctly_classified_test_indices','rb')
correctly_classified_test_indices=pickle.load(correctly_classified_test_indices_file)
correctly_classified_test_indices_file.close()

def get_adversarial_noise_cav(positive_digit,negative_digit):
    if not(os.path.exists(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}')):
        os.mkdir(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}')

    if os.path.exists(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\log.txt'):
        log_file=open(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\log.txt','a')
    else:
        log_file=open(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\log.txt','w')


    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    perturbation_vector=torch.load(f'{external_path}\\concept_activation_vectors\\{positive_digit}_{negative_digit}.pt').unsqueeze(0)
    perturbations=np.load(f'{external_path}\\boundary_info\\perturbations_cav\\{positive_digit}_{negative_digit}.npy')
    test_idxs=[correctly_classified_test_indices[negative_digit][pair[1]] for pair in pairs]

    success=0
    failure=0

    pbar=tqdm(range(len(pairs)))
    for k in pbar:
        if os.path.exists(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy'):
            continue
        for n,(img,label) in enumerate(test_dataset):
            if n==test_idxs[k]:
                test_image=img.unsqueeze(0)
        perturbation=perturbations[k]
        target_perturbation=(1.02*perturbation)*perturbation_vector

        adversarial_noise=(2*torch.rand((1,1,28,28))-1).requires_grad_()
        optimizer=torch.optim.Adam([adversarial_noise],lr=1e-1)

        for epoch in range(5000):
            if epoch==0:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-1)
            if epoch==400:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-2)
            if epoch==800:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-3)
            optimizer.zero_grad()
            latent_representation=model.encoder(test_image+adversarial_noise)-model.encoder(test_image)
            loss=torch.norm(latent_representation-target_perturbation)
            loss.backward()
            optimizer.step()
            if loss<1e-2:
                break
        if model(test_image+adversarial_noise).argmax()==positive_digit:
            np.save(f'{external_path}\\boundary_info\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy',adversarial_noise.detach().numpy())
            success+=1
            log_file.write('image{test_idxs[k]},success\n')
        else:
            failure+=1
            log_file.write('image{test_idxs[k]},failure\n')
        pbar.set_description(f'...{positive_digit}_{negative_digit}...success={success},failure={failure}...')
    log_file.close()

In [7]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        get_adversarial_noise_cav(positive_digit,negative_digit)

  0%|          | 0/12 [00:00<?, ?it/s]

100%|██████████| 12/12 [00:00<00:00, 3903.49it/s]
100%|██████████| 35/35 [00:00<?, ?it/s]
100%|██████████| 11/11 [00:00<00:00, 4507.80it/s]
100%|██████████| 24/24 [00:00<00:00, 63270.46it/s]
100%|██████████| 20/20 [00:00<?, ?it/s]
100%|██████████| 38/38 [00:00<00:00, 8467.04it/s]
100%|██████████| 18/18 [00:00<?, ?it/s]
100%|██████████| 28/28 [00:00<00:00, 8092.65it/s]
100%|██████████| 30/30 [00:00<00:00, 19821.85it/s]
100%|██████████| 12/12 [00:00<00:00, 3781.78it/s]
100%|██████████| 25/25 [00:00<00:00, 12382.81it/s]
100%|██████████| 16/16 [00:00<00:00, 9345.34it/s]
100%|██████████| 44/44 [00:00<00:00, 8644.00it/s]
100%|██████████| 16/16 [00:00<00:00, 8538.02it/s]
100%|██████████| 15/15 [00:00<00:00, 1803.02it/s]
100%|██████████| 18/18 [00:00<?, ?it/s]
100%|██████████| 26/26 [00:00<?, ?it/s]
100%|██████████| 16/16 [00:00<00:00, 13094.41it/s]
100%|██████████| 35/35 [00:00<00:00, 15509.84it/s]
100%|██████████| 25/25 [00:00<00:00, 8268.87it/s]
100%|██████████| 48/48 [00:00<00:00, 47005.98

In [6]:
correctly_classified_test_indices_file=open(f'{external_path}\\correctly_classified_test_indices','rb')
correctly_classified_test_indices=pickle.load(correctly_classified_test_indices_file)
correctly_classified_test_indices_file.close()

def get_adversarial_noise_cbv(positive_digit,negative_digit):
    if not(os.path.exists(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}')):
        os.mkdir(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}')

    if os.path.exists(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\log.txt'):
        log_file=open(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\log.txt','a')
    else:
        log_file=open(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\log.txt','w')


    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    perturbation_vector=torch.tensor(np.load(f'{external_path}\\concept_boundary_vectors\\{positive_digit}_{negative_digit}.npy')).unsqueeze(0)
    perturbations=np.load(f'{external_path}\\boundary_info\\perturbations_cbv\\{positive_digit}_{negative_digit}.npy')
    test_idxs=[correctly_classified_test_indices[negative_digit][pair[1]] for pair in pairs]

    success=0
    failure=0

    pbar=tqdm(range(len(pairs)))
    for k in pbar:
        if os.path.exists(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy'):
            continue
        for n,(img,label) in enumerate(test_dataset):
            if n==test_idxs[k]:
                test_image=img.unsqueeze(0)
        perturbation=perturbations[k]
        target_perturbation=(1.02*perturbation)*perturbation_vector

        adversarial_noise=(2*torch.rand((1,1,28,28))-1).requires_grad_()
        optimizer=torch.optim.Adam([adversarial_noise],lr=1e-1)

        for epoch in range(5000):
            if epoch==0:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-1)
            if epoch==400:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-2)
            if epoch==800:
                optimizer=torch.optim.Adam([adversarial_noise],lr=1e-3)
            optimizer.zero_grad()
            latent_representation=model.encoder(test_image+adversarial_noise)-model.encoder(test_image)
            loss=torch.norm(latent_representation-target_perturbation)
            loss.backward()
            optimizer.step()
            if loss<1e-2:
                break
        if model(test_image+adversarial_noise).argmax()==positive_digit:
            np.save(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy',adversarial_noise.detach().numpy())
            success+=1
            log_file.write('image{test_idxs[k]},success\n')
        else:
            failure+=1
            log_file.write('image{test_idxs[k]},failure\n')
        pbar.set_description(f'...{positive_digit}_{negative_digit}...success={success},failure={failure}...')
    log_file.close()

In [1]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        get_adversarial_noise_cbv(positive_digit,negative_digit)

NameError: name 'get_adversarial_noise_opt' is not defined

In [6]:
def get_least_perturbed_images(positive_digit,negative_digit,num_images,perturbation_vector_type='cav'):
    if perturbation_vector_type=='cav':
        folder='adversarial_noise_cav'
    elif perturbation_vector_type=='opt':
        folder='adversarial_noise_cbv'

    adversarial_noises_files=os.listdir(f'{external_path}\\{folder}\\{positive_digit}_{negative_digit}')
    test_idxs=np.array([int(adversarial_noises_file.split('_')[-1].split('.')[0]) for adversarial_noises_file in adversarial_noises_files if not('log' in adversarial_noises_file)])

    distances=[]
    for file in adversarial_noises_files:
        if not('log' in file):
            adversarial_noise=np.load(f'{external_path}\\{folder}\\{positive_digit}_{negative_digit}\\{file}')
            distances.append(np.linalg.norm(adversarial_noise))
    distances=np.array(distances)

    return test_idxs[np.argsort(distances)[:num_images]],distances[np.argsort(distances)[:num_images]]

In [7]:
import matplotlib
matplotlib.use('agg')

pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        num_images=3
        fig,axs=plt.subplots(nrows=2,ncols=num_images,layout='constrained')
        fig.set_figheight(6)
        fig.set_figwidth(num_images*2)
        fig.suptitle(f'Perturbing {negative_digit} for {positive_digit} Reclassification')

        subfigs = fig.subfigures(nrows=2, ncols=1)
        subfigs[0].suptitle('CAV Based Perturbations')
        subfigs[1].suptitle('CBV Based Perturbations')

        test_idxs_to_plot,distances=get_least_perturbed_images(positive_digit,negative_digit,num_images)

        for k,(img,label) in enumerate(test_dataset):
            if k in test_idxs_to_plot:
                relative_idx=np.where(test_idxs_to_plot==k)[0][0]
                adversarial_noise=np.load(f'{external_path}\\adversarial_noise_cav\\{positive_digit}_{negative_digit}\\noise_for_test_image_{k}.npy')
                axs[0][relative_idx].imshow((img.detach().numpy()+adversarial_noise).squeeze(0).squeeze(0))
                axs[0][relative_idx].set_title(str(round(distances[relative_idx],2)))
                axs[0][relative_idx].axis('off')

        test_idxs_to_plot,distances=get_least_perturbed_images(positive_digit,negative_digit,num_images,perturbation_vector_type='opt')

        for k,(img,label) in enumerate(test_dataset):
            if k in test_idxs_to_plot:
                relative_idx=np.where(test_idxs_to_plot==k)[0][0]
                adversarial_noise=np.load(f'{external_path}\\adversarial_noise_cbv\\{positive_digit}_{negative_digit}\\noise_for_test_image_{k}.npy')
                axs[1][relative_idx].imshow((img.detach().numpy()+adversarial_noise).squeeze(0).squeeze(0))
                axs[1][relative_idx].set_title(str(round(distances[relative_idx],2)))
                axs[1][relative_idx].axis('off')

        plt.savefig(f'{external_path}\\comparing_adversarial_images\\{positive_digit}_{negative_digit}.png')

  fig,axs=plt.subplots(nrows=2,ncols=num_images,layout='constrained')
8/9: 100%|██████████| 10/10 [07:08<00:00, 42.83s/it]


In [14]:
fig,axs=plt.subplots(nrows=2,ncols=5,layout='constrained')
fig.set_figwidth(12)
fig.set_figheight(6)
fig.suptitle('Amplitude of Image Perturbations')

for digit,ax in enumerate(axs.flatten()):
    ax.set_xlim((-0.2,9.2))
    for negative_digit in range(10):
        if negative_digit==digit:
            continue
        adversarial_files=os.listdir(f'{external_path}\\adversarial_noise_cav\\{digit}_{negative_digit}')
        distances=[]
        for file in adversarial_files:
            if not('log' in file):
                adversarial_noise=np.load(f'{external_path}\\adversarial_noise_cav\\{digit}_{negative_digit}\\{file}')
                distances.append(np.linalg.norm(adversarial_noise))

        adversarial_files_opt=os.listdir(f'{external_path}\\adversarial_noise_cbv\\{digit}_{negative_digit}')
        distances_opt=[]
        for file in adversarial_files_opt:
            if not('log' in file):
                adversarial_noise=np.load(f'{external_path}\\adversarial_noise_cbv\\{digit}_{negative_digit}\\{file}')
                distances_opt.append(np.linalg.norm(adversarial_noise))
        if digit==4 and negative_digit==0:
            ax.scatter(x=(negative_digit-0.1)*np.ones(len(distances)),y=distances,s=5,color='red',label='CAV')
            ax.scatter(x=(negative_digit+0.1)*np.ones(len(distances_opt)),y=distances_opt,s=5,color='blue',label='CBV')
            ax.legend()
        else:
            ax.scatter(x=(negative_digit-0.1)*np.ones(len(distances)),y=distances,s=5,color='red')
            ax.scatter(x=(negative_digit+0.1)*np.ones(len(distances_opt)),y=distances_opt,s=5,color='blue') 
    ax.set_xticks(range(10))
plt.savefig(f'{external_path}\\comparing_adversarial_images\\image_perturbation_amplitudes.png')