In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import pickle

#https://github.com/nrkarthikeyan/topology-decision-boundaries/tree/master
from src.TopologicalData import *

external_path=''

In [None]:
NUM_ITERATIONS=50
NUM_FILTRATIONS=100
R=np.linspace(0.1, 2, NUM_FILTRATIONS)

In [None]:
pbar=tqdm(range(10))
for positive_digit in pbar:
    positive_latents=torch.load(f'{external_path}\\latent_activations\\{positive_digit}.pt').detach().numpy()
    positive_labels=np.ones(positive_latents.shape[0])
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        pbar.set_description(f'{negative_digit}')
        if not(os.path.exists(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}')):
            os.mkdir(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}')

        negative_latents=torch.load(f'{external_path}\\latent_activations\\{negative_digit}.pt').detach().numpy()
        negative_labels=np.zeros(negative_latents.shape[0])

        latents=np.concatenate((positive_latents,negative_latents))
        latents/=np.max(latents)
        labels=np.concatenate((positive_labels,negative_labels))

        t=TopologicalData(latents,labels,graphtype="knn_rho",scale=R,k=3,showComplexes=False,saveComplexes=False,use_cy=True,N=20,exptid=f"linear",PH_program="ripser",maxdim=1)
        t.run()

        for key,values in t.bc.items():
            if key==0:
                np.save(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}\\betti_{int(key)}_count.npy',values-t.nTriv)
            else:
                np.save(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}\\betti_{int(key)}_count.npy',values)
        dim_birth_deaths={'0':{'birth':[],'death':[]},'1':{'birth':[],'death':[]}}
        for k in range(len(t.dims)):
            if t.death_values[k]==np.inf:
                continue
            if t.dims[k]>1:
                continue
            dim_birth_deaths[str(int(t.dims[k]))]['birth'].append(t.birth_values[k])
            dim_birth_deaths[str(int(t.dims[k]))]['death'].append(t.death_values[k])
        dim_birth_deaths_file=open(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}\\dim_birth_deaths','wb')
        pickle.dump(dim_birth_deaths,dim_birth_deaths_file)
        dim_birth_deaths_file.close()

In [None]:
def sum_H0(positive_digit,negative_digit):
    dim_birth_death_file=open(f'{external_path}\\boundary_info\\topology\\{positive_digit}_{negative_digit}\\dim_birth_deaths','rb')
    dim_birth_death=pickle.load(dim_birth_death_file)
    dim_birth_death_file.close()
    H0=0
    for birth,death in zip(dim_birth_death['0']['birth'],dim_birth_death['0']['death']):
        H0+=death-birth
    return H0

In [None]:
target_gradient_cav_means=[]
target_gradient_cbv_means=[]
sums_H0=[]

for positive_digit in range(10):
    max_grad_cav=0
    max_grad_cbv=0
    digit_grad_cav_means=[]
    digit_grad_cbv_means=[]
    for negative_digit in range(10):
        if positive_digit==negative_digit:
            continue
        sums_H0.append(sum_H0(positive_digit,negative_digit))

        gradients_cav=np.load(f'{external_path}\\cluster_info\\gradients_on_target_cav\\{positive_digit}_{negative_digit}.npy')
        gradients_cbv=np.load(f'{external_path}\\cluster_info\\gradients_on_target_cbv\\{positive_digit}_{negative_digit}.npy')
        digit_grad_cav_means.append(np.mean(gradients_cav))
        digit_grad_cbv_means.append(np.mean(gradients_cbv))
        
        max_grad_cav=max(max_grad_cav,np.mean(gradients_cav))
        max_grad_cbv=max(max_grad_cbv,np.mean(gradients_cbv))

    target_gradient_cav_means+=[grad/max_grad_cav for grad in digit_grad_cav_means]
    target_gradient_cbv_means+=[grad/max_grad_cbv for grad in digit_grad_cbv_means]

fig,axs=plt.subplots(nrows=1,ncols=1)
colors=plt.cm.jet([0,0.1,0.9,1])
fig.set_figwidth(5)
fig.set_figheight(3)
sum_H0_cav_cor,sum_H0_cav_pvalue=pearsonr(sums_H0,target_gradient_cav_means,alternative='less')
sum_H0_cbv_cor,sum_H0_cbv_pvalue=pearsonr(sums_H0,target_gradient_cbv_means,alternative='less')
a_cav,b_cav=np.polyfit(sums_H0,target_gradient_cav,deg=1)
a_cbv,b_cbv=np.polyfit(sums_H0,target_gradient_cbv,deg=1)
axs.scatter(sums_H0,target_gradient_cav_means,color=colors[0],label='CAV')
axs.scatter(sums_H0,target_gradient_cbv_means,color=colors[3],label='CBV')
axs.plot(sums_H0,a_cav*np.array(sums_H0)+b_cav,color=colors[1])
axs.plot(sums_H0,a_cbv*np.array(sums_H0)+b_cbv,color=colors[2])
axs.ticklabel_format(axis='x',style='sci',scilimits=(0,3))
axs.set_xlabel('Sum H0 Lifetimes')
axs.set_ylabel('Logit Influence')
axs.set_title(f'CAV: Cor Coef {sum_H0_cav_cor:.3f}, p-Value {sum_H0_cav_pvalue:.3f}\nCBV: Cor Coef {sum_H0_cbv_cor:.3f}, p-Value {sum_H0_cbv_pvalue:.3f}')
axs.legend()

fig.savefig(f'{external_path}\\boundary_info\\topology\\boundary_topology_logit_influence.png',bbox_inches='tight')