In [23]:
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "plotly_white"
fig = go.Figure()
fig.to_dict()
pio.templates.default = "plotly_white"

In [24]:
import dill
import nshutils as nu
import numpy as np

preds_targets = dill.load(
    open(
        "/mnt/datasets/jmp-mptrj-checkpoints/relaxer-results-8k/mptrj-jmps-s2ef_s2re_s2e_energy.dill",
        "rb",
    )
)
e_above_hull_true, e_above_hull_pred = zip(*preds_targets["e_above_hull"])
e_above_hull_true = np.array(e_above_hull_true)
e_above_hull_pred = np.array(e_above_hull_pred)
maes = np.abs(e_above_hull_true - e_above_hull_pred)
nu.display(maes)

In [25]:
import numpy as np
import plotly.graph_objects as go
from scipy.stats import gaussian_kde

# Calculate the KDE
kde = gaussian_kde(maes)
x_range = np.linspace(min(maes), max(maes), 1000)
y_kde = kde(x_range)

# Create the density plot
fig = go.Figure()

fig.add_trace(
    go.Scatter(x=x_range, y=y_kde, mode="lines", fill="tozeroy", name="MAE Density")
)

# Add vertical lines for the quantiles
q95 = np.quantile(maes, 0.95)
fig.add_vline(
    x=q95,
    line_dash="dash",
    line_color="red",
    annotation_text="95% quantile",
    annotation_position="top right",
)

# Update layout
fig.update_layout(
    title="Density Plot of Mean Absolute Errors (MAEs)",
    xaxis_title="MAE",
    yaxis_title="Density",
    showlegend=False,
)

# Show the plot
fig.show()

In [26]:
# Find the indices of the top 5% of MAEs
top_5pct_indices = np.argwhere(maes > q95).flatten()
bool_mask = np.zeros_like(maes, dtype=bool)
bool_mask[top_5pct_indices] = True
nu.display(bool_mask)

In [29]:
from pathlib import Path

import ase
import ase.io
import ase.visualize
from ase.build import make_supercell
from jmppeft.modules.relaxer._relaxer import RelaxationOutput
from jmppeft.utils.render import render_trajectory

problematic_traj_base_dir = Path(
    "/mnt/datasets/jmp-mptrj-checkpoints/relaxer-results-8k_problematic/trajs/"
)


def _to_atoms_list(index: int, supercell: bool = True):
    with open(problematic_traj_base_dir / f"{index}.dill", "rb") as f:
        data = dill.load(f)

    assert isinstance(data, RelaxationOutput)

    atoms_list = [
        ase.Atoms(
            numbers=data.atoms.numbers,
            positions=t.pos.numpy(),
            cell=t.cell.numpy(),
            pbc=data.atoms.pbc,
        )
        for t in data.trajectory.frames
    ]
    if supercell:
        atoms_list = [
            make_supercell(atoms, [[2, 0, 0], [0, 2, 0], [0, 0, 2]])
            for atoms in atoms_list
        ]

    return atoms_list


def display_structure(index: int):
    global atoms_list
    atoms_list = _to_atoms_list(index)

    # view = nglview.show_asetraj(atoms_list, default=False)
    # view.add_unitcell()
    # view.add_spacefill()
    # view.camera = "orthographic"
    # view.parameters = {"clipDist": 5}
    # view.center()
    # view.update_spacefill(radiusType="covalent", radiusScale=0.5, color_scale="rainbow")

    view = ase.visualize.view(atoms_list, viewer="ngl")

    return view


def _save_trajectory(
    atoms_list: list[ase.Atoms],
    save_path: Path,
    zip_results: bool = True,
):
    # Create a 2x2x2 supercell
    for i, atoms in enumerate(atoms_list):
        ase.io.write(save_path / f"{i}.cif", atoms)

    if zip_results:
        import zipfile

        with zipfile.ZipFile(save_path.with_suffix(".zip"), "w") as z:
            for i in range(len(atoms_list)):
                z.write(save_path / f"{i}.cif")
    print(f"Saved trajectory to {save_path}")


def save_trajectory(idx: int, save_dir: Path, supercell: bool = True):
    atoms_list = _to_atoms_list(idx, supercell=supercell)
    ase.io.write(save_dir / f"{idx}.xyz", atoms_list)


INDEX = 205


if False:
    atoms_list = _to_atoms_list(INDEX)
    render_trajectory(atoms_list)

if True:
    save_dir = Path("./trajectory_outputs/")
    save_dir.mkdir(exist_ok=True)
    (save_dir / ".gitignore").write_text("*\n")

    # save_trajectory(_to_atoms_list(200), save_dir / "200")
    save_trajectory(INDEX, save_dir, supercell=True)

view = display_structure(INDEX)
display(view)

HBox(children=(NGLWidget(max_frame=35), VBox(children=(Dropdown(description='Show', options=('All', 'Pu', 'Mg'…

In [31]:
from tqdm.auto import tqdm


def save_all(
    indices: list[int],
    save_dir: Path,
    supercell: bool = True,
    zip_results: bool = True,
):
    for idx in tqdm(indices):
        save_trajectory(idx, save_dir, supercell=supercell)

    if zip_results:
        import zipfile

        with zipfile.ZipFile(save_dir.with_suffix(".zip"), "w") as z:
            # Zip all the files in the directory
            for file in save_dir.iterdir():
                z.write(file)


all_indices = [int(p.stem) for p in problematic_traj_base_dir.glob("*.dill")]

save_dir = Path("./trajectory_outputs/")
save_dir.mkdir(exist_ok=True)
(save_dir / ".gitignore").write_text("*\n")
save_all(all_indices, save_dir, supercell=True, zip_results=True)

  0%|          | 0/410 [00:00<?, ?it/s]

In [None]:
nu.display(data)


Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /opt/conda/conda-bld/pytorch_1708025845868/work/c10/core/TensorImpl.h:1908.)

