In [None]:
import numpy as np
import pandas as pd
import pickle
from scipy import stats
from collections import Counter
from tqdm.notebook import tqdm

def load_data(file):

    print('loading file: ' + file)
    with open(file, 'rb') as f:
        data = pickle.load(f)

    return(data)

def dump_data(data, filename):
    print('writing file: ' + filename)
    with open(filename, 'wb') as f:
        pickle.dump(data, f, pickle.HIGHEST_PROTOCOL)
        
def upper(df):

    try:
        assert(type(df)==np.ndarray)
    except:
        if type(df)==pd.DataFrame:
            df = df.values
        else:
            raise TypeError('Must be np.ndarray or pd.DataFrame')
    mask = np.triu_indices(df.shape[0], k=1)
    return df[mask]


In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.backends.backend_pdf import PdfPages

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rc('image', cmap='viridis')

plt.rcParams.update({
    "font.family": "serif",  
    # use serif/main font for text elements
    })
plt.rcParams.update({'font.size': 20})
plt.rcParams.update({'font.weight': 'bold'})
plt.rcParams.update({'axes.linewidth': 2.5})
plt.rcParams.update({'axes.labelweight': 'bold'})
plt.rcParams.update({'axes.labelsize': 20})
plt.rc('legend',fontsize=12)

def plot_images(paths):
    num_images = len(paths)
    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 5, 5))
    
    for i, path in enumerate(paths):
        im_name = path.split("/")[-1].split(".")[0]
        img = mpimg.imread(path)
        axes[i].imshow(img)
        axes[i].set_title(im_name, fontsize = 16)
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()
    

def plot_images_to_pdf(paths, pdf_filename='images.pdf'):
    with PdfPages(pdf_filename) as pdf:
        for row in paths:
            fig, axes = plt.subplots(1, len(row), figsize=(len(row) * 5, 5))
            for j, path in enumerate(row):
                img = mpimg.imread(path)
                axes[j].imshow(img)
                axes[j].axis('off')
            pdf.savefig(fig) 
            plt.close(fig) 

In [None]:
top_file = load_data("top_CLIP_vis.pkl")
top_image_embeddings = load_data("top_CLIP_vis.pkl")["image_embeddings"]
top_txt_embeddings = load_data("top_CLIP_vis.pkl")["text_embeddings"]

In [None]:

clusters = top_file['cluster']
cluster_categories = dict()
for cl in np.unique(clusters):
    indices = np.where(top_file["cluster"] == cl) 
    concept = top_file["bigger_concept"][indices]
    
    # Count occurrences of each word
    concept = [x for x in concept if x != "None"]
    word_counts = Counter(concept)

    # Find the most frequent word
    most_common_word, count = word_counts.most_common(1)[0]
    print(f"Cluster {cl}: {most_common_word}/{count}")
    cluster_categories[cl]=word_counts 

In [None]:
data_struct = {}
vis = top_image_embeddings
vis_cluster = top_file["cluster"]

txt = top_txt_embeddings
txt_cluster = top_file["cluster"]
stim_paths = top_file["stimuli_paths"]
stim = [x.split("/")[-1] for x in stim_paths]

for q in tqdm(range(len(vis))):
    
    # Step 1: Get Query 
    data_struct[stim[q]] = {}
    
    vis_embedding_q = vis[q]
    vis_cluster_q = vis_cluster[q]
    txt_embedding_q = txt[q]
    txt_cluster_q = txt_cluster[q]

    # Step 2: Keep all images from different cluser
    other_indexes = np.where(vis_cluster != vis_cluster_q)[0]
    other_indexes = np.insert(other_indexes, 0, q, axis=0)

    # Step 3: Compute correlation RDMs
    rdm_vis = 1 - np.corrcoef(vis[other_indexes])
    rdm_txt = 1 - np.corrcoef(txt[other_indexes])

    #Step 4: Select first two
    selected = [0]
    vec1_vis = rdm_vis[0, ::]
    vec1_txt = rdm_txt[0, ::]
    mask = np.isin(np.arange(0, len(vec1_vis)), selected)
    vec1_vis[mask] = -1000
    vec1_txt[mask] = -1000
    joint = vec1_vis + vec1_txt

    max_ = np.argmax(joint)
    selected.append(max_)

    # update
    for _ in range(2):
        vec1_vis += rdm_vis[max_, ::]
        vec1_txt += rdm_txt[max_, ::]
        mask = np.isin(np.arange(0, len(vec1_vis)), selected)
        vec1_vis[mask] = -1000
        vec1_txt[mask] = -1000
        joint = vec1_vis + vec1_txt
        max_ = np.argmax(joint)
        selected.append(max_)

    data_struct[stim[q]]["stimuli_paths"] = stim_paths[other_indexes[selected]]
    data_struct[stim[q]]["indexes"] = other_indexes[selected]

