In [1]:
import sys
sys.path.append('../..')

import numpy as np
import torch
import pandas as pd
import os
import tqdm
from src.metrics import evaluate_prediction

def eval_results(path, n_chunks, sort=True, re_eval=False):

    full_dict = dict()
    for chunk_id in range(n_chunks):
        chunk_path = path + f'/chunk-{chunk_id}/results.pt'
        chunk_dict = torch.load(chunk_path)
        full_dict.update(chunk_dict)

    if re_eval:
        if not os.path.exists(path + '/re_eval_results.pt'):
            results = dict()
            for generated in tqdm.tqdm(full_dict.values(), total=len(full_dict)):
                M_true = generated['mol']
                preds = generated['preds']
                M_preds = [M_true.replace(coords=pred['coords']) for pred in preds]

                log = {
                    "mol": M_true.replace(graph=None),
                    "preds": [],
                }

                for M_pred in M_preds:
                    metrics, M_aligned = evaluate_prediction(
                        M_pred=M_pred,
                        M_true=M_true,
                        return_aligned_mol=True,
                        keep_coords_pred=True,
                    )
                    metrics["coords"] = M_aligned.coords
                    log["preds"].append(metrics)

                results[M_true.id_as_int] = log
            
            torch.save(results, path + '/re_eval_results.pt')
        else:
            results = torch.load(path + '/re_eval_results.pt')
        full_dict = results
    
    ks = [1, 5, 10, 100]
    metrics = ['heavy_correctness', 'correctness', 'heavy_coord_rmse', 'coord_rmse', 'unsigned_coords_rmse', 'moments_rmse']
    summary = dict()

    for k in ks:
        for metric in metrics:
            summary[f'top_{k}_{metric}'] = []

    samples = dict()
    for generated in full_dict.values():
        M_true = generated['mol']
        results = generated['preds']
        if sort:
            unsigned_rmsds = [result['moments_rmse'] for result in results]
            reorder_idxs = np.argsort(unsigned_rmsds)
            results = [results[i] for i in reorder_idxs]
        for metric in metrics:
            if 'correctness' in metric:
                agg = max
            elif 'rmse' in metric:
                agg = min
            for k in ks:
                summary[f'top_{k}_{metric}'].append(agg([result[metric] for result in results[:k]]))
        
        example = dict()
        example['true'] = M_true
        example['preds'] = [M_true.replace(coords=result['coords']) for result in results]
        example['correct'] = [result['correctness'] for result in results]
        example['heavy_correct'] = [result['heavy_correctness'] for result in results]
        samples[M_true.id_as_int] = example

    df = pd.DataFrame(summary)
    # get statistics
    df.describe()

    avg_top_1_correctness = '%#.3g' % (df.top_1_correctness.mean() * 100)
    avg_top_5_correctness = '%#.3g' % (df.top_5_correctness.mean()* 100)

    median_top_1_rmse = '%#.3g' % (df.top_1_heavy_coord_rmse).median()
    median_top_5_rmse = '%#.3g' % (df.top_5_heavy_coord_rmse).median()

    print(f"Top 1 correctness: {avg_top_1_correctness}")
    print(f"Top 5 correctness: {avg_top_5_correctness}")
    # print(f"Top 1 RMSD < {threshold}: {round((df.top_1_heavy_coord_rmse < threshold).mean()*100, 1)}")
    # print(f"Top 5 RMSD < {threshold}: {round((df.top_5_heavy_coord_rmse < threshold).mean()*100, 1)}")
    print(f"Median Top 1 RMSD: {median_top_1_rmse}")
    print(f"Median Top 5 RMSD: {median_top_5_rmse}")

    print(f"{avg_top_1_correctness} & {avg_top_5_correctness} & {median_top_1_rmse} & {median_top_5_rmse}")

    return samples, df

In [2]:
import os
n_chunks = 5
path = "./qm9"
for chunk_id in range(n_chunks):
    print(os.path.exists(os.path.join(path, f"chunk-{chunk_id}", "results.pt")))
path = "./geom"
for chunk_id in range(n_chunks):
    print(os.path.exists(os.path.join(path, f"chunk-{chunk_id}", "results.pt")))

True
True
True
True
True
True
True
True
True
True


In [3]:
path = "./qm9"
n_chunks = 5
qm9_samples, qm9_df = eval_results(path, n_chunks, sort=True, re_eval=True)
qm9_df.describe()

Top 1 correctness: 27.9
Top 5 correctness: 39.4
Median Top 1 RMSD: 1.08
Median Top 5 RMSD: 0.797
27.9 & 39.4 & 1.08 & 0.797


