In [1]:
import sys
sys.path.append("../..")

import pandas as pd
import json

names = "small_w512 small_w1024 small_w1536 base_w512 base_w1024 base_w1536 big_w512 big_w1024 original random norot".split()

rows = []
for name in names:
    for rep in range(3):
        with open(f"../../samples/gen/ds_qm9_{name}_rep{rep}/metrics.json") as f:
            data = json.load(f)
        data["name"] = name
        data["rep"] = rep
        rows.append(data)
df = pd.DataFrame(rows)

df

Unnamed: 0,validity,uniqueness,valid_uniq,atom_stability,mol_stability,edm_validity,edm_valid_uniq,name,rep
0,0.9595,0.956644,0.9179,0.960487,0.741,0.8757,0.8441,small_w512,0
1,0.957,0.955277,0.9142,0.960143,0.7387,0.877,0.8451,small_w512,1
2,0.958,0.955324,0.9152,0.958851,0.7333,0.8733,0.842,small_w512,2
3,0.983,0.959308,0.943,0.973178,0.8138,0.9148,0.8816,small_w1024,0
4,0.9823,0.95704,0.9401,0.974946,0.8221,0.9185,0.8832,small_w1024,1
5,0.9813,0.954142,0.9363,0.973614,0.8131,0.9156,0.878,small_w1024,2
6,0.9843,0.954181,0.9392,0.977154,0.8361,0.9239,0.8872,small_w1536,0
7,0.9842,0.954887,0.9398,0.976288,0.8325,0.9256,0.8885,small_w1536,1
8,0.983,0.955849,0.9396,0.976741,0.8334,0.9281,0.89,small_w1536,2
9,0.9655,0.952978,0.9201,0.970062,0.8017,0.9066,0.8678,base_w512,0


In [2]:
import numpy as np

# Group by "name" and calculate mean and standard deviation for other columns
summary = df.groupby("name").agg(
    {col: ["mean", "std"] for col in df.columns if col not in ["name", "rep"]}
)

# Flatten the MultiIndex columns
summary.columns = ['_'.join(col).strip() for col in summary.columns.values]
summary.reset_index(inplace=True)

summary['name'] = pd.Categorical(summary['name'], categories=names, ordered=True)
summary = summary.sort_values('name').reset_index(drop=True)

summary

Unnamed: 0,name,validity_mean,validity_std,uniqueness_mean,uniqueness_std,valid_uniq_mean,valid_uniq_std,atom_stability_mean,atom_stability_std,mol_stability_mean,mol_stability_std,edm_validity_mean,edm_validity_std,edm_valid_uniq_mean,edm_valid_uniq_std
0,small_w512,0.958167,0.001258,0.955748,0.000776,0.915767,0.001914,0.959827,0.000863,0.737667,0.003953,0.875333,0.001877,0.843733,0.001582
1,small_w1024,0.9822,0.000854,0.95683,0.002589,0.9398,0.00336,0.973913,0.000921,0.816333,0.005006,0.9163,0.001947,0.880933,0.002663
2,small_w1536,0.983833,0.000723,0.954972,0.000838,0.939533,0.000306,0.976728,0.000433,0.834,0.001873,0.925867,0.002113,0.888567,0.001401
3,base_w512,0.967533,0.001767,0.953042,0.001904,0.9221,0.002427,0.971203,0.001009,0.806033,0.003765,0.910833,0.003723,0.873167,0.005452
4,base_w1024,0.9841,0.000693,0.954003,0.002289,0.938833,0.001595,0.976164,0.000387,0.829133,0.003365,0.925167,0.00135,0.8872,0.000872
5,base_w1536,0.988533,0.000814,0.950972,0.002346,0.940067,0.001665,0.980473,0.000934,0.8567,0.005761,0.938267,0.003953,0.897533,0.003814
6,big_w512,0.9708,0.000361,0.944753,0.001423,0.917167,0.001665,0.966949,0.000723,0.785667,0.003669,0.903267,0.000702,0.8592,0.001493
7,big_w1024,0.984333,0.001106,0.949814,0.002565,0.934933,0.00275,0.979494,0.000388,0.857733,0.001617,0.936467,0.002281,0.8926,0.001908
8,original,0.9912,0.000173,0.948379,0.001003,0.940033,0.000862,0.983008,0.000486,0.876,0.0026,0.947167,0.00195,0.900933,0.001343
9,random,0.806533,0.00212,0.993386,0.000599,0.8012,0.002524,0.823099,0.00119,0.259133,0.002386,0.630667,0.005123,0.624767,0.0043


In [3]:
column_order = "atom_stability mol_stability edm_validity edm_valid_uniq validity valid_uniq"
summary = summary[["name"] + [f"{col}_{stat}" for col in column_order.split() for stat in ["mean", "std"]]]
summary

