In [None]:
import os
import seaborn as sns
import matplotlib.pyplot as plt

import sys
sys.path.append('..')

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2
import analyses.analysis as analysis
import analyses.process_generated_molecules as process_generated_molecules

In [None]:
BASEDIR = "../potato_workdirs/tetris/"
BASEDIR = "../potato_workdirs/platonic_solids/"
BASEDIR = "../potato_workdirs/qm9_10JUL/"
BASEDIR = "../potato_workdirs/qm9_global_embedding_exp_with_noise/nequip/interactions=3/l=4/position_channels=5/channels=64/global_embed=False/"
BASEDIR = "../potato_workdirs/"

In [None]:
# Get results.
basedir = os.path.abspath(BASEDIR)
results = analysis.get_results_as_dataframe(basedir)
results


In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=results,
    x="num_params",
    y="val_eval_final.position_loss",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L", loc="upper right")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Validation Loss")
plt.title("QM9")
plt.show();


# EDM Analyses

In [None]:
molecules_basedir = "/Users/ameyad/Documents/spherical-harmonic-net/analyses/analysed_workdirs/qm9_bessel_embedding"
preprocess = True
if preprocess:
    process_generated_molecules.process_molecules_dirs(molecules_basedir, relax_structures=False)
    
edm_analyses_results_orig = analysis.get_edm_analyses_results_as_dataframe(molecules_basedir, extract_hyperparams_from_path=True, read_as_sdf=True)

In [None]:
0.944685	* 0.960964

In [None]:
edm_analyses_results_orig.sort_values(by="fraction_atoms_stable", ascending=False)[edm_analyses_results_orig["step"] == "best"]

In [None]:
edm_analyses_results = edm_analyses_results_orig.loc[edm_analyses_results["step"] != "best"]

In [None]:
edm_analyses_results["step"] = edm_analyses_results["step"].astype(int)

sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=edm_analyses_results,
    x="step",
    y="fraction_atoms_stable",
    hue="max_l",
    style="max_l",
    marker="o",
)
plt.xscale("log")
plt.xticks(ticks=edm_analyses_results["step"].unique(), labels=edm_analyses_results["step"].unique())
plt.legend(title="Max L")
plt.xlabel("Training Step")
plt.ylabel("Fraction of Atoms Stable")
plt.title("Generated Atom Stability During Training")
# plt.savefig("../plots/qm9_atom_stability.pdf")
plt.show();

sns.cubehelix_palette(n_colors=5)[:]

In [None]:
edm_analyses_results["step"] = edm_analyses_results["step"].astype(int)

sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=edm_analyses_results,
    x="step",
    y="fraction_molecules_stable",
    hue="max_l",
    style="max_l",
    marker="o",
)
plt.xscale("log")
plt.xticks(ticks=edm_analyses_results["step"].unique(), labels=edm_analyses_results["step"].unique())
plt.legend(title="Max L")
plt.xlabel("Training Step")
plt.ylabel("Fraction of Molecules Stable")
plt.title("Generated Molecule Stability During Training")
plt.savefig("../plots/qm9_molecule_stability.pdf")
plt.show();

In [None]:
edm_analyses_results["step"] = edm_analyses_results["step"].astype(int)

sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=edm_analyses_results,
    x="step",
    y="fraction_valid",
    hue="max_l",
    style="max_l",
    marker="o",
)
plt.xscale("log")
plt.xticks(ticks=edm_analyses_results["step"].unique(), labels=edm_analyses_results["step"].unique())
plt.legend(title="Max L")
plt.xlabel("Training Step")
plt.ylabel("Fraction of Molecules Valid")
plt.title("Generated Molecule Validity During Training")
plt.savefig("../plots/qm9_molecule_validity.pdf")
plt.show();

# Merged Results

In [None]:
# Join results and edm_analyses_results on (max_l, position_channels).
merged_results = results.merge(edm_analyses_results, on=["max_l", "config.target_position_predictor.num_channels", "model", "config.num_channels", "config.num_interactions"], suffixes=('', ''))
merged_results

In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="fraction_molecules_stable",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Fraction of Molecules Stable")
plt.title("QM9")
plt.show();

In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="fraction_atoms_stable",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Fraction of Atoms Stable")
plt.title("QM9")
plt.show();

In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="num_generated",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Number of Molecules Generated")
plt.title("QM9")
plt.show();

In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="fraction_valid",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Fraction Valid of All Molecules Generated")
plt.title("QM9")
plt.show();

In [None]:
sns.set_theme(style="darkgrid")
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="fraction_unique_of_valid",
    hue="max_l",
    style="max_l", 
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Fraction Unique of Valid Molecules Generated")
plt.title("QM9")
plt.show();

In [None]:
sns.set_theme(style="darkgrid")
# Set marker sizes depending on max_l.
sns.lineplot(
    data=merged_results,
    x="num_params",
    y="fraction_novel_of_valid_and_unique",
    hue="max_l",
    style="max_l",
    markersize=10,
    markers=True,
    dashes=True,
)
# Set legend title as Max L.
plt.legend(title="Max L")
# Set axis labels.
plt.xlabel("Number of Parameters")
plt.ylabel("Fraction Novel of Valid and Unique Molecules Generated")
plt.title("QM9")
plt.show();