In [26]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from map_mhr_mrr import calculate_metric
import faiss
from prettytable import PrettyTable
from collections import defaultdict
from models.gcn_molclr import GCN
from dataset.dataset_contrastive import USPTO50_contrastive
from torch_geometric.loader import DataLoader

In [27]:
uspto_triplets_dataset_original = pd.read_pickle('dataset/uspto_50_retrieval.pickle')
# USPTO_triplets_dataclass = USPTO50_contrastive(uspto_triplets_dataset_original, return_index=True, split='all')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
uspto_triplets_dataset_embeddings = uspto_triplets_dataset_original.copy()
USPTO_triplets_dataclass = USPTO50_contrastive(uspto_triplets_dataset_original, return_index=True, split='all')
uspto_triplets_dataset_original

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,exclude_indices
0,<rdkit.Chem.rdchem.Mol object at 0x7fe779f5d760>,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc98180>,<RX_1>,train,"[0, 1]"
1,<rdkit.Chem.rdchem.Mol object at 0x7fe7749fcc20>,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc98180>,<RX_1>,train,"[0, 1]"
2,<rdkit.Chem.rdchem.Mol object at 0x7fe779f5ddf0>,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc981d0>,<RX_6>,train,[2]
3,<rdkit.Chem.rdchem.Mol object at 0x7fe779f5cc70>,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc98220>,<RX_9>,train,"[3, 4]"
4,<rdkit.Chem.rdchem.Mol object at 0x7fe779f5f880>,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc98220>,<RX_9>,train,"[3, 4]"
...,...,...,...,...,...
85533,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc6ffb0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9260>,<RX_7>,test,[85533]
85534,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc78040>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[85534, 85535]"
85535,<rdkit.Chem.rdchem.Mol object at 0x7fe93cac8090>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[85534, 85535]"
85536,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc980e0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9300>,<RX_1>,test,"[85536, 85537]"


In [28]:
def update_exclude_indices(df, old_to_new_index):
    df['exclude_indices'] = df['exclude_indices'].apply(lambda x: [old_to_new_index.get(i, i) for i in x])
    return df

old_to_new_index = {old_index: new_index for new_index, old_index in enumerate(uspto_triplets_dataset_original[uspto_triplets_dataset_original['set'] == 'test'].index)}

uspto_triplets_dataset_original_filtered = uspto_triplets_dataset_original[uspto_triplets_dataset_original['set'] == 'test'].reset_index(drop=True)
uspto_triplets_dataset_original_filtered = update_exclude_indices(uspto_triplets_dataset_original_filtered, old_to_new_index)
uspto_triplets_dataset_embeddings = uspto_triplets_dataset_original_filtered.copy()

USPTO_triplets_dataclass = USPTO50_contrastive(uspto_triplets_dataset_original_filtered, return_index=True, split='all')
uspto_triplets_dataset_original_filtered

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,exclude_indices
0,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098310>,<rdkit.Chem.rdchem.Mol object at 0x7fe930f9b150>,<RX_1>,test,"[0, 1]"
1,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098360>,<rdkit.Chem.rdchem.Mol object at 0x7fe930f9b150>,<RX_1>,test,"[0, 1]"
2,<rdkit.Chem.rdchem.Mol object at 0x7fe93d0983b0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31a0>,<RX_4>,test,"[2, 3]"
3,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098400>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31a0>,<RX_4>,test,"[2, 3]"
4,<rdkit.Chem.rdchem.Mol object at 0x7fe93cee0450>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31f0>,<RX_2>,test,"[4, 5]"
...,...,...,...,...,...
8558,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc6ffb0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9260>,<RX_7>,test,[8558]
8559,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc78040>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[8559, 8560]"
8560,<rdkit.Chem.rdchem.Mol object at 0x7fe93cac8090>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[8559, 8560]"
8561,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc980e0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9300>,<RX_1>,test,"[8561, 8562]"


