# Crystaltoolkit Relaxation Viewer

This notebook shows how to visualize a CHGNet relaxation trajectory in a Plotly Dash app using Crystal Toolkit.

In [None]:
from __future__ import annotations

import numpy as np
from pymatgen.core import Structure

In [None]:
structure = Structure.from_file("./o-LiMnO2_unit.cif")
print(structure.get_space_group_info())

# perturb all atom positions by a small amount
for site in structure:
    site.coords += np.random.normal(size=3) * 0.3

# stretch the cell by a small amount
structure.scale_lattice(structure.volume * 1.1)

structure.get_space_group_info()

('Pmmn', 59)


('P1', 1)

In [None]:
import pandas as pd

from chgnet.model import StructOptimizer

results = StructOptimizer().relax(structure)

CHGNet initialized with 400,438 parameters
CHGNet will run on cpu
      Step     Time          Energy         fmax
*Force-consistent energies used in optimization.
FIRE:    0 06:56:46      -49.977566*      35.0804
FIRE:    1 06:56:46      -54.445950*       6.6021
FIRE:    2 06:56:46      -53.259274*      17.6979
FIRE:    3 06:56:46      -54.708866*      13.8175
FIRE:    4 06:56:46      -55.995808*       3.6483
FIRE:    5 06:56:46      -55.250874*      16.6567
FIRE:    6 06:56:46      -55.627010*      13.1561
FIRE:    7 06:56:46      -56.112236*       7.6342
FIRE:    8 06:56:46      -56.430168*       3.1914
FIRE:    9 06:56:46      -56.472912*       3.9039
FIRE:   10 06:56:46      -56.479626*       3.8009
FIRE:   11 06:56:47      -56.492607*       3.5976
FIRE:   12 06:56:47      -56.511051*       3.3000
FIRE:   13 06:56:47      -56.533810*       2.9178
FIRE:   14 06:56:47      -56.559662*       2.4660
FIRE:   15 06:56:47      -56.587429*       2.4747
FIRE:   16 06:56:47      -56.616158*

In [None]:
struct_traj: list[Structure] = []
for idx, coords in enumerate(results["trajectory"].atom_positions):
    lattice = results["trajectory"].cells[idx]
    struct = Structure(lattice, structure.species, coords, coords_are_cartesian=True)
    struct_traj.append(struct)

e_col = "energy (eV/atom)"
vol_col = "volume (A^3)"
spg_col = "spacegroup"
df_traj = pd.DataFrame(
    {
        e_col: results["trajectory"].energies,
        vol_col: [struct.volume for struct in struct_traj],
        spg_col: [struct.get_space_group_info() for struct in struct_traj],
    }
)
df_traj.index.name = "step"

In [None]:
from pymatgen.ext.matproj import MPRester

mp_id = "mp-18767"

with MPRester() as mpr:
    mp_doc = mpr.thermo.search(material_ids=[mp_id])[0]

mp_energy = mp_doc.energy_per_atom * mp_doc.nsites

Retrieving ThermoDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
import crystal_toolkit.components as ctc
import plotly.express as px
import plotly.graph_objects as go
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output
from jupyter_dash import JupyterDash
from pymatgen.core import Structure

app = JupyterDash(prevent_initial_callbacks=True, assets_folder=SETTINGS.ASSETS_PATH)

struct_comp = ctc.StructureMoleculeComponent(id="structure", struct_or_mol=structure)


step_size = max(1, len(struct_traj) // 20)  # ensure slider has max 20 steps
slider = dcc.Slider(
    id="slider", min=0, max=len(struct_traj) - 1, step=step_size, updatemode="drag"
)


def plot_energy(df: pd.DataFrame, step: int) -> go.Figure:
    """Plot energy as a function of relaxation step."""
    href = f"https://materialsproject.org/materials/{mp_id}"
    title = f"<a {href=}>{mp_id}</a> - {spg_col} = {df[spg_col][step]}"
    fig = px.line(df, y=e_col, template="plotly_white", title=title)
    fig.add_vline(x=step, line=dict(dash="dash", width=1))
    anno = dict(text="MP final energy", yanchor="top")
    fig.add_hline(
        y=mp_energy, line=dict(dash="dot", width=0.5, color="darkblue"), annotation=anno
    )
    return fig


graph = dcc.Graph(id="fig", figure=plot_energy(df_traj, 0), style={"maxWidth": "50%"})

app.layout = html.Div(
    [
        html.H1(
            "Structure Relaxation Trajectory", style=dict(margin="1em", fontSize="2em")
        ),
        html.P("Drag slider to see structure at different relaxation steps."),
        slider,
        html.Div([struct_comp.layout(), graph], style=dict(display="flex", gap="2em")),
    ],
    style=dict(
        margin="2em auto", placeItems="center", textAlign="center", maxWidth="1200px"
    ),
)

ctc.register_crystal_toolkit(app=app, layout=app.layout)


@app.callback(
    Output(struct_comp.id(), "data"), Output(graph, "figure"), Input(slider, "value")
)
def update_structure(step: int) -> tuple[Structure, go.Figure]:
    """Update the structure displayed in the StructureMoleculeComponent and the
    dashed vertical line in the figure when the slider is moved.
    """
    return struct_traj[step], plot_energy(df_traj, step)


app.run_server(mode="inline", height=800)

Dash is running on http://127.0.0.1:8050/

[1;31m---------------------------------------------------------------------------[0m
[1;31mDuplicateIdError[0m                          Traceback (most recent call last)
File [1;32m~/.venv/py310/lib/python3.10/site-packages/flask/app.py:1821[0m, in [0;36mFlask.full_dispatch_request[1;34m(self=<Flask '__main__'>)[0m
[0;32m   1819[0m [38;5;28;01mtry[39;00m:
[0;32m   1820[0m     request_started[38;5;241m.[39msend([38;5;28mself[39m)
[1;32m-> 1821[0m     rv [38;5;241m=[39m [38;5;28;43mself[39;49m[38;5;241;43m.[39;49m[43mpreprocess_request[49m[43m([49m[43m)[49m
        self [1;34m= <Flask '__main__'>[0m
[0;32m   1822[0m     [38;5;28;01mif[39;00m rv [38;5;129;01mis[39;00m [38;5;28;01mNone[39;00m:
[0;32m   1823[0m         rv [38;5;241m=[39m [38;5;28mself[39m[38;5;241m.[39mdispatch_request()

File [1;32m~/.venv/py310/lib/python3.10/site-packages/flask/app.py:2312[0m, in [0;36mFlask.preprocess_reques