In [1]:
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import pickle
import os

from src import *

external_path=''

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

In [3]:
model=CNN()
model.load_state_dict(torch.load('cnn_mnist.pth'))
model.eval()

CNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=64, bias=True)
  )
  (out_layer): Linear(in_features=64, out_features=10, bias=True)
)

In [13]:
correctly_classified_test_indices_file=open(f'{external_path}\\correctly_classified_test_indices','rb')
correctly_classified_test_indices=pickle.load(correctly_classified_test_indices_file)
correctly_classified_test_indices_file.close()

def spatial_dependency_cav(positive_digit,negative_digit,epsilon=1e-2,epochs=500):
    if not(os.path.exists(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}')):
        os.mkdir(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}')

    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    perturbation_vector=torch.load(f'{external_path}\\concept_activation_vectors\\{positive_digit}_{negative_digit}.pt')
    test_idxs=[correctly_classified_test_indices[negative_digit][pair[1]] for pair in pairs]

    dots={}

    pbar=tqdm(range(len(pairs)))
    for k in pbar:
        for n,(img,label) in enumerate(test_dataset):
            if n==test_idxs[k]:
                test_image=img.unsqueeze(0)

        adversarial_noise=(2*torch.rand((1,1,28,28))-1)
        adversarial_noise=(adversarial_noise/torch.norm(adversarial_noise)).requires_grad_()
        optimizer=torch.optim.Adam([adversarial_noise],lr=1e-3)
        for epoch in range(epochs):
            optimizer.zero_grad()
            gradient=(model.encoder(test_image+epsilon*adversarial_noise/torch.norm(adversarial_noise))-model.encoder(test_image))/epsilon
            loss=-torch.dot(gradient.squeeze(0),perturbation_vector)
            loss.backward()
            optimizer.step()
        pbar.set_description(f'...{positive_digit}_{negative_digit}...loss={loss:.4f}...')
            
        np.save(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy',(adversarial_noise/torch.norm(adversarial_noise)).detach().numpy())
        
        final_gradient=(model.encoder(test_image+epsilon*adversarial_noise/torch.norm(adversarial_noise))-model.encoder(test_image))/epsilon
        dots[test_idxs[k]]=torch.dot(final_gradient.squeeze(0),perturbation_vector).item()

    dots_file=open(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}\\dots','wb')
    pickle.dump(dots,dots_file)
    dots_file.close()

In [41]:
correctly_classified_test_indices_file=open(f'{external_path}\\correctly_classified_test_indices','rb')
correctly_classified_test_indices=pickle.load(correctly_classified_test_indices_file)
correctly_classified_test_indices_file.close()

def spatial_dependency_cbv(positive_digit,negative_digit,epsilon=1e-2,epochs=500):
    if not(os.path.exists(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}')):
        os.mkdir(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}')

    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    perturbation_vector=torch.tensor(np.load(f'{external_path}\\concept_boundary_vectors\\{positive_digit}_{negative_digit}.npy'))
    test_idxs=[correctly_classified_test_indices[negative_digit][pair[1]] for pair in pairs]

    dots={}

    pbar=tqdm(range(len(pairs)))
    for k in pbar:
        for n,(img,label) in enumerate(test_dataset):
            if n==test_idxs[k]:
                test_image=img.unsqueeze(0)

        adversarial_noise=(2*torch.rand((1,1,28,28))-1)
        adversarial_noise=(adversarial_noise/torch.norm(adversarial_noise)).requires_grad_()
        optimizer=torch.optim.Adam([adversarial_noise],lr=1e-3)
        for epoch in range(epochs):
            optimizer.zero_grad()
            gradient=(model.encoder(test_image+epsilon*adversarial_noise/torch.norm(adversarial_noise))-model.encoder(test_image))/epsilon
            loss=-torch.dot(gradient.squeeze(0),perturbation_vector)
            loss.backward()
            optimizer.step()
        pbar.set_description(f'...{positive_digit}_{negative_digit}...loss={loss:.4f}...')
            
        np.save(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}\\noise_for_test_image_{test_idxs[k]}.npy',(adversarial_noise/torch.norm(adversarial_noise)).detach().numpy())
        
        final_gradient=(model.encoder(test_image+epsilon*adversarial_noise/torch.norm(adversarial_noise))-model.encoder(test_image))/epsilon
        dots[test_idxs[k]]=torch.dot(final_gradient.squeeze(0),perturbation_vector).item()

    dots_file=open(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}\\dots','wb')
    pickle.dump(dots,dots_file)
    dots_file.close()

In [42]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        spatial_dependency_cav(positive_digit,negative_digit)

...0_1...loss=-2.2098...: 100%|██████████| 12/12 [01:09<00:00,  5.76s/it]


In [16]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        spatial_dependency_cbv(positive_digit,negative_digit)

100%|██████████| 12/12 [00:00<00:00, 12087.33it/s]
100%|██████████| 35/35 [00:00<00:00, 17440.97it/s]
100%|██████████| 11/11 [00:00<00:00, 5492.54it/s]
100%|██████████| 24/24 [00:00<00:00, 11999.44it/s]
100%|██████████| 20/20 [00:00<00:00, 12951.38it/s]
100%|██████████| 38/38 [00:00<00:00, 18775.30it/s]
100%|██████████| 18/18 [00:00<00:00, 18040.02it/s]
100%|██████████| 28/28 [00:00<00:00, 14001.01it/s]
100%|██████████| 30/30 [00:00<00:00, 14965.40it/s]
100%|██████████| 12/12 [00:00<00:00, 11980.87it/s]
100%|██████████| 25/25 [00:00<00:00, 4166.80it/s]
100%|██████████| 16/16 [00:00<00:00, 15993.53it/s]
100%|██████████| 44/44 [00:00<00:00, 22022.60it/s]
100%|██████████| 16/16 [00:00<00:00, 8002.49it/s]
100%|██████████| 15/15 [00:00<00:00, 14834.84it/s]
100%|██████████| 18/18 [00:00<00:00, 17979.87it/s]
100%|██████████| 26/26 [00:00<00:00, 11643.38it/s]
100%|██████████| 16/16 [00:00<00:00, 16012.61it/s]
100%|██████████| 35/35 [00:00<00:00, 16446.41it/s]
100%|██████████| 25/25 [00:00<00:0

In [None]:
import matplotlib
matplotlib.use('agg')

top=3

pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        dots_file=open(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}\\dots','rb')
        dots=pickle.load(dots_file)
        dots_file.close()

        sorted_dots={k: v for k, v in sorted(dots.items(), key=lambda item: item[1],reverse=True)}

        plotted_idxs=[]
        fig,axs=plt.subplots(nrows=2,ncols=top)
        fig.suptitle(f'Spatial Dependency {positive_digit}_{negative_digit} CBV')
        for n,idx in enumerate(sorted_dots.keys()):
            if n>top-1:
                break

            axs[1][n].imshow(np.load(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}\\noise_for_test_image_{idx}.npy').squeeze(0).squeeze(0))
            axs[1][n].xaxis.set_visible(False)
            axs[1][n].tick_params(left=False, labelleft=False)
            axs[0][n].set_title(str(idx))
            if n==0:
                axs[1][n].set_ylabel('Image Perturbation')
            plotted_idxs.append(idx)
        for n,(img,label) in enumerate(test_dataset):
            if n in plotted_idxs:
                relative_idx=plotted_idxs.index(n)
                axs[0][relative_idx].imshow(img.squeeze(0))
                axs[0][relative_idx].xaxis.set_visible(False)
                axs[0][relative_idx].tick_params(left=False, labelleft=False)
                if relative_idx==0:
                    axs[0][relative_idx].set_ylabel('Image Original')
        plt.savefig(f'{external_path}\\spatial_dependency_cbv\\{positive_digit}_{negative_digit}\\top{top}_image_perturbations.png')
        plt.close()

In [17]:
import matplotlib
matplotlib.use('agg')

top=3

pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        dots_file=open(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}\\dots','rb')
        dots=pickle.load(dots_file)
        dots_file.close()

        sorted_dots={k: v for k, v in sorted(dots.items(), key=lambda item: item[1],reverse=True)}

        plotted_idxs=[]
        fig,axs=plt.subplots(nrows=2,ncols=top)
        fig.suptitle(f'Spatial Dependency {positive_digit}_{negative_digit} CAV')
        for n,idx in enumerate(sorted_dots.keys()):
            if n>top-1:
                break

            axs[1][n].imshow(np.load(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}\\noise_for_test_image_{idx}.npy').squeeze(0).squeeze(0))
            axs[1][n].xaxis.set_visible(False)
            axs[1][n].tick_params(left=False, labelleft=False)
            axs[0][n].set_title(str(idx))
            if n==0:
                axs[1][n].set_ylabel('Image Perturbation')
            plotted_idxs.append(idx)
        for n,(img,label) in enumerate(test_dataset):
            if n in plotted_idxs:
                relative_idx=plotted_idxs.index(n)
                axs[0][relative_idx].imshow(img.squeeze(0))
                axs[0][relative_idx].xaxis.set_visible(False)
                axs[0][relative_idx].tick_params(left=False, labelleft=False)
                if relative_idx==0:
                    axs[0][relative_idx].set_ylabel('Image Original')
        plt.savefig(f'{external_path}\\spatial_dependency_cav\\{positive_digit}_{negative_digit}\\top{top}_image_perturbations.png')
        plt.close()

8/9: 100%|██████████| 10/10 [03:20<00:00, 20.07s/it]
