# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.

In [None]:
from __future__ import annotations

# install CHGNet with extra dependency Crystal Toolkit to run the dash app in this notebook
# https://github.com/materialsproject/crystaltoolkit
# (only needed on Google Colab or if you didn't install these packages yet)
!git clone --depth 1 https://github.com/CederGroupHub/chgnet
!pip install ./chgnet[crystal-toolkit]

In [None]:
import numpy as np
from pymatgen.core import Structure

In [None]:
try:
    from chgnet import ROOT

    structure = Structure.from_file(f"{ROOT}/examples/o-LiMnO2_unit.cif")
except Exception:
    from urllib.request import urlopen

    url = "https://github.com/CederGroupHub/chgnet/raw/main/examples/o-LiMnO2_unit.cif"
    cif = urlopen(url).read().decode("utf-8")
    structure = Structure.from_str(cif, fmt="cif")

In [None]:
print(structure.get_space_group_info())

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.3

# stretch the cell by a small amount
structure.scale_lattice(structure.volume * 1.1)

structure.get_space_group_info()

('Pmmn', 59)


('P1', 1)

In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

trajectory = StructOptimizer().relax(structure)["trajectory"]

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 16:42:01      -52.137783*      22.3727
FIRE:    1 16:42:01      -54.654541*       7.3535
FIRE:    2 16:42:01      -55.012409*       6.0275
FIRE:    3 16:42:01      -55.455532*       3.5359
FIRE:    4 16:42:01      -55.792427*       5.8862
FIRE:    5 16:42:01      -56.233017*       4.8161
FIRE:    6 16:42:01      -56.729580*       3.8116
FIRE:    7 16:42:01      -57.105877*       3.7746
FIRE:    8 16:42:01      -57.478939*       4.3329
FIRE:    9 16:42:01      -57.859035*       2.4245
FIRE:   10 16:42:01      -57.905327*       5.3127
FIRE:   11 16:42:01      -57.972210*       4.0170
FIRE:   12 16:42:02      -58.060028*       2.1152
FIRE:   13 16:42:02      -58.125294*       1.7038
FIRE:   14 16:42:02      -58.166508*       2.5232
FIRE:   15 16:42:02      -58.206398*       2.7542
FIRE:   16 16:42:02      -58.255352*

In [None]:
e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
df_traj = pd.DataFrame(trajectory.energies, columns=[e_col])
df_traj[force_col] = [
    np.linalg.norm(force, axis=1).mean()  # mean of norm of force on each atom
    for force in trajectory.forces
]
df_traj.index.name = "step"

In [None]:
from pymatgen.ext.matproj import MPRester

mp_id = "mp-18767"

try:
    # new MPRester expects len(api_key) == 32
    with MPRester(use_document_model=False) as mpr:
        mp_doc = mpr.thermo.search(material_ids=[mp_id])[0]
except AttributeError:
    # old MPRester expects len(api_key) ~= 16
    with MPRester() as mpr:
        mp_doc = mpr.query(mp_id, ["energy_per_atom", "nsites"])[0]

dft_energy = mp_doc["energy_per_atom"] * mp_doc["nsites"]
print(f"{dft_energy=:.2f}")

Problem loading MPContribs client: Client.__init__() got an unexpected keyword argument 'session'


Retrieving ThermoDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

dft_energy=-59.09


In [None]:
import crystal_toolkit.components as ctc
import plotly.graph_objects as go
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from pymatgen.core import Structure

app = JupyterDash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

if "struct_layout" not in vars():
    # don't create layout twice because it causes duplicate
    # ID errors when restarting Dash app in Jupyter notebook
    struct_comp = ctc.StructureMoleculeComponent(
        id="structure", struct_or_mol=structure
    )
    struct_layout = struct_comp.layout()


step_size = max(1, len(trajectory) // 20)  # ensure slider has max 20 steps
slider = dcc.Slider(
    id="slider", min=0, max=len(trajectory) - 1, step=step_size, updatemode="drag"
)


def plot_energy_and_forces(
    df: pd.DataFrame, step: int, e_col: str, force_col: str, title: str
) -> go.Figure:
    """Plot energy and forces as a function of relaxation step."""
    fig = go.Figure()
    # energy trace = primary y-axis
    fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
    # get energy line color
    line_color = fig.data[0].line.color

    # forces trace = secondary y-axis
    fig.add_trace(
        go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
    )

    fig.update_layout(
        template="plotly_white",
        title=title,
        xaxis=dict(title="Relaxation Step"),
        yaxis=dict(title=e_col),
        yaxis2=dict(title=force_col, overlaying="y", side="right"),
        legend=dict(yanchor="top", y=1, xanchor="right", x=1),
    )

    # vertical line at the specified step
    fig.add_vline(x=step, line=dict(dash="dash", width=1))

    # horizontal line for DFT final energy
    anno = dict(text="DFT final energy", yanchor="top")
    fig.add_hline(
        y=dft_energy,
        line=dict(dash="dot", width=1, color=line_color),
        annotation=anno,
    )

    return fig


def make_title(spg: tuple[str, int]) -> str:
    """Return a title for the figure."""
    href = f"https://materialsproject.org/materials/{mp_id}/"
    return f"<a {href=}>{mp_id}</a> - {spg[0]} ({spg[1]})"


title = make_title(structure.get_space_group_info())

graph = dcc.Graph(
    id="fig",
    figure=plot_energy_and_forces(df_traj, 0, e_col, force_col, title),
    style={"maxWidth": "50%"},
)

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(margin="auto", textAlign="center", maxWidth="1200px", padding="2em"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    lattice = trajectory.cells[step]
    coords = trajectory.atom_positions[step]
    structure.lattice = lattice
    assert len(structure) == len(coords)
    for site, coord in zip(structure, coords):
        site.coords = coord

    title = make_title(structure.get_space_group_info())
    fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)

    return structure, fig


app.run_server(mode="inline", height=800, use_reloader=False)

Dash is running on http://127.0.0.1:8050/

