In [1]:
def estimate_mrr(top_k_scores):
    """
    DON'T USE THIS, NOT RIGHT
    Estimates the Mean Reciprocal Rank (MRR) based on top-k accuracy scores.

    Args:
        top_k_scores (dict): A dictionary where keys are 'top-k' labels (e.g., 'top-1', 'top-3') and
                             values are the probabilities (as fractions) of the correct answer being within that top-k.

    Returns:
        float: An estimated MRR based on the provided top-k scores.
    """
    # Define average ranks for each top-k category, assuming uniform distribution within each range
    average_ranks = {
        'top-1': 1,
        'top-3': 2.5,  # Assumes ranks 2 and 3 are equally likely if not in top-1
        'top-5': 4.5,  # Assumes ranks 4 and 5 are equally likely if not in top-3
        'top-10': 7.5  # Assumes ranks 6 to 10 are equally likely if not in top-5
    }

    # Calculate weighted sum of reciprocal ranks
    weighted_sum_reciprocal_ranks = 0
    previous_top_k_prob = 0
    for top_k, prob in top_k_scores.items():
        # Adjust probability to reflect conditional probability, excluding higher ranks
        adjusted_prob = prob - previous_top_k_prob
        previous_top_k_prob = prob
        # Calculate and add the weighted reciprocal rank for this top-k category
        weighted_sum_reciprocal_ranks += adjusted_prob / average_ranks[top_k]

    # The estimated MRR is the weighted sum of reciprocal ranks
    estimated_mrr = weighted_sum_reciprocal_ranks
    return estimated_mrr

def estimate_mrr_discrete(top_k_scores):
    """
    Estimates the Mean Reciprocal Rank (MRR) from top-k scores assuming a uniform distribution of ranks within intervals.

    Args:
        top_k_scores (dict): A dictionary with keys as the top-k cutoffs (e.g., 1, 3, 5) and values as the corresponding
                             probabilities of the correct answer being within the top k.

    Returns:
        float: The estimated MRR based on the provided top-k scores.
    """
    sorted_keys = sorted(top_k_scores.keys())  # Ensure keys are sorted in ascending order
    estimated_mrr = 0
    prev_prob = 0  # To keep track of the previous cumulative probability

    for i, k in enumerate(sorted_keys):
        # Calculate the incremental probability for the current interval
        curr_prob = top_k_scores[k]
        delta_prob = curr_prob - prev_prob
        prev_prob = curr_prob

        # Calculate the expected reciprocal rank contribution for the current interval
        if i == 0:  # For the first interval, it starts from 1
            start_rank = 1
        else:
            start_rank = sorted_keys[i-1] + 1  # Start from the next rank after the previous cutoff
        
        sum_reciprocal_ranks = sum(1/r for r in range(start_rank, k+1))
        interval_length = k - start_rank + 1
        E_i = sum_reciprocal_ranks / interval_length

        # Update the estimated MRR with the weighted contribution of this interval
        estimated_mrr += delta_prob * E_i

    return estimated_mrr

def turn_topk_list_to_dict_1_3_10(topk_list):
    return {1: topk_list[0], 5: topk_list[1], 10: topk_list[2]}

def turn_topk_list_to_dict(topk_list):
    return {1: topk_list[0], 3: topk_list[1], 5: topk_list[2], 10: topk_list[3]}

In [2]:
our_topks = [0.5245, 0.7382, 0.7957, 0.8472]
retrobridge_topks = [0.503, 0.74, 0.803, 0.851]
dualtf_topks = [0.536, 0.707, 0.746, 0.770]
# our_topks = [0.4687815720347545, 0.657506566983229, 0.7138815922408568, 0.789452414629218]
# our_topks = [0.545968882602546, 0.7229743382501516, 0.7710648615881996, 0.8086482117599515]# with cis/trans transfer and chirality transfer, but no deduplication with the correct SMILES
our_topks = [0.5471812487371186, 0.7330773893715902, 0.7781369973732067, 0.8106688219842393]#, with cis/trans transfer, chirality transfer, deduplication
our_topks_250 = [0.5489997979389776, 0.7334815114164478, 0.7759143261264903, 0.8086482117599515]
our_topks = [0.546, 76.4, 83.3, 88.5]
print(estimate_mrr_discrete(turn_topk_list_to_dict(our_topks)))
# print(estimate_mrr_discrete(turn_topk_list_to_dict(our_topks_250)))
# print(estimate_mrr_discrete(turn_topk_list_to_dict(their_topks)))
# print(estimate_mrr_discrete(turn_topk_list_to_dict(dualtf_topks)))

