In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import csv
from sklearn.manifold import TSNE
from tqdm import tqdm
import os
import gudhi as gd

external_path=''

In [None]:
bert_latents=torch.load(f'{external_path}\\bert_text_embeds_prompts.pt').detach().numpy()
clip_latents=torch.load(f'{external_path}\\clip_text_embeds_prompts.pt').detach().numpy()

model=TSNE(n_components=2)
bert_transformed_latents=model.fit_transform(bert_latents)
clip_transformed_latents=model.fit_transform(clip_latents)

np.save(f'{external_path}\\bert_transformed_latents.npy',bert_transformed_latents)
np.save(f'{external_path}\\clip_transformed_latents.npy',clip_transformed_latents)

In [None]:
fig,axs=plt.subplots(nrows=1,ncols=2)
fig.set_figwidth(8)
fig.set_figheight(4)
axs[0].scatter(bert_transformed_latents[:,0],bert_transformed_latents[:,1],s=2)
axs[0].axis('off')
axs[0].set_title('BERT Embeddings')
axs[1].scatter(clip_transformed_latents[:,0],clip_transformed_latents[:,1],s=2)
axs[1].axis('off')
axs[1].set_title('CLIP Embeddings')
plt.show()

In [None]:
with open('similar_from_COCO.csv','r') as file:
    csvFile=csv.reader(file)
    for n,lines in enumerate(csvFile):
        if n==0:
            dictionary_keys=lines
            similar_COCO={key:[] for key in dictionary_keys}
        else:
            for key,value in zip(dictionary_keys,lines):
                if key=='Sample 1' or key=='Sample 2':
                    similar_COCO[key].append(value)
                elif key=='Sample 1 Index' or key=='Sample 2 Index':
                    similar_COCO[key].append(int(value))
                else:
                    similar_COCO[key].append(float(value))

In [None]:
samples=np.concatenate((np.array(similar_COCO['Sample 1 Index']),np.array(similar_COCO['Sample 2 Index'])))
pbar=tqdm(samples)
count=0
for sample in pbar:
    if not(os.path.exists(f'{external_path}\\sample_bert_pd\\dim0\\sample{sample}.npy')):
        pbar.set_description(f'Sample {sample}...computing BERT distances...')
        bert_latents_distances=np.zeros(bert_latents.shape[0])
        for k in range(bert_latents.shape[0]):
            bert_latents_distances[k]=np.linalg.norm(bert_latents[sample,:]-bert_latents[k,:])
        
        pbar.set_description(f'Sample {sample}...computing BERT PD...')
        bert_neighbourhood_latents=clip_latents[np.argsort(bert_latents_distances)[:128],:]
        barcodes=gd.RipsComplex(points=bert_neighbourhood_latents).create_simplex_tree(max_dimension=2).persistence()
        dim_0=np.asarray([np.array(x) for dim, x in barcodes if dim==0])
        dim_1=np.asarray([np.array(x) for dim, x in barcodes if dim==1])
        np.save(f'{external_path}\\sample_bert_pd\\dim0\\sample{sample}.npy',dim_0)
        np.save(f'{external_path}\\sample_bert_pd\\dim1\\sample{sample}.npy',dim_1)

        pbar.set_description(f'Sample {sample}...computing CLIP distances...')
        clip_latents_distances=np.zeros(clip_latents.shape[0])
        for k in range(clip_latents.shape[0]):
            clip_latents_distances[k]=np.linalg.norm(clip_latents[sample,:]-clip_latents[k,:])
        
        pbar.set_description(f'Sample {sample}...computing CLIP PD...')
        clip_neighbourhood_latents=clip_latents[np.argsort(clip_latents_distances)[:128],:]
        barcodes=gd.RipsComplex(points=clip_neighbourhood_latents).create_simplex_tree(max_dimension=2).persistence()
        dim_0=np.asarray([np.array(x) for dim, x in barcodes if dim==0])
        dim_1=np.asarray([np.array(x) for dim, x in barcodes if dim==1])
        np.save(f'{external_path}\\sample_clip_pd\\dim0\\sample{sample}.npy',dim_0)
        np.save(f'{external_path}\\sample_clip_pd\\dim1\\sample{sample}.npy',dim_1)

        count+=1
        if count>50:
            break