Unnamed: 0,name,atom_stability_mean,atom_stability_std,mol_stability_mean,mol_stability_std,edm_validity_mean,edm_validity_std,edm_valid_uniq_mean,edm_valid_uniq_std,validity_mean,validity_std,valid_uniq_mean,valid_uniq_std
0,small_w512,0.959827,0.000863,0.737667,0.003953,0.875333,0.001877,0.843733,0.001582,0.958167,0.001258,0.915767,0.001914
1,small_w1024,0.973913,0.000921,0.816333,0.005006,0.9163,0.001947,0.880933,0.002663,0.9822,0.000854,0.9398,0.00336
2,small_w1536,0.976728,0.000433,0.834,0.001873,0.925867,0.002113,0.888567,0.001401,0.983833,0.000723,0.939533,0.000306
3,base_w512,0.971203,0.001009,0.806033,0.003765,0.910833,0.003723,0.873167,0.005452,0.967533,0.001767,0.9221,0.002427
4,base_w1024,0.976164,0.000387,0.829133,0.003365,0.925167,0.00135,0.8872,0.000872,0.9841,0.000693,0.938833,0.001595
5,base_w1536,0.980473,0.000934,0.8567,0.005761,0.938267,0.003953,0.897533,0.003814,0.988533,0.000814,0.940067,0.001665
6,big_w512,0.966949,0.000723,0.785667,0.003669,0.903267,0.000702,0.8592,0.001493,0.9708,0.000361,0.917167,0.001665
7,big_w1024,0.979494,0.000388,0.857733,0.001617,0.936467,0.002281,0.8926,0.001908,0.984333,0.001106,0.934933,0.00275
8,original,0.983008,0.000486,0.876,0.0026,0.947167,0.00195,0.900933,0.001343,0.9912,0.000173,0.940033,0.000862
9,random,0.823099,0.00119,0.259133,0.002386,0.630667,0.005123,0.624767,0.0043,0.806533,0.00212,0.8012,0.002524


In [4]:
# Prepare data for pretty table
table_data = []
previous_name = None
for i, row in summary.iterrows():
    # row_data = [row["name"]]
    row_data = []
    for col in column_order.split():
        mean = row[f"{col}_mean"]*100
        std = row[f"{col}_std"]*100
        row_data.append(f"{mean:.1f}" + r"$\spm{" + f"{std:.1f}" + "}$")
    print(" & ".join(row_data) + r" \\")


96.0$\spm{0.1}$ & 73.8$\spm{0.4}$ & 87.5$\spm{0.2}$ & 84.4$\spm{0.2}$ & 95.8$\spm{0.1}$ & 91.6$\spm{0.2}$ \\
97.4$\spm{0.1}$ & 81.6$\spm{0.5}$ & 91.6$\spm{0.2}$ & 88.1$\spm{0.3}$ & 98.2$\spm{0.1}$ & 94.0$\spm{0.3}$ \\
97.7$\spm{0.0}$ & 83.4$\spm{0.2}$ & 92.6$\spm{0.2}$ & 88.9$\spm{0.1}$ & 98.4$\spm{0.1}$ & 94.0$\spm{0.0}$ \\
97.1$\spm{0.1}$ & 80.6$\spm{0.4}$ & 91.1$\spm{0.4}$ & 87.3$\spm{0.5}$ & 96.8$\spm{0.2}$ & 92.2$\spm{0.2}$ \\
97.6$\spm{0.0}$ & 82.9$\spm{0.3}$ & 92.5$\spm{0.1}$ & 88.7$\spm{0.1}$ & 98.4$\spm{0.1}$ & 93.9$\spm{0.2}$ \\
98.0$\spm{0.1}$ & 85.7$\spm{0.6}$ & 93.8$\spm{0.4}$ & 89.8$\spm{0.4}$ & 98.9$\spm{0.1}$ & 94.0$\spm{0.2}$ \\
96.7$\spm{0.1}$ & 78.6$\spm{0.4}$ & 90.3$\spm{0.1}$ & 85.9$\spm{0.1}$ & 97.1$\spm{0.0}$ & 91.7$\spm{0.2}$ \\
97.9$\spm{0.0}$ & 85.8$\spm{0.2}$ & 93.6$\spm{0.2}$ & 89.3$\spm{0.2}$ & 98.4$\spm{0.1}$ & 93.5$\spm{0.3}$ \\
98.3$\spm{0.0}$ & 87.6$\spm{0.3}$ & 94.7$\spm{0.2}$ & 90.1$\spm{0.1}$ & 99.1$\spm{0.0}$ & 94.0$\spm{0.1}$ \\
82.3$\spm{0.1}$ & 2