# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.

In [None]:
from __future__ import annotations

import numpy as np
from pymatgen.core import Structure

In [None]:
structure = Structure.from_file("./o-LiMnO2_unit.cif")
print(structure.get_space_group_info())

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.3

# stretch the cell by a small amount
structure.scale_lattice(structure.volume * 1.1)

structure.get_space_group_info()

('Pmmn', 59)


('P1', 1)

In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

results = StructOptimizer().relax(structure)

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 15:57:43      -53.022331*      21.0648
FIRE:    1 15:57:43      -54.552948*      10.8460
FIRE:    2 15:57:43      -54.989319*       7.2947
FIRE:    3 15:57:43      -55.410568*       3.5495
FIRE:    4 15:57:43      -55.484802*       7.7296
FIRE:    5 15:57:43      -55.577263*       6.4933
FIRE:    6 15:57:43      -55.720509*       4.2097
FIRE:    7 15:57:43      -55.859726*       2.9528
FIRE:    8 15:57:43      -55.956528*       2.6131
FIRE:    9 15:57:43      -56.020336*       3.7622
FIRE:   10 15:57:44      -56.096500*       4.8202
FIRE:   11 15:57:44      -56.210766*       4.9993
FIRE:   12 15:57:44      -56.382362*       4.2564
FIRE:   13 15:57:44      -56.594078*       2.5059
FIRE:   14 15:57:44      -56.781483*       2.2923
FIRE:   15 15:57:44      -56.902622*       4.4940
FIRE:   16 15:57:44      -57.080509*

In [None]:
struct_traj: list[Structure] = []
for idx, coords in enumerate(results["trajectory"].atom_positions):
    coords = results["trajectory"].atom_positions[idx]
    lattice = results["trajectory"].cells[idx]
    struct = Structure(lattice, structure.species, coords, coords_are_cartesian=True)
    struct_traj.append(struct)

e_col = "energy (eV/atom)"
vol_col = "volume (A^3)"
spg_col = "spacegroup"
df_traj = pd.DataFrame(
    {
        e_col: results["trajectory"].energies,
        vol_col: [struct.volume for struct in struct_traj],
        spg_col: [struct.get_space_group_info() for struct in struct_traj],
    }
)
df_traj.index.name = "step"

In [None]:
import crystal_toolkit.components as ctc
import plotly.express as px
import plotly.graph_objects as go
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from pymatgen.core import Structure

app = JupyterDash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

struct_comp = ctc.StructureMoleculeComponent(id="structure", struct_or_mol=structure)


step_size = max(1, len(struct_traj) // 20)  # ensure slider has max 20 steps
slider = dcc.Slider(
    id="slider", min=0, max=len(struct_traj) - 1, value=0, step=step_size
)


def plot_energy(df: pd.DataFrame, step: int) -> go.Figure:
    """Plot energy as a function of relaxation step."""
    title = f"{spg_col} = {df[spg_col][step]}"
    fig = px.line(df, y=e_col, template="plotly_white", title=title)
    fig.add_vline(x=step, line=dict(dash="dash", width=1))
    return fig


graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"})

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(
        margin="2em auto", placeItems="center", textAlign="center", maxWidth="1200px"
    ),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    return struct_traj[step], plot_energy(df_traj, step)


app.run_server(mode="inline", height=800)

  warn("The TEMDiffractionComponent requires the py4DSTEM package.")


Dash is running on http://127.0.0.1:8050/