In [29]:
def get_GCN_model(path):
    gcn_model = GCN(feat_dim=512)
    if path is not "random_weights":
        try:
            gcn_model.load_state_dict(torch.load(path))
        except:
            gcn_model.load_state_dict(torch.load(path, map_location='cuda:0'))
    gcn_model.eval()
    gcn_model.to(device)
    
    return gcn_model

In [30]:
model_checkpoint_paths = {
    "triplet_from_scratch": "ckpt/TripletMarginFromScratchCheckpoints/checkpoints/model.pth",
    "pretrained": "ckpt/pretrained_gcn/checkpoints/model.pth",
    "triplet_loss_finetuned": "ckpt/TripletMarginCheckpoints/checkpoints/model.pth",
    "triplet_cosine_loss_finetuned": "ckpt/TripletMarginCosineDistanceCheckpoints/checkpoints/model.pth",
    "random_weights": "random_weights"
    # triplet_loss_trained: ,
    # triplet_cosine_loss_trained: ,
}

GCN_models = defaultdict()

for model_checkpoint_path in model_checkpoint_paths.keys():
    GCN_models[model_checkpoint_path] = get_GCN_model(model_checkpoint_paths[model_checkpoint_path])

In [31]:
# uspto_graph_retrieval_dataloader = DataLoader(USPTO_triplets_dataclass, batch_size=32, shuffle=False, num_workers=8, pin_memory=True)

In [32]:
def save_embeddings_to_dataframe(model_name, model):
    uspto_graph_retrieval_dataloader = DataLoader(USPTO_triplets_dataclass, batch_size=32, shuffle=False, num_workers=36, pin_memory=True)
    print("Getting embeddings for ", model_name)
    reactants_embedding = []
    products_embedding = []

    with torch.no_grad():
        for bn, (anchor, positive, negative, index) in enumerate(tqdm(uspto_graph_retrieval_dataloader)):
            anchor = anchor.pin_memory().to(device, non_blocking=True)
            positive = positive.pin_memory().to(device, non_blocking=True)

            _, anchor_embedding = model(anchor)
            _, positive_embedding = model(positive)

            reactants_embedding.extend(positive_embedding.cpu().detach().numpy())
            products_embedding.extend(anchor_embedding.cpu().detach().numpy())

    uspto_triplets_dataset_embeddings[model_name + "_reactants_embedding"] = reactants_embedding
    uspto_triplets_dataset_embeddings[model_name + "_products_embedding"] = products_embedding

    return reactants_embedding, products_embedding

In [33]:
for model_name in GCN_models.keys():
    reactants_embedding, products_embedding = save_embeddings_to_dataframe(model_name, GCN_models[model_name])

Getting embeddings for  triplet_from_scratch


100%|██████████| 268/268 [00:09<00:00, 27.14it/s]


Getting embeddings for  pretrained


100%|██████████| 268/268 [00:10<00:00, 25.99it/s]


Getting embeddings for  triplet_loss_finetuned


100%|██████████| 268/268 [00:10<00:00, 26.55it/s]


Getting embeddings for  triplet_cosine_loss_finetuned


100%|██████████| 268/268 [00:10<00:00, 26.40it/s]


Getting embeddings for  random_weights


100%|██████████| 268/268 [00:10<00:00, 26.75it/s]


In [34]:
uspto_triplets_dataset_embeddings

