In [23]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

from src import *

from tqdm import tqdm
import pickle
import numpy as np
import matplotlib.pyplot as plt

external_path='c:\\Users\\thoma\\Documents\\working_docs\\LIoT_aidos_external\\cnn2'

In [3]:
model=CNN()
model.load_state_dict(torch.load('cnn_mnist.pth'))
model.eval()

CNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=64, bias=True)
  )
  (out_layer): Linear(in_features=64, out_features=10, bias=True)
)

In [9]:
def instance_perturbation(model,instance,target_digit,perturbation_vector,perturbation_start,perturbation_increment,perturbation_max):
    perturbation_amplitude=perturbation_start
    while model.out_layer(instance+perturbation_amplitude*perturbation_vector).argmax()!=target_digit and perturbation_amplitude<perturbation_max:
            perturbation_amplitude+=perturbation_increment
    return perturbation_amplitude-perturbation_increment

In [14]:
def perturbation_amplitudes_cav(positive_digit,negative_digit):

    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')
    perturbation=torch.load(f'{external_path}\\concept_activation_vectors\\{positive_digit}_{negative_digit}.pt')

    perturbations=np.zeros(negative_digit_latents.shape[0])
    for k in range(negative_digit_latents.shape[0]):
        perturbation_amplitude=0
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1,1e3)
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1e-1,1e3)
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1e-2,1e3)
        perturbation_amplitude+=1e-2
        perturbations[k]=perturbation_amplitude
        if perturbation_amplitude>=1e3:
            perturbation_amplitude=np.inf
    return perturbations

In [15]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        perturbations=perturbation_amplitudes_cav(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\perturbations_cav\\{positive_digit}_{negative_digit}.npy',perturbations)

8/9: 100%|██████████| 10/10 [04:19<00:00, 25.94s/it]


In [18]:
def perturbation_amplitudes_cbv(positive_digit,negative_digit):

    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')
    perturbation=torch.tensor(np.load(f'{external_path}\\concept_boundary_vectors\\{positive_digit}_{negative_digit}.npy'))

    perturbations=np.zeros(negative_digit_latents.shape[0])
    for k in range(negative_digit_latents.shape[0]):
        perturbation_amplitude=0
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1,1e3)
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1e-1,1e3)
        perturbation_amplitude=instance_perturbation(model,negative_digit_latents[k,:],positive_digit,perturbation,perturbation_amplitude,1e-2,1e3)
        perturbation_amplitude+=1e-2
        perturbations[k]=perturbation_amplitude
        if perturbation_amplitude>=1e3:
            perturbation_amplitude=np.inf
    return perturbations

In [19]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        perturbations=perturbation_amplitudes_cbv(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\perturbations_cbv\\{positive_digit}_{negative_digit}.npy',perturbations)

8/9: 100%|██████████| 10/10 [07:36<00:00, 45.62s/it]


In [21]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        perturbations_cav=np.load(f'{external_path}\\boundary_info\\perturbations_cav\\{positive_digit}_{negative_digit}.npy')
        perturbations_cbv=np.load(f'{external_path}\\boundary_info\\perturbations_cbv\\{positive_digit}_{negative_digit}.npy')
        ratios=perturbations_cav/perturbations_cbv
        np.save(f'{external_path}\\cluster_info\\ratio_perturbations_cav_cbv\\arrays\\{positive_digit}_{negative_digit}.npy',ratios)

In [27]:
colors=plt.cm.jet(np.linspace(0,1,2))
for positive_digit in range(10):
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        ratios=np.load(f'{external_path}\\cluster_info\\ratio_perturbations_cav_cbv\\arrays\\{positive_digit}_{negative_digit}.npy')
        plt.hist(ratios,color=colors[0],density=True)
        plt.title(f'Ratio Between CAV and CBV Based Reclassification Perturbations\nFor {negative_digit} to {positive_digit} Concept')
        plt.xlabel('Ratio')
        xlims=plt.gca().get_xlim()
        plt.xlim(min(1,xlims[0]),xlims[1])
        plt.savefig(f'{external_path}\\cluster_info\\ratio_perturbations_cav_cbv\\plots\\{positive_digit}_{negative_digit}.png')
        plt.close()