samples=np.random.choice(clip_latents.shape[0],size=50)
pbar=tqdm(samples)
count=0
for sample in pbar:
    if not(os.path.exists(f'{external_path}\\sample_bert_pd\\dim0\\sample{sample}.npy')):
        pbar.set_description(f'Sample {sample}...computing BERT distances...')
        bert_latents_distances=np.zeros(bert_latents.shape[0])
        for k in range(bert_latents.shape[0]):
            bert_latents_distances[k]=np.linalg.norm(bert_latents[sample,:]-bert_latents[k,:])
        
        pbar.set_description(f'Sample {sample}...computing BERT PD...')
        bert_neighbourhood_latents=clip_latents[np.argsort(bert_latents_distances)[:128],:]
        barcodes=gd.RipsComplex(points=bert_neighbourhood_latents).create_simplex_tree(max_dimension=2).persistence()
        dim_0=np.asarray([np.array(x) for dim, x in barcodes if dim==0])
        dim_1=np.asarray([np.array(x) for dim, x in barcodes if dim==1])
        np.save(f'{external_path}\\random_bert_pd\\dim0\\sample{sample}.npy',dim_0)
        np.save(f'{external_path}\\random_bert_pd\\dim1\\sample{sample}.npy',dim_1)

        pbar.set_description(f'Sample {sample}...computing CLIP distances...')
        clip_latents_distances=np.zeros(clip_latents.shape[0])
        for k in range(clip_latents.shape[0]):
            clip_latents_distances[k]=np.linalg.norm(clip_latents[sample,:]-clip_latents[k,:])
        
        pbar.set_description(f'Sample {sample}...computing CLIP PD...')
        clip_neighbourhood_latents=clip_latents[np.argsort(clip_latents_distances)[:128],:]
        barcodes=gd.RipsComplex(points=clip_neighbourhood_latents).create_simplex_tree(max_dimension=2).persistence()
        dim_0=np.asarray([np.array(x) for dim, x in barcodes if dim==0])
        dim_1=np.asarray([np.array(x) for dim, x in barcodes if dim==1])
        np.save(f'{external_path}\\random_clip_pd\\dim0\\sample{sample}.npy',dim_0)
        np.save(f'{external_path}\\random_clip_pd\\dim1\\sample{sample}.npy',dim_1)

In [None]:
samples=os.listdir(f'{external_path}\\sample_clip_pd\\dim0')
clip_H0_complexity=[]
for sample in samples:
    clip_H0=np.load(f'{external_path}\\sample_clip_pd\\dim0\\{sample}')
    clip_H0_complexity.append(sum([death-birth for (birth,death) in clip_H0 if death!=np.inf]))

random_samples=os.listdir(f'{external_path}\\random_clip_pd\\dim0')
random_clip_H0_complexity=[]
for random_sample in random_samples:
    clip_H0=np.load(f'{external_path}\\random_clip_pd\\dim0\\{random_sample}')
    random_clip_H0_complexity.append(sum([death-birth for (birth,death) in clip_H0 if death!=np.inf]))

fig,axs=plt.subplots(nrows=1,ncols=1)
colors=plt.cm.jet(np.linspace(0,1,2))
axs.hist(clip_H0_complexity,color=colors[0],bins=15,alpha=0.5,density=True,label='Collision Points')
axs.hist(random_clip_H0_complexity,color=colors[1],bins=15,alpha=0.5,density=True,label='Random')
axs.set_xlabel('Sum of H0 Lifetimes')
axs.legend()
plt.show()