Unnamed: 0,reactants_mol,products_mol,reaction_type,set,exclude_indices,triplet_from_scratch_reactants_embedding,triplet_from_scratch_products_embedding,pretrained_reactants_embedding,pretrained_products_embedding,triplet_loss_finetuned_reactants_embedding,triplet_loss_finetuned_products_embedding,triplet_cosine_loss_finetuned_reactants_embedding,triplet_cosine_loss_finetuned_products_embedding,random_weights_reactants_embedding,random_weights_products_embedding
0,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098310>,<rdkit.Chem.rdchem.Mol object at 0x7fe930f9b150>,<RX_1>,test,"[0, 1]","[0.07329723, -0.036533702, -0.13493162, -0.121...","[0.062163208, -0.045985818, -0.15452383, -0.11...","[-0.0001417166, -2.6986338e-05, -9.015823e-05,...","[-0.00011566309, -5.8518326e-05, -9.1130445e-0...","[-0.0016309423, -0.0010638223, 7.022045e-05, -...","[-0.0014088374, -0.0008960876, 0.00030777595, ...","[-0.0007551508, -8.454921e-05, -0.00064842013,...","[-0.0007358731, -1.0567495e-05, -0.00057030155...","[-0.14416327, -0.23207569, -0.12935998, -0.140...","[-0.15209006, -0.22431044, -0.123266794, -0.14..."
1,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098360>,<rdkit.Chem.rdchem.Mol object at 0x7fe930f9b150>,<RX_1>,test,"[0, 1]","[0.05798794, -0.049530365, -0.16187091, -0.117...","[0.06216321, -0.045985818, -0.15452383, -0.118...","[-0.00010589301, -7.0342816e-05, -9.149502e-05...","[-0.0001156631, -5.8518337e-05, -9.1130445e-05...","[-0.0013255478, -0.00083318714, 0.00039685922,...","[-0.0014088373, -0.0008960877, 0.00030777598, ...","[-0.000728644, 1.7175647e-05, -0.0005410071, -...","[-0.0007358732, -1.0567495e-05, -0.00057030155...","[-0.15506265, -0.22139843, -0.120981835, -0.14...","[-0.15209007, -0.22431044, -0.12326681, -0.147..."
2,<rdkit.Chem.rdchem.Mol object at 0x7fe93d0983b0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31a0>,<RX_4>,test,"[2, 3]","[0.058492642, -0.049101904, -0.16098283, -0.11...","[0.075878836, -0.03241189, -0.12553608, -0.105...","[-0.000107074, -6.8913505e-05, -9.145095e-05, ...","[-0.00016480377, -4.688659e-05, -9.865948e-05,...","[-0.0013356158, -0.00084079057, 0.00038609098,...","[-0.0011423073, -0.00077340833, 0.00022383843,...","[-0.0007295179, 1.382208e-05, -0.00054454815, ...","[-0.0008902296, 3.5511268e-05, -0.00064315065,...","[-0.15470335, -0.22175041, -0.121258035, -0.14...","[-0.13787062, -0.20371185, -0.1423492, -0.1255..."
3,<rdkit.Chem.rdchem.Mol object at 0x7fe93d098400>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31a0>,<RX_4>,test,"[2, 3]","[0.081615575, -0.024887748, -0.10876873, -0.08...","[0.07587884, -0.032411896, -0.12553608, -0.105...","[-0.00020166622, -6.805564e-05, -0.00011015776...","[-0.00016480377, -4.688659e-05, -9.865948e-05,...","[-0.0005140618, -0.00040703668, 0.0003884005, ...","[-0.0011423075, -0.0007734083, 0.00022383845, ...","[-0.0010797498, 0.0001860623, -0.00065124984, ...","[-0.0008902296, 3.5511268e-05, -0.00064315065,...","[-0.12766126, -0.16623697, -0.1614062, -0.1043...","[-0.13787067, -0.20371185, -0.1423492, -0.1255..."
4,<rdkit.Chem.rdchem.Mol object at 0x7fe93cee0450>,<rdkit.Chem.rdchem.Mol object at 0x7fe930fb31f0>,<RX_2>,test,"[4, 5]","[0.06465318, -0.043152895, -0.1483344, -0.1126...","[0.06735051, -0.040359665, -0.14232247, -0.109...","[-0.00012784016, -6.160422e-05, -9.4164134e-05...","[-0.00013859726, -6.106161e-05, -9.6204385e-05...","[-0.0012572818, -0.00081091665, 0.00033240052,...","[-0.0011702304, -0.0007656742, 0.00032927585, ...","[-0.0007888426, 2.4006566e-05, -0.0005790605, ...","[-0.00082757353, 4.186712e-05, -0.00059188774,...","[-0.14865777, -0.21480931, -0.12894222, -0.140...","[-0.14557573, -0.20882416, -0.13343729, -0.135..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8558,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc6ffb0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9260>,<RX_7>,test,[8558],"[0.08219527, -0.030393315, -0.12498383, -0.130...","[0.08219526, -0.03039332, -0.12498384, -0.1301...","[-0.00014001525, -1.1374301e-05, -6.007977e-05...","[-0.00014001524, -1.1374301e-05, -6.007977e-05...","[-0.0014816881, -0.00095057586, 6.0529415e-05,...","[-0.001481688, -0.00095057586, 6.0529408e-05, ...","[-0.0008192311, -0.00022404816, -0.00084567675...","[-0.0008192311, -0.00022404816, -0.00084567675...","[-0.14080426, -0.19668728, -0.1292082, -0.1333...","[-0.14080428, -0.19668727, -0.12920819, -0.133..."
8559,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc78040>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[8559, 8560]","[0.08393727, -0.010955384, -0.109863505, -0.11...","[0.07877298, -0.016210381, -0.11928492, -0.115...","[-0.00015914187, -1.681137e-05, -7.815085e-05,...","[-0.00014745073, -3.0386442e-05, -7.9184865e-0...","[-0.0011356545, -0.0010345613, 0.00058439036, ...","[-0.0010698744, -0.00096673745, 0.0006555661, ...","[-0.0006866289, -0.00035001856, -0.00090476684...","[-0.00068226334, -0.00030545253, -0.0008589701...","[-0.13278896, -0.1968321, -0.13848296, -0.1192...","[-0.13666563, -0.19547582, -0.13548306, -0.123..."
8560,<rdkit.Chem.rdchem.Mol object at 0x7fe93cac8090>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca92b0>,<RX_10>,test,"[8559, 8560]","[-0.014184419, -0.110800356, -0.28887045, -0.0...","[0.07877297, -0.016210381, -0.11928493, -0.115...","[6.298968e-05, -0.00027473771, -9.779702e-05, ...","[-0.00014745073, -3.0386458e-05, -7.918486e-05...","[0.00011417031, 0.0002540932, 0.0019367283, -0...","[-0.0010698743, -0.00096673745, 0.00065556605,...","[-0.0006036839, 0.0004967357, -3.4631114e-05, ...","[-0.0006822634, -0.00030545253, -0.00085897016...","[-0.20644553, -0.17106286, -0.08148491, -0.196...","[-0.13666561, -0.19547582, -0.13548307, -0.123..."
8561,<rdkit.Chem.rdchem.Mol object at 0x7fe93cc980e0>,<rdkit.Chem.rdchem.Mol object at 0x7fe930ca9300>,<RX_1>,test,"[8561, 8562]","[0.075758025, -0.018317888, -0.13094226, -0.13...","[0.082724474, -0.030398643, -0.1252674, -0.134...","[-0.000108777975, -1.7502069e-05, -5.007727e-0...","[-0.0001369042, -4.4505005e-06, -5.7214547e-05...","[-0.001291107, -0.00095358386, 0.0007577774, -...","[-0.0015953815, -0.0010198255, 1.118468e-05, -...","[-0.000672513, -0.00050645653, -0.0009683486, ...","[-0.0007941968, -0.00025921117, -0.00085882033...","[-0.14248128, -0.18556345, -0.121501505, -0.13...","[-0.14146711, -0.20203795, -0.12688655, -0.135..."


In [35]:
np.array(reactants_embedding).shape

(8563, 300)

In [36]:
reactants_embeddings = defaultdict()
products_embeddings = defaultdict()
exclude_indices = defaultdict()

for model_name in GCN_models.keys():
    reactants_embeddings[model_name] = torch.tensor(uspto_triplets_dataset_embeddings[f'{model_name}_reactants_embedding'].tolist())
    products_embeddings[model_name] = torch.tensor(uspto_triplets_dataset_embeddings[f'{model_name}_products_embedding'].tolist())
    exclude_indices[model_name] = uspto_triplets_dataset_embeddings['exclude_indices'].tolist()

In [37]:
# check if the shape of all reactant embeddings is the same
for key, value in reactants_embeddings.items():
    print(key, value.shape)

d = reactants_embeddings['pretrained'].shape[1]

triplet_from_scratch torch.Size([8563, 300])
pretrained torch.Size([8563, 300])
triplet_loss_finetuned torch.Size([8563, 300])
triplet_cosine_loss_finetuned torch.Size([8563, 300])
random_weights torch.Size([8563, 300])


In [38]:
indices = defaultdict()

for model_name in GCN_models.keys():
    indices[model_name] = faiss.IndexFlatL2(d)

for model_name in GCN_models.keys():
    indices[model_name] = faiss.index_cpu_to_all_gpus(indices[model_name])

## Normalising the embeddings

In [39]:
for model_name in GCN_models.keys():
    reactants_embeddings[model_name] = reactants_embeddings[model_name].numpy() / np.linalg.norm(reactants_embeddings[model_name].numpy(), axis=1, keepdims=True)
    products_embeddings[model_name] = products_embeddings[model_name].numpy() / np.linalg.norm(products_embeddings[model_name].numpy(), axis=1, keepdims=True)
    indices[model_name].add(reactants_embeddings[model_name])

In [40]:
k_retrieved_indices = defaultdict()
skip_indices = defaultdict()

for model_name in GCN_models.keys():
    skip_indices[model_name] = []

for model_name in GCN_models.keys():
    k_retrieved_indices[model_name] = []

for idx, model_name in enumerate(list(GCN_models.keys())):
    if idx in skip_indices[model_name]:
            continue
    skip_indices[model_name].extend(exclude_indices[model_name][idx])
    for i in tqdm(range(products_embeddings[model_name].shape[0])):
        D, I = indices[model_name].search(products_embeddings[model_name][i][np.newaxis, ...], 5)
        k_retrieved_indices[model_name].append(I[0])

100%|██████████| 8563/8563 [00:01<00:00, 8045.36it/s]
100%|██████████| 8563/8563 [00:01<00:00, 8077.37it/s]
100%|██████████| 8563/8563 [00:01<00:00, 8056.50it/s]
100%|██████████| 8563/8563 [00:01<00:00, 8008.46it/s]
100%|██████████| 8563/8563 [00:01<00:00, 8049.84it/s]


In [41]:
skip_indices = defaultdict()
targets = defaultdict()

for model_name in GCN_models.keys():
    skip_indices[model_name] = []
    targets[model_name] = []

In [42]:
retrieval_metrics = PrettyTable()
retrieval_metrics.field_names = ["Model", "MAP", "MHR", "MRR"]

In [43]:
for model_name in GCN_models.keys():
    print("Calculating metrics for", model_name)
    for idx, row in tqdm(uspto_triplets_dataset_embeddings.iterrows(), total=len(uspto_triplets_dataset_embeddings)):
        if idx in skip_indices[model_name]:
            continue

        true_reactants_indices = row['exclude_indices']
        retrieved_reactants_indices = k_retrieved_indices[model_name][idx]

        skip_indices[model_name].extend(exclude_indices[model_name][idx])

        targets_idx = [0 for _ in range(len(retrieved_reactants_indices))]

        for idx, retrieved_idx in enumerate(retrieved_reactants_indices):
            if retrieved_idx in true_reactants_indices:
                targets_idx[idx] = 1
        
        targets[model_name].append(targets_idx)
    
    targets[model_name] = np.array(targets[model_name])
    map, mhr, mrr = calculate_metric(targets[model_name])

    retrieval_metrics.add_row([model_name, map, mhr, mrr])

Calculating metrics for triplet_from_scratch


100%|██████████| 8563/8563 [00:00<00:00, 12696.57it/s]


Calculating metrics for pretrained


100%|██████████| 8563/8563 [00:00<00:00, 13071.36it/s]


Calculating metrics for triplet_loss_finetuned


100%|██████████| 8563/8563 [00:00<00:00, 13257.95it/s]


Calculating metrics for triplet_cosine_loss_finetuned


100%|██████████| 8563/8563 [00:00<00:00, 12923.33it/s]


Calculating metrics for random_weights


100%|██████████| 8563/8563 [00:00<00:00, 13331.96it/s]


In [44]:
retrieval_metrics

Model,MAP,MHR,MRR
triplet_from_scratch,0.6859461599366586,0.1680655475619504,0.6855727308759414
pretrained,0.7038378611470462,0.1542765787370104,0.7034542314335062
triplet_loss_finetuned,0.701825293350717,0.1530775379696243,0.7014360313315926
triplet_cosine_loss_finetuned,0.6982445008460237,0.1570743405275779,0.6981128074639525
random_weights,0.6950281803542673,0.1650679456434852,0.6948950766747377


## Brute force cosine similarity search

In [45]:
# def find_top_k_indices(arr, e, k):
#     # Ensure the embeddings and the comparison embedding are tensors
#     # arr = torch.stack(arr)
#     e = e.unsqueeze(0)

#     # Compute cosine similarity between e and each embedding in arr
#     similarities = cosine_similarity(arr, e)

#     # Get the values and indices of the top k similarities
#     top_k_values, top_k_indices = torch.topk(similarities, k, largest=True)

#     return top_k_values, top_k_indices

In [46]:
# k = 5
# skip_indices = []
# query_count = 0

# metrics_queries = torch.tensor([], dtype=torch.long)
# metrics_predictions = torch.tensor([], dtype=torch.float)
# metrics_targets = torch.tensor([], dtype=torch.bool)

# for idx, product_embedding in tqdm(enumerate(product_embeddings), total=len(product_embeddings)):
#     if idx in skip_indices:
#         continue
#     skip_indices.extend(exclude_indices[idx])

#     top_k_similar_values, top_k_similar_indices = find_top_k_indices(reactants_embeddings, product_embeddings[idx], k)
    
#     target_bools = torch.tensor([False for i in range(k)], dtype=torch.bool)
#     for j in range(k):
#         if top_k_similar_indices[j] in exclude_indices[idx]:
#             target_bools[j] = True

#     # Append the index, prediction (similarity values), and target (boolean values) to the respective tensors
#     metrics_queries = torch.cat((metrics_queries, torch.tensor([query_count for i in range(k)], dtype=torch.long)))
#     metrics_predictions = torch.cat((metrics_predictions, top_k_similar_values))
#     metrics_targets = torch.cat((metrics_targets, target_bools))
    
#     query_count += 1

In [47]:
# map_metric = RetrievalMAP()
# mrr_metric = RetrievalMRR()
# mhr_metric = RetrievalHitRate()

# # Compute the metrics
# map_value = map_metric(metrics_predictions, metrics_targets, indexes=metrics_queries)
# mrr_value = mrr_metric(metrics_predictions, metrics_targets, indexes=metrics_queries)
# mhr_value = mhr_metric(metrics_predictions, metrics_targets, indexes=metrics_queries)

# # Print the metric values
# print("Mean Average Precision (MAP):", map_value)
# print("Mean Reciprocal Rank (MRR):", mrr_value)
# print("Mean Hit Rate (MHR):", mhr_value)