In [1]:
import torch
import numpy as np
from tqdm import tqdm
import os
import pickle

external_path=''

In [2]:
def boundary_pairs(positive_digit,negative_digit):
    positive_digit_latents=torch.load(f'{external_path}\\latent_activations\\{positive_digit}.pt')
    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')

    pairs_from_positive=[]
    negative_idxs=[]

    for k in range(positive_digit_latents.shape[0]):
        positive_latent=positive_digit_latents[k,:]
        distances=np.zeros(negative_digit_latents.shape[0])
        for l in range(negative_digit_latents.shape[0]):
            distances[l]=torch.norm(positive_latent-negative_digit_latents[l,:])
        pairs_from_positive.append([k,distances.argmin()])
        negative_idxs.append(distances.argmin())

    negative_idxs=list(set(negative_idxs))

    positive_idxs=[]
    pairs_from_negative=[]

    for k in range(negative_digit_latents.shape[0]):
        negative_latent=negative_digit_latents[k,:]
        distances=np.zeros(positive_digit_latents.shape[0])
        for l in range(positive_digit_latents.shape[0]):
            distances[l]=torch.norm(negative_latent-positive_digit_latents[l,:])
        pairs_from_negative.append([distances.argmin(),k])
        positive_idxs.append(distances.argmin())

    positive_idxs=list(set(positive_idxs))

    pairs=[]

    for pair in pairs_from_positive:
        if pair[0] in positive_idxs:
            pairs.append(pair)
        
    for pair in pairs_from_negative:
        if pair[1] in negative_idxs:
            if not(pair in pairs):
                pairs.append(pair)

    return pairs

In [4]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        pbar.set_description(f'{negative_digit}/9')
        if os.path.exists(f'{external_path}\\boundary_info\\pairs\\{negative_digit}_{positive_digit}.npy'):
            pairs_reversed=np.load(f'{external_path}\\boundary_info\\pairs\\{negative_digit}_{positive_digit}.npy')
            pairs=[[pair[1],pair[0]] for pair in pairs_reversed]
        else:
            pairs=boundary_pairs(positive_digit,negative_digit)
        np.save(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy',np.array(pairs))

8/9: 100%|██████████| 10/10 [36:05<00:00, 216.51s/it]


In [13]:
def boundary_distances(positive_digit,negative_digit):
    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')

    positive_digit_latents=torch.load(f'{external_path}\\latent_activations\\{positive_digit}.pt')
    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')

    distances=np.zeros(len(pairs))
    for k,pair in enumerate(pairs):
        positive_digit_latent=positive_digit_latents[pair[0],:]
        negative_digit_latent=negative_digit_latents[pair[1],:]
        distances[k]=torch.norm(positive_digit_latent-negative_digit_latent).item()
    np.save(f'{external_path}\\boundary_info\\distances\\{positive_digit}_{negative_digit}.npy',distances)

In [14]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        boundary_distances(positive_digit,negative_digit)

100%|██████████| 10/10 [00:00<00:00, 14.45it/s]


In [14]:
def boundary_normals(positive_digit,negative_digit):
    pairs=np.load(f'{external_path}\\boundary_info\\pairs\\{positive_digit}_{negative_digit}.npy')

    positive_digit_latents=torch.load(f'{external_path}\\latent_activations\\{positive_digit}.pt')
    negative_digit_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt')

    for k,pair in enumerate(pairs):
        positive_digit_latent=positive_digit_latents[pair[0],:]
        negative_digit_latent=negative_digit_latents[pair[1],:]
        if k==0:
            normal_vector=positive_digit_latent-negative_digit_latent
            normal_vector/=torch.norm(normal_vector).item()
            normals=normal_vector.unsqueeze(0).detach().numpy()
        else:
            normal_vector=positive_digit_latent-negative_digit_latent
            normal_vector/=torch.norm(normal_vector).item()
            normals=np.concatenate([normals,normal_vector.unsqueeze(0).detach().numpy()])
    np.save(f'{external_path}\\boundary_info\\normals\\{positive_digit}_{negative_digit}.npy',normals)

In [15]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        boundary_normals(positive_digit,negative_digit)

100%|██████████| 10/10 [00:00<00:00, 21.85it/s]


In [None]:
def boundary_dots_cav(positive_digit,negative_digit):
    normals=np.load(f'{external_path}\\boundary_info\\normals\\{positive_digit}_{negative_digit}.npy')
    cav=torch.load(f'{external_path}\\concept_activation_vectors\\{positive_digit}_{negative_digit}.pt')
    dots=np.zeros(len(normals))
    for k,normal in enumerate(normals):
        dots[k]=torch.dot(torch.tensor(normal),cav)
    np.save(f'{external_path}\\boundary_info\\dots_cav\\{positive_digit}_{negative_digit}.npy',dots)

def boundary_dots_cbv(positive_digit,negative_digit):
    normals=np.load(f'{external_path}\\boundary_info\\normals\\{positive_digit}_{negative_digit}.npy')
    cbv=np.load(f'{external_path}\\concept_boundary_vectors\\{positive_digit}_{negative_digit}.npy')
    dots=np.zeros(len(normals))
    for k,normal in enumerate(normals):
        dots[k]=np.dot(normal,cbv)
    np.save(f'{external_path}\\boundary_info\\dots_cbv\\{positive_digit}_{negative_digit}.npy',dots)

In [None]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        boundary_dots_cav(positive_digit,negative_digit)
        boundary_dots_cbv(positive_digit,negative_digit)