34.37579365079364


In [4]:
topks = {
    "RSMILES": [0.563,0.792,0.862,0.910],
    "PMSR": [0.620, 0.784, 0.829, 0.869],
    "Retrosym": [0.373, 0.547, 0.644, 0.741],
    "GLN": [0.525, 0.747, 0.812, 0.879],
    "LocalRetro": [0.525, 0.760, 0.844, 0.906],
    "GraphRetro": [0.537, 0.683, 0.722, 0.755],
    "SCROP": [0.437, 0.600, 0.652, 0.687],
    "Tied Transformer": [0.471, 0.671, 0.731, 0.763],
    "MEGAN": [0.48, 0.709, 0.781, 0.854],
    "G2G": [0.489, 0.676, 0.725, 0.755],
    "Retrobridge": [0.503, 0.740, 0.803, 0.851],
    "GTA_aug": [0.511, 0.676, 0.748, 0.816],
    "Graph2SMILES": [0.529, 0.665, 0.700, 0.729],
    # "Retroformer": [0.529, 0.682, 0.725, 0.764],
    "Retroformer": [0.532, 0.711, 0.766, 0.821],
    "DualTF_aug": [0.536, 0.707, 0.746, 0.77],
    "RetroDiff": [0.526, 0.712, 0.810, 0.833],
    "DiffRxn": [0.0406, 0.0648, 0.0786, 0.0978],
    "DiffRxn-PE-dY-100": [0.5235, 0.7309, 0.7836, 0.8320],
    "DiffRxn-input": [0.4411, 0.659, 0.7220, 0.7872],
    "DiffRxn-PE-Laplacian": [0.4904, 0.7068, 0.7660, 0.8181],
    "DiffRxn-PE-dY": [0.5259, 0.7302, 0.7864, 0.8319],
    "DiffRxn-PE-dY-150samples": [0.5264, 0.7371, 0.7880, 0.8402],
    "DiffRxn-PE-dY-200samples": [0.5245, 0.7382, 0.7957, 0.8472],
    "MARS": [0.546, 0.764, 0.833, 0.885],
    "retrocomposer": [0.545, 0.772, 0.832, 0.877],
    'retroknn': [0.572, 0.789, 0.864, 0.927]
}

# 0.526, 0.712, 0.810, 0.833

special_topks = {
    "Aug. Transformer": [0.483, 0.734, 0.774] # missing top-3
}

for key in topks.keys():
    print(key, estimate_mrr_discrete(turn_topk_list_to_dict(topks[key])))

print(list(special_topks.keys())[0], estimate_mrr_discrete(turn_topk_list_to_dict_1_3_10(special_topks[list(special_topks.keys())[0]])))

RSMILES 0.6803647619047619
PMSR 0.7036234126984127
Retrosym 0.4798503174603175
GLN 0.640776507936508
LocalRetro 0.6498225396825398
GraphRetro 0.6108695238095238
SCROP 0.5211361111111111
Tied Transformer 0.5719653968253968
MEGAN 0.6010429365079365
G2G 0.5818154761904762
Retrobridge 0.6221230952380953
GTA_aug 0.604730634920635
Graph2SMILES 0.5972863492063492
Retroformer 0.6260603174603175
DualTF_aug 0.6191240476190476
RetroDiff 0.6285199206349207
DiffRxn 0.05626757142857143
DiffRxn-PE-dY-100 0.6280239126984126
DiffRxn-input 0.5544857460317459
DiffRxn-PE-Laplacian 0.6006141825396826
DiffRxn-PE-dY 0.6295452777777778
DiffRxn-PE-dY-150samples 0.6323845952380952
DiffRxn-PE-dY-200samples 0.6331292063492063
MARS 0.6590729365079365
retrocomposer 0.6588940476190477
retroknn 0.6874266666666666
Aug. Transformer 0.568694246031746
