In [1]:
import numpy as np
import pandas as pd
import os 
import pickle
from collections import Counter
from tqdm.notebook import tqdm

def load_data(file):

    print('loading file: ' + file)
    with open(file, 'rb') as f:
        data = pickle.load(f)

    return(data)

def dump_data(data, filename):
    print('writing file: ' + filename)
    with open(filename, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        
def upper(df):

    try:
        assert(type(df)==np.ndarray)
    except:
        if type(df)==pd.DataFrame:
            df = df.values
        else:
            raise TypeError('Must be np.ndarray or pd.DataFrame')
    mask = np.triu_indices(df.shape[0], k=1)
    return df[mask]


In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.backends.backend_pdf import PdfPages

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rc('image', cmap='viridis')

plt.rcParams.update({
    "font.family": "serif",  
    # use serif/main font for text elements
    })
plt.rcParams.update({'font.size': 20})
plt.rcParams.update({'font.weight': 'bold'})
plt.rcParams.update({'axes.linewidth': 2.5})
plt.rcParams.update({'axes.labelweight': 'bold'})
plt.rcParams.update({'axes.labelsize': 20})
plt.rc('legend',fontsize=12)

def plot_images(paths):
    num_images = len(paths)
    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 5, 5))
    
    for i, path in enumerate(paths):
        im_name = path.split("/")[-1].split(".")[0]
        img = mpimg.imread(path)
        axes[i].imshow(img)
        axes[i].set_title(im_name, fontsize = 16)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    

def plot_images_to_pdf(paths, pdf_filename='images.pdf'):
    with PdfPages(pdf_filename) as pdf:
        for row in paths:
            fig, axes = plt.subplots(1, len(row), figsize=(len(row) * 5, 5))
            for j, path in enumerate(row):
                img = mpimg.imread(path)
                axes[j].imshow(img)
                axes[j].axis('off')
            pdf.savefig(fig) 
            plt.close(fig)

In [None]:
project_dir = "/projects/crunchie/boyanova/EEG_Things/Grouping-Embeddings"
class_type = "animate"
top_file = load_data(os.path.join(project_dir, "files", class_type, "top25_CLIP_vis_blip.pickle"))
top_image_embeddings = top_file["image_embeddings"]
top_txt_embeddings = top_file["text_embeddings"]

In [None]:
clusters = top_file['cluster']
cluster_categories = dict()
for cl in np.unique(clusters):
    indices = np.where(top_file["cluster"] == cl) 
    concept = top_file["bigger_concept"][indices]
    
    # Count occurrences of each word
    concept = [x for x in concept if x != "None"]
    word_counts = Counter(concept)

    # Find the most frequent word
    most_common_word, count = word_counts.most_common(1)[0]
    print(f"Cluster {cl}: {most_common_word}/{count}")
    cluster_categories[cl]=word_counts 

In [None]:
data_struct = {}
vis = top_image_embeddings
vis_cluster = top_file["clusters"]

txt = top_txt_embeddings
txt_cluster = top_file["clusters"]
stim_paths = top_file["stimuli_paths"]
stim = [x.split("/")[-1] for x in stim_paths]

