In [None]:
import pickle
import json

import numpy as np
import pandas as pd

from rdkit import Chem
from rdkit.Chem import Draw

from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# Loading Pickle file from callback
with open("./outputs/eval_exp_vibraclip_graph_ir_raman_mass_allpairs_09_3.pkl", "rb") as p_file:
    data = pickle.load(p_file)
p_file.close()

In [None]:
# Variables
num_labels = len(list(data.keys()))
print(num_labels)

### **1) General Retrieval Accuracy Metrics**

In [None]:
# Metrics
def top_k_accuracy(retrieved, relevant_item, k):
    """Measures the proportion of relevant items in the top K retrieved results"""
    acc = 1.0 if relevant_item in retrieved[:k] else 0.0
    index = retrieved.index(relevant_item) + 1
    return acc, index

def mean_average_precision(retrieved, relevant_item):
    """Computes the average precision for each query and then takes the mean over all queries"""
    if relevant_item in retrieved:
        index = retrieved.index(relevant_item) + 1
        return 1 / index
    else:
        return 0.0

In [None]:
# Initialize accumulators for each top K value
top_k_values = [1, 5, 10, 15, 20, 25]
total_acc = {k: 0 for k in top_k_values}

# Loop over the data
top_k_dict = {k: [] for k in top_k_values}
for idx, (smile_target, smile_cand) in tqdm(enumerate(data.items())):
    # Get the candidates labels 
    sorted_labels = list(smile_cand.keys())
    
    # Compute accuracy for each K and update the corresponding total
    for k in top_k_values:
        top_k_acc, top_k_idx = top_k_accuracy(retrieved=sorted_labels, relevant_item=smile_target, k=k)
        total_acc[k] += top_k_acc
        # Storing positive ones
        if top_k_acc == 1 and top_k_idx == k: 
            top_k_dict[k].append({smile_target: smile_cand})

# Calculate and print average accuracies
top_k_acc_dict = {}
for k in top_k_values:
    avg_acc = total_acc[k] / num_labels
    top_k_acc_dict.update({str(k): avg_acc})
    print(f"Average Acc. Top {k}: {avg_acc*100}")

In [None]:
print(top_k_acc_dict)

In [None]:
# Plot
plt.plot(list(top_k_acc_dict.keys()), list(top_k_acc_dict.values()), "--")
plt.scatter(list(top_k_acc_dict.keys()), list(top_k_acc_dict.values()))
plt.xlabel("Top K")
plt.ylabel("Retrieval Accuracy")
plt.show()

### 2) **Grid Molecules at Top K**

In [None]:
def _get_grid_plot(smiles_target, smiles_cand, similarity):
    """Method to build a grid of molecules"""
    # Convert SMILES to RDKIT Molecule objects
    molec_target = Chem.MolFromSmiles(smiles_target)
    molec_cand = [Chem.MolFromSmiles(smile) for smile in smiles_cand]

    # Create a 5x5 grid plot
    fig, axes = plt.subplots(6, 5, figsize=(10,10))

    # Top left cell for the target molecule
    ax = axes[0, 0]
    img_target = Draw.MolToImage(molec_target, size=(250, 250))
    ax.imshow(img_target)
    ax.set_axis_off()

    # Hide the rest of the cells in the first row
    for ax in axes[0, 1:]:
        ax.set_visible(False)

    for i, ax in enumerate(axes[1:, :].flat):
        if i < len(molec_cand):
            img = Draw.MolToImage(molec_cand[i], size=(250, 250))
            ax.imshow(img)
            ax.set_axis_off()
            ax.text(0.5, -0.1, f"{similarity[i]:.3f}", ha='center', va='top', transform=ax.transAxes, fontsize=11)
        else:
            ax.set_visible(False)

    plt.tight_layout()
    return fig

In [None]:
def plot_random_top_k(top_k, top_k_dict):
    """Select a random entry from the top K list and plot it"""
    # Get the list for the specified top K
    top_k_list = top_k_dict[top_k]
    
    # Choose a random index
    rand_idx = np.random.randint(0, len(top_k_list))
    print(rand_idx)
    
    # Extract the target and candidates
    smile_data = top_k_list[rand_idx]
    smile_target = list(smile_data.keys())[0]
    smile_cands_dict = next(iter(smile_data.values()))
    
    # Generate the plot
    return _get_grid_plot(smile_target, list(smile_cands_dict.keys()), list(smile_cands_dict.values()))

In [None]:
# Top K=1
plot_k_1 = plot_random_top_k(1, top_k_dict)
#plot_k_1.savefig("./exp_top1.svg")


In [None]:
# Top K=5
plot_k_5 = plot_random_top_k(5, top_k_dict)
#plot_k_5.savefig("./exp_top5.svg")

In [None]:
# Top K=10
plot_k_10 = plot_random_top_k(10, top_k_dict)
#plot_k_10.savefig("./exp_top10.svg")

In [None]:
# Top K=15
plot_k_15 = plot_random_top_k(15, top_k_dict)
#plot_k_15.savefig("./exp_top15.svg")

In [None]:
# Top K=20
plot_k_20 = plot_random_top_k(20, top_k_dict)
#plot_k_20.savefig("./exp_top20.svg")

In [None]:
# Top K=25
plot_k_25 = plot_random_top_k(25, top_k_dict)
#plot_k_25.savefig("./exp_top25.svg")