In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

from src import *

from tqdm import tqdm
import pickle
import numpy as np

external_path='c:\\Users\\thoma\\Documents\\working_docs\\LIoT_aidos_external\\cnn2'

In [2]:
model=CNN()
model.load_state_dict(torch.load('cnn_mnist.pth'))
model.eval()

CNN(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=64, bias=True)
  )
  (out_layer): Linear(in_features=64, out_features=10, bias=True)
)

In [9]:
def perturbation_amplitudes_cav(positive_digit,negative_digit,perturbation_increment=1e-2):

    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    negative_idxs=np.array([int(pair[1]) for pair in pairs],dtype=int)

    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')[negative_idxs]
    perturbation=torch.load(f'{external_path}\\concept_activation_vectors\\{positive_digit}_{negative_digit}.pt')

    perturbations=np.zeros(negative_digit_latents.shape[0])
    for k in range(negative_digit_latents.shape[0]):
        perturbation_amplitude=0
        while model.out_layer(negative_digit_latents[k,:]+perturbation_amplitude*perturbation).argmax()!=positive_digit and perturbation_amplitude<2e2:
            perturbation_amplitude+=perturbation_increment
        perturbations[k]=perturbation_amplitude
        if perturbation_amplitude>=2e2:
            perturbation_amplitude=np.inf
    return perturbations

In [10]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        perturbations=perturbation_amplitudes_cav(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\perturbations_cav\\{positive_digit}_{negative_digit}.npy',perturbations)

8/9: 100%|██████████| 10/10 [03:37<00:00, 21.78s/it]


In [1]:
def perturbation_amplitudes_cbv(positive_digit,negative_digit,perturbation_increment=1e-2):

    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    negative_idxs=np.array([int(pair[1]) for pair in pairs],dtype=int)

    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')[negative_idxs]
    perturbation=torch.tensor(np.load(f'{external_path}\\concept_boundary_vectors\\{positive_digit}_{negative_digit}.npy'))

    perturbations=np.zeros(negative_digit_latents.shape[0])
    for k in range(negative_digit_latents.shape[0]):
        perturbation_amplitude=0
        while model.out_layer(negative_digit_latents[k,:]+perturbation_amplitude*perturbation).argmax()!=positive_digit and perturbation_amplitude<2e2:
            perturbation_amplitude+=perturbation_increment
        perturbations[k]=perturbation_amplitude
        if perturbation_amplitude>=2e2:
            perturbation_amplitude=np.inf
    return perturbations

In [5]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        perturbations=perturbation_amplitudes_cbv(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\perturbations_cbv\\{positive_digit}_{negative_digit}.npy',perturbations)

8/9: 100%|██████████| 10/10 [03:04<00:00, 18.41s/it]


In [None]:
def perturbation_amplitudes_boundary_normals(positive_digit,negative_digit,perturbation_increment=1e-2):

    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')
    negative_idxs=np.array([int(pair[1]) for pair in pairs],dtype=int)

    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')[negative_idxs]
    normals=torch.tensor(np.load(f'{external_path}\\boundary_info\\normals\\{positive_digit}_{negative_digit}.npy'))

    perturbations=np.zeros(negative_digit_latents.shape[0])
    for k in range(negative_digit_latents.shape[0]):
        perturbation=normals[k,:]
        perturbation_amplitude=0
        while model.out_layer(negative_digit_latents[k,:]+perturbation_amplitude*perturbation).argmax()!=positive_digit and perturbation_amplitude<2e2:
            perturbation_amplitude+=perturbation_increment
        perturbations[k]=perturbation_amplitude
        if perturbation_amplitude>=2e2:
            perturbation_amplitude=np.inf
    return perturbations

In [None]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if negative_digit==positive_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        perturbations=perturbation_amplitudes_boundary_normals(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\perturbations_boundary_normals\\{positive_digit}_{negative_digit}.npy',perturbations)

In [7]:
for positive_digit in range(10):
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        perturbations_cav=np.load(f'{external_path}\\boundary_info\\perturbations_cav\\{positive_digit}_{negative_digit}.npy')
        perturbations_cbv=np.load(f'{external_path}\\boundary_info\\perturbations_cbv\\{positive_digit}_{negative_digit}.npy')
        ratios=perturbations_cav/perturbations_cbv
        np.save(f'{external_path}\\boundary_info\\ratio_perturbations_cav_cbv\\{positive_digit}_{negative_digit}.npy',ratios)