In [1]:
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "plotly_white"
fig = go.Figure()
fig.to_dict()
pio.templates.default = "plotly_white"
pio.renderers.default = "plotly_mimetype+notebook_connected"

In [2]:
# | output: false
# | echo: false


import dill
import nshutils as nu
import numpy as np

preds_targets = dill.load(
    open(
        "/mnt/datasets/jmp-mptrj-checkpoints/relaxer-results-8k/mptrj-jmps-s2ef_s2re_s2e_energy.dill",
        "rb",
    )
)
e_above_hull_true, e_above_hull_pred = zip(*preds_targets["e_above_hull"])
e_above_hull_true = np.array(e_above_hull_true)
e_above_hull_pred = np.array(e_above_hull_pred)
maes = np.abs(e_above_hull_true - e_above_hull_pred)
nu.display(maes)

Type checking the following modules: ('jmppeft',)


Loading 'wbm_summary' from cached file at '/root/.cache/matbench-discovery/1.0.0/wbm/2023-12-13-wbm-summary.csv.gz'


In [3]:
import numpy as np
import plotly.graph_objects as go
from scipy.stats import gaussian_kde

# Calculate the KDE
kde = gaussian_kde(maes)
x_range = np.linspace(min(maes), max(maes), 1000)
y_kde = kde(x_range)

# Create the density plot
fig = go.Figure()

fig.add_trace(
    go.Scatter(x=x_range, y=y_kde, mode="lines", fill="tozeroy", name="MAE Density")
)

# Add vertical lines for the quantiles
q95 = np.quantile(maes, 0.95)
fig.add_vline(
    x=q95,
    line_dash="dash",
    line_color="red",
    annotation_text="95% quantile",
    annotation_position="top right",
)

# Update layout
fig.update_layout(
    title="Density Plot of Mean Absolute Errors (MAEs)",
    xaxis_title="MAE",
    yaxis_title="Density",
    showlegend=False,
)

# Show the plot
fig.show()

In [4]:
# Find the indices of the top 5% of MAEs
top_5pct_indices = np.argwhere(maes > q95).flatten()
bool_mask = np.zeros_like(maes, dtype=bool)
bool_mask[top_5pct_indices] = True
nu.display(bool_mask)

In [5]:
from pathlib import Path

import ase
import nglview
from jmppeft.modules.relaxer._relaxer import RelaxationOutput
from jmppeft.utils.render import render_trajectory

problematic_traj_base_dir = Path(
    "/mnt/datasets/jmp-mptrj-checkpoints/relaxer-results-8k_problematic/trajs/"
)


def _to_atoms_list(index: int):
    with open(problematic_traj_base_dir / f"{index}.dill", "rb") as f:
        data = dill.load(f)

    assert isinstance(data, RelaxationOutput)

    atoms_list = [
        ase.Atoms(
            numbers=data.atoms.numbers,
            positions=t.pos.numpy(),
            cell=t.cell.numpy(),
            pbc=data.atoms.pbc,
        )
        for t in data.trajectory.frames
    ]

    return atoms_list


def display_structure(index: int):
    with open(problematic_traj_base_dir / f"{index}.dill", "rb") as f:
        data = dill.load(f)

    assert isinstance(data, RelaxationOutput)

    atoms_list = [
        ase.Atoms(
            numbers=data.atoms.numbers,
            positions=t.pos.numpy(),
            cell=t.cell.numpy(),
            pbc=data.atoms.pbc,
        )
        for t in data.trajectory.frames
    ]

    view = nglview.show_asetraj(atoms_list, default=False)
    view.add_unitcell()
    view.add_spacefill()
    view.camera = "orthographic"
    view.parameters = {"clipDist": 5}
    view.center()
    view.update_spacefill(radiusType="covalent", radiusScale=0.5, color_scale="rainbow")

    return view


if False:
    atoms_list = _to_atoms_list(0)
    render_trajectory(atoms_list)

view = display_structure(0)
display(view)
# view.render_image()
# view._display_image()



NGLWidget(max_frame=53)