# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.

In [None]:
from __future__ import annotations

import numpy as np
from pymatgen.core import Structure

In [None]:
structure = Structure.from_file("./o-LiMnO2_unit.cif")
print(structure.get_space_group_info())

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.3

# stretch the cell by a small amount
structure.scale_lattice(structure.volume * 1.1)

structure.get_space_group_info()

('Pmmn', 59)


('P1', 1)

In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

results = StructOptimizer().relax(structure)

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu


Exception: input should either be a graph or list of graphs!

In [None]:
struct_traj: list[Structure] = []
for idx, coords in enumerate(results["trajectory"].atom_positions):
    lattice = results["trajectory"].cells[idx]
    struct = Structure(lattice, structure.species, coords, coords_are_cartesian=True)
    struct_traj.append(struct)

e_col = "Energy (eV)"
force_col = "Force (eV/Å)"
vol_col = "Volume (Å<sup>3</sup)"
spg_col = "Spacegroup"
df_traj = pd.DataFrame(
    {
        e_col: results["trajectory"].energies,
        vol_col: [struct.volume for struct in struct_traj],
        spg_col: [struct.get_space_group_info() for struct in struct_traj],
    }
)
df_traj.index.name = "step"

In [None]:
from pymatgen.ext.matproj import MPRester

mp_id = "mp-18767"

with MPRester() as mpr:
    mp_doc = mpr.thermo.search(material_ids=[mp_id])[0]

dft_energy = mp_doc.energy_per_atom * mp_doc.nsites

Problem loading MPContribs client: Client.__init__() got an unexpected keyword argument 'session'


Retrieving ThermoDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import crystal_toolkit.components as ctc
import plotly.graph_objects as go
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from pymatgen.core import Structure

app = JupyterDash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

if (
    "struct_layout" not in vars()
):  # don't create layout twice because it causes duplicate
    # ID errors when restarting Dash app in Jupyter notebook
    struct_comp = ctc.StructureMoleculeComponent(
        id="structure", struct_or_mol=structure
    )
    struct_layout = struct_comp.layout()


step_size = max(1, len(struct_traj) // 20)  # ensure slider has max 20 steps
slider = dcc.Slider(
    id="slider", min=0, max=len(struct_traj) - 1, step=step_size, updatemode="drag"
)


def plot_energy_and_forces(
    df: pd.DataFrame, step: int, e_col: str, force_col: str
) -> go.Figure:
    """Plot energy and forces as a function of relaxation step."""
    fig = go.Figure()
    # energy trace = primary y-axis
    fig.add_trace(go.Scatter(x=df.index, y=df[e_col], mode="lines", name="Energy"))
    # get energy line color
    line_color = fig.data[0].line.color

    # forces trace = secondary y-axis
    fig.add_trace(
        go.Scatter(x=df.index, y=df[force_col], mode="lines", name="Forces", yaxis="y2")
    )

    # Set up the layout
    fig.update_layout(
        template="plotly_white",
        title=f"{mp_id} - {spg_col} = {df[spg_col][step]}",
        xaxis=dict(title="Relaxation Step"),
        yaxis=dict(title=e_col),
        yaxis2=dict(title=force_col, overlaying="y", side="right"),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )

    # Add vertical line at the specified step
    fig.add_vline(x=step, line=dict(dash="dash", width=1))

    # Add horizontal line for MP final energy
    anno = dict(text="DFT final energy", yanchor="top")
    fig.add_hline(
        y=dft_energy,
        line=dict(dash="dot", width=1, color=line_color),
        annotation=anno,
    )

    return fig


graph = dcc.Graph(
    id="fig",
    figure=plot_energy_and_forces(df_traj, 0, e_col, force_col),
    style={"maxWidth": "50%"},
)

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(
        margin="2em auto", placeItems="center", textAlign="center", maxWidth="1200px"
    ),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    return struct_traj[step], plot_energy_and_forces(df_traj, step, e_col, force_col)


app.run_server(mode="inline", height=800, use_reloader=False)

The TEMDiffractionComponent requires the py4DSTEM package.


Dash is running on http://127.0.0.1:8050/