for q in tqdm(range(len(vis))):
    
    # Step 1: Get Query 
    data_struct[stim[q]] = {}
    
    vis_embedding_q = vis[q]
    vis_cluster_q = vis_cluster[q]
    txt_embedding_q = txt[q]
    txt_cluster_q = txt_cluster[q]

    # Step 2: Keep all images from same cluserm but different category
    overlap_indexes = np.where(vis_cluster == vis_cluster_q)[0]
    overlap_indexes = np.insert(overlap_indexes, 0, q, axis=0)

    # Step 3: Compute correlation RDMs
    rdm_vis = 1 - np.corrcoef(vis[overlap_indexes])
    rdm_txt = 1 - np.corrcoef(txt[overlap_indexes])
    
    # make diag high
    np.fill_diagonal(rdm_vis, 1000)
    np.fill_diagonal(rdm_txt, 1000) 

    #Step 4: Select first two
    #selected = [0]
    selected = []
    vec1_vis = rdm_vis[0, ::]
    vec1_txt = rdm_txt[0, ::]
    mask = np.isin(np.arange(0, len(vec1_vis)), selected)
    vec1_vis[mask] = 1000
    vec1_txt[mask] = 1000
    joint = vec1_vis + vec1_txt

    min_ = np.argmin(joint)
    selected.append(min_)

    # update
    for _ in range(3):
        vec1_vis += rdm_vis[min_, ::]
        vec1_txt += rdm_txt[min_, ::]
        mask = np.isin(np.arange(0, len(vec1_vis)), selected)
        vec1_vis[mask] = 1000
        vec1_txt[mask] = 1000
        joint = vec1_vis + vec1_txt
        min_ = np.argmin(joint)
        selected.append(min_)

    data_struct[stim[q]]["stimuli_paths"] = stim_paths[overlap_indexes[selected]]
    data_struct[stim[q]]["indexes_masked"] = overlap_indexes[selected]

#dump_data(data_struct, "selections_txt.pickle")



In [None]:
from scipy import stats
rhos = []
for key in tqdm(data_struct.keys()):
    indexes = data_struct[key]["indexes_masked"]
    rdm_vis = 1 - np.corrcoef(vis[indexes])
    rdm_txt = 1 - np.corrcoef(txt[indexes])
    res = stats.spearmanr(upper(rdm_vis), upper(rdm_txt))
    rhos.append(res.statistic)
    
    
sorted_indexes = np.argsort(np.abs(rhos))    
rhos = np.array(rhos)
    

In [None]:
sorted_keys = np.array(list(data_struct.keys()))[np.where(np.abs(rhos) >= 0.2)]
for key in sorted_keys[0:10]:
    plot_images(data_struct[key]["stimuli_paths"])
    
    

In [None]:
sorted_keys = np.array(list(data_struct.keys()))[np.where(np.abs(rhos) > 0.9)]
key = sorted_keys[1]
indexes = data_struct[key]["indexes_masked"]
# Calculate the RDMs for vision and text embeddings
rdm_vis = 1 - np.corrcoef(vis[indexes])
rdm_txt = 1 - np.corrcoef(txt[indexes])

# Calculate Spearman correlation coefficient
res = stats.spearmanr(upper(rdm_vis), upper(rdm_txt))
corr = np.corrcoef(upper(rdm_vis), upper(rdm_txt))[0,1]


# Plotting RDMs side by side
fig, axes = plt.subplots(1, 3, figsize=(12, 5))

# Plotting CLIP text RDM
im1 = axes[0].imshow(rdm_txt, vmin=0.0, vmax=1.0, cmap="viridis")
axes[0].set_title("CLIP Text RDM", fontsize = 12)
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

# Plotting CLIP vision RDM
im2 = axes[1].imshow(rdm_vis, vmin=0.0, vmax=1.0, cmap="viridis")
axes[1].set_title("CLIP Vision RDM", fontsize = 12)
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

# Plotting scatter
axes[2].scatter(upper(rdm_txt), upper(rdm_vis))
axes[2].grid(True)
axes[2].set_ylabel("CLIP Vis Corr", fontsize = 12)
axes[2].set_xlabel("CLIP Txt Corr", fontsize = 12)

# Display Spearman correlation coefficient in the subtitle
fig.suptitle(f"Spearman Correlation (RDMs): {res.statistic:.2f}\n Pearson's Correlation (RDMs): {corr:.2f}", fontsize=24)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
save_path = "/projects/crunchie/boyanova/EEG_Things/eeg_prep/figures/00_rdm_desc.png"
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()

# Plot the stimuli images
plot_images(data_struct[key]["stimuli_paths"])