dump_data(data_struct, "selections.pickle")



In [None]:
rhos = []
for key in tqdm(data_struct.keys()):
    indexes = data_struct[key]["indexes"]
    rdm_vis = 1 - np.corrcoef(vis[indexes])
    rdm_txt = 1 - np.corrcoef(txt[indexes])
    res = stats.spearmanr(upper(rdm_vis), upper(rdm_txt))
    rhos.append(res.statistic)
    
    
sorted_indexes = np.argsort(np.abs(rhos))    
rhos = np.array(rhos)
sorted_keys = np.array(list(data_struct.keys()))[np.where(np.abs(rhos) <= 0.2)]    

In [None]:
for key in sorted_keys:
    plot_images(data_struct[key]["stimuli_paths"])

In [None]:
all_paths = []
for key in sorted_keys:
    all_paths.append(data_struct[key]["stimuli_paths"])

# Flatten the list of lists if necessary
flattened_paths = [path for sublist in all_paths for path in sublist]
plot_images_to_pdf(all_paths, 'selected_images_txt.pdf')  


In [None]:
# Creating the new dictionary with selected indexes
key = 'bamboo_01b.jpg'
indexes = data_struct[key]["indexes"]
new_dict = {key: [top_file[key][i] for i in indexes] for key in top_file.keys()}
dump_data(new_dict, "exp_stimuli_selection.pkl")

In [None]:
# Selecting only the relevant keys for the DataFrame
df = pd.DataFrame({
    'cluster': new_dict['cluster'],
    'stimuli_paths': new_dict['stimuli_paths'],
    'stim_name':[x.split("/")[-1] for x in new_dict['stimuli_paths']],
    'category': [x.split("/")[-1].split("_")[0] for x in new_dict['stimuli_paths']],
    'bigger_concept': new_dict['bigger_concept']
})

# Display the DataFrame
df.to_csv("exp_stimuli_desc.csv")

In [None]:
# Load data for specific key
key = 'bamboo_01b.jpg'
indexes = data_struct[key]["indexes"]
labels = df["category"].values

# Calculate the RDMs for vision and text embeddings
rdm_vis = 1 - np.corrcoef(vis[indexes])
rdm_txt = 1 - np.corrcoef(txt[indexes])

# Calculate Spearman correlation coefficient
res = stats.spearmanr(upper(rdm_vis), upper(rdm_txt))
corr = np.corrcoef(upper(rdm_vis), upper(rdm_txt))[0,1]


# Plotting RDMs side by side
fig, axes = plt.subplots(1, 3, figsize=(12, 5))

# Plotting CLIP text RDM
im1 = axes[0].imshow(rdm_txt, vmin=0.0, vmax=1.0, cmap="viridis")
axes[0].set_title("CLIP Text RDM", fontsize = 12)
axes[0].set_xticks(range(len(labels)))
axes[0].set_yticks(range(len(labels)))
axes[0].set_xticklabels(labels, rotation=90, fontsize=8)
axes[0].set_yticklabels(labels, fontsize=8)
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

# Plotting CLIP vision RDM
im2 = axes[1].imshow(rdm_vis, vmin=0.0, vmax=1.0, cmap="viridis")
axes[1].set_title("CLIP Vision RDM", fontsize = 12)
axes[1].set_xticks(range(len(labels)))
axes[1].set_yticks(range(len(labels)))
axes[1].set_xticklabels(labels, rotation=90, fontsize=8)
axes[1].set_yticklabels(labels, fontsize=8)
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

# Plotting scatter
axes[2].scatter(upper(rdm_txt), upper(rdm_vis))
axes[2].grid(True)
axes[2].set_ylabel("CLIP Vis Corr", fontsize = 12)
axes[2].set_xlabel("CLIP Txt Corr", fontsize = 12)

# Display Spearman correlation coefficient in the subtitle
fig.suptitle(f"Spearman Correlation (RDMs): {res.statistic:.2f}\n Pearson's Correlation (RDMs): {corr:.2f}", fontsize=24)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
save_path = "/projects/crunchie/boyanova/EEG_Things/eeg_prep/figures/00_rdm_desc.png"
plt.savefig(save_path, dpi=300, bbox_inches="tight")
plt.show()

# Plot the stimuli images
plot_images(data_struct[key]["stimuli_paths"])