Unnamed: 0,top_1_heavy_correctness,top_1_correctness,top_1_heavy_coord_rmse,top_1_coord_rmse,top_1_unsigned_coords_rmse,top_1_moments_rmse,top_5_heavy_correctness,top_5_correctness,top_5_heavy_coord_rmse,top_5_coord_rmse,...,top_10_heavy_coord_rmse,top_10_coord_rmse,top_10_unsigned_coords_rmse,top_10_moments_rmse,top_100_heavy_correctness,top_100_correctness,top_100_heavy_coord_rmse,top_100_coord_rmse,top_100_unsigned_coords_rmse,top_100_moments_rmse
count,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,...,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0,1335.0
mean,0.292135,0.278652,0.919692,1.076128,0.0,1.025185,0.423221,0.394007,0.630277,0.791284,...,0.547264,0.71708,0.0,1.025185,0.561049,0.48764,0.403775,0.578085,0.0,1.025185
std,0.454914,0.448504,0.606152,0.66156,0.0,0.674092,0.494255,0.48882,0.507921,0.591414,...,0.455084,0.550088,0.0,0.674092,0.496445,0.500035,0.341222,0.44404,0.0,0.674092
min,0.0,0.0,0.003893,0.005751,0.0,0.007793,0.0,0.0,0.00353,0.004898,...,0.00353,0.004436,0.0,0.007793,0.0,0.0,0.00353,0.004436,0.0,0.007793
25%,0.0,0.0,0.124044,0.354605,0.0,0.45041,0.0,0.0,0.031314,0.051589,...,0.02688,0.045633,0.0,0.45041,0.0,0.0,0.025372,0.043953,0.0,0.45041
50%,0.0,0.0,1.083761,1.300408,0.0,0.9619,0.0,0.0,0.797198,1.043096,...,0.668746,0.931991,0.0,0.9619,1.0,0.0,0.457979,0.741551,0.0,0.9619
75%,1.0,1.0,1.370442,1.555768,0.0,1.496111,1.0,1.0,1.054636,1.285394,...,0.948399,1.18637,0.0,1.496111,1.0,1.0,0.709503,0.968749,0.0,1.496111
max,1.0,1.0,2.16925,2.509489,0.0,3.512154,1.0,1.0,1.915482,2.134184,...,1.736487,2.041942,0.0,3.512154,1.0,1.0,1.260371,1.576751,0.0,3.512154


In [4]:
path = "./geom"
n_chunks = 5
geom_samples, geom_df = eval_results(path, n_chunks, sort=True, re_eval=True)
geom_df.describe()

Top 1 correctness: 0.273
Top 5 correctness: 0.342
Median Top 1 RMSD: 2.14
Median Top 5 RMSD: 1.89
0.273 & 0.342 & 2.14 & 1.89


Unnamed: 0,top_1_heavy_correctness,top_1_correctness,top_1_heavy_coord_rmse,top_1_coord_rmse,top_1_unsigned_coords_rmse,top_1_moments_rmse,top_5_heavy_correctness,top_5_correctness,top_5_heavy_coord_rmse,top_5_coord_rmse,...,top_10_heavy_coord_rmse,top_10_coord_rmse,top_10_unsigned_coords_rmse,top_10_moments_rmse,top_100_heavy_correctness,top_100_correctness,top_100_heavy_coord_rmse,top_100_coord_rmse,top_100_unsigned_coords_rmse,top_100_moments_rmse
count,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,...,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0,1464.0
mean,0.002732,0.002732,2.142256,2.269363,0.0,13.335727,0.004098,0.003415,1.864329,2.018026,...,1.763067,1.934954,0.0,13.335727,0.008197,0.005464,1.525588,1.719747,0.0,13.335727
std,0.052217,0.052217,0.412888,0.380998,0.0,7.51863,0.063909,0.058361,0.35869,0.327752,...,0.341418,0.316437,0.0,7.51863,0.090195,0.073745,0.327645,0.306742,0.0,7.51863
min,0.0,0.0,0.006673,0.010289,0.0,0.259354,0.0,0.0,0.006673,0.010289,...,0.006673,0.010289,0.0,0.259354,0.0,0.0,0.006673,0.010289,0.0,0.259354
25%,0.0,0.0,1.885487,2.03302,0.0,8.53594,0.0,0.0,1.662176,1.845231,...,1.575139,1.765071,0.0,8.53594,0.0,0.0,1.329349,1.553711,0.0,8.53594
50%,0.0,0.0,2.144487,2.259917,0.0,12.673577,0.0,0.0,1.893213,2.032712,...,1.800083,1.957916,0.0,12.673577,0.0,0.0,1.561107,1.749682,0.0,12.673577
75%,0.0,0.0,2.382596,2.493511,0.0,16.870538,0.0,0.0,2.083524,2.216264,...,1.988311,2.131911,0.0,16.870538,0.0,0.0,1.755611,1.91445,0.0,16.870538
max,1.0,1.0,3.853696,3.880698,0.0,137.307251,1.0,1.0,3.085841,3.273874,...,2.925229,3.021736,0.0,137.307251,1.0,1.0,2.438931,2.577977,0.0,137.307251


In [5]:
print(f'& QM9', end='')
for metric in ["top_1_correctness", "top_5_correctness", "top_10_correctness", "top_100_correctness", "top_1_heavy_correctness", "top_5_heavy_correctness", "top_10_heavy_correctness", "top_100_heavy_correctness"]:
    val = '%#.3g' % (qm9_df[metric].mean() * 100)
    print(f' & {val}', end='')
print(' \\\\')
print(f'& GEOM', end='')
for metric in ["top_1_correctness", "top_5_correctness", "top_10_correctness", "top_100_correctness", "top_1_heavy_correctness", "top_5_heavy_correctness", "top_10_heavy_correctness", "top_100_heavy_correctness"]:
    val = '%#.3g' % (geom_df[metric].mean() * 100)
    print(f' & {val}', end='')
print(' \\\\')

& QM9 & 27.9 & 39.4 & 42.1 & 48.8 & 29.2 & 42.3 & 46.4 & 56.1 \\
& GEOM & 0.273 & 0.342 & 0.342 & 0.546 & 0.273 & 0.410 & 0.478 & 0.820 \\
