# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.

In [None]:
from __future__ import annotations

import numpy as np
from pymatgen.core import Structure

In [None]:
structure = Structure.from_file("./o-LiMnO2_unit.cif")
print(structure)

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.1

# stretch the cell by a small amount
# structure.scale_lattice(structure.volume * 1.1)

# structure.get_space_group_info()

Full Formula (Li2 Mn2 O4)
Reduced Formula: LiMnO2
abc   :   2.868779   4.634475   5.832507
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (8)
  #  SP      a    b         c
---  ----  ---  ---  --------
  0  Li+   0.5  0.5  0.37975
  1  Li+   0    0    0.62025
  2  Mn3+  0.5  0.5  0.863252
  3  Mn3+  0    0    0.136747
  4  O2-   0.5  0    0.360824
  5  O2-   0    0.5  0.098514
  6  O2-   0.5  0    0.901486
  7  O2-   0    0.5  0.639176


In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

results = StructOptimizer().relax(structure)

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 16:56:55      -58.602581*       2.7941
FIRE:    1 16:56:55      -58.687492*       3.6510
FIRE:    2 16:56:55      -58.745514*       1.8767
FIRE:    3 16:56:55      -58.778030*       1.3059
FIRE:    4 16:56:55      -58.782005*       1.2717
FIRE:    5 16:56:55      -58.788986*       1.2080
FIRE:    6 16:56:55      -58.797585*       1.1220
FIRE:    7 16:56:55      -58.806774*       1.0215
FIRE:    8 16:56:55      -58.816338*       0.9132
FIRE:    9 16:56:55      -58.826687*       0.8259
FIRE:   10 16:56:56      -58.838154*       0.7950
FIRE:   11 16:56:56      -58.851761*       0.7490
FIRE:   12 16:56:56      -58.866394*       0.6859
FIRE:   13 16:56:56      -58.880363*       0.6057
FIRE:   14 16:56:56      -58.893337*       0.5547
FIRE:   15 16:56:56      -58.906376*       0.4945
FIRE:   16 16:56:56      -58.918114*

In [None]:
struct_traj: list[Structure] = []
for idx, coords in enumerate(results["trajectory"].atom_positions):
    coords = results["trajectory"].atom_positions[idx]
    lattice = results["trajectory"].cells[idx]
    struct = Structure(lattice, structure.species, coords)
    struct_traj.append(struct)

e_col = "energy (eV/atom)"
df_traj = pd.DataFrame({e_col: results["trajectory"].energies})
df_traj.index.name = "step"

In [None]:
import crystal_toolkit.components as ctc
import plotly.express as px
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from pymatgen.core import Structure

app = JupyterDash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

struct_comp = ctc.StructureMoleculeComponent(id="structure", struct_or_mol=structure)


step_size = max(1, len(struct_traj) // 50)  # ensure slider has max 50 values
slider = dcc.Slider(
    id="slider", min=0, max=len(struct_traj) - 1, value=0, step=step_size
)

fig = px.line(
    df_traj, y=e_col, template="plotly_white", title="Energy during relaxation"
)
fig.add_vline(x=0, line=dict(dash="dash", width=1))
graph = dcc.Graph(id="fig", figure=fig, style={"maxWidth": "50%"})

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(margin="10vw", placeItems="center", textAlign="center"),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, px.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    fig = px.line(df_traj, template="plotly_white", title="Energy and Forces")
    fig.add_vline(x=step, line=dict(dash="dash", width=1))
    return struct_traj[step], fig


app.run_server(mode="inline")

  warn("The TEMDiffractionComponent requires the py4DSTEM package.")
  warn(


Dash is running on http://127.0.0.1:8050/

