In [1]:
import base64
from pathlib import Path

from ase.eos import EquationOfState
from ase.io import read
from dash import Dash, dcc, html, Input, Output, callback, dash_table
import dash_bootstrap_components as dbc
from IPython.display import update_display, display, DisplayHandle
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from weas_widget import WeasWidget

from mlip_testing import analysis
from mlip_testing.analysis import mof_eos
from mlip_testing.analysis.utils import get_config

In [2]:
config = get_config(
    "/Users/elliottkasoar/Documents/PSDI/mlip-testing/mlip_testing/analysis/config.yml"
)

test = "equation_of_state"
name = config[test]["name"]

module = getattr(analysis, config[test]["module"])
mlips = config[test]["mlips"]

get_structs = getattr(module, config[test]["structs"])
structs = get_structs()

In [3]:
input_results = {}
metric_results = {}

metrics = config[test]["score"]["metrics"]

for metric in metrics:
    inputs_name = metrics[metric]["inputs"]
    get_inputs = getattr(
        module, metrics[metric][inputs_name]["calc"]
    )
    input_results[metric] = get_inputs(mlips)

    get_metric = getattr(module, metrics[metric]["calc"])
    metric_results[metric] = get_metric(mlips, input_results[metric])

get_score = getattr(module, config[test]["score"]["calc"])
score = get_score(mlips, metric_results)



In [4]:
score_columns = ("MLIP", f"{name} Score")
score_data = [
    {"MLIP": mlip, f"{name} Score": score[mlip]} for mlip in mlips
]

metrics_columns = ("MLIP",) + tuple(metric_results.keys())

metrics_data = []
for mlip in mlips:
    metrics_data.append(
        {"MLIP": mlip} | {key: value[mlip] for key, value in metric_results.items()},

    )

In [5]:
def build_parity(mlips, metric, results):
    fig1 = go.Figure()

    for mlip in mlips:
        x = results[metric][mlip]
        y = results[metric]["ref"]

        fig1.add_trace(go.Scatter(
            x=x,
            y=y,
            name=mlip,
            mode="markers",
            # text=formulae,
            # hovertemplate="<br>Reference:%{x}<br>MACE:%{y}<br>Formula:%{text}"),
        ))


    # Dashed y=x
    values = y + [value for value in results[metric][mlip] for mlip in mlips]
    x_range = np.arange(
        min(value for value in values if value is not None),
        max(value for value in values if value is not None),
    )
    fig2 = px.line(x=x_range, y=x_range).update_traces(line=dict(dash="dash"))

    fig = go.Figure(data = fig1.data + fig2.data)
    fig.update_traces()

    fig.update_layout(
        xaxis=dict(
            title=dict(
                text=f"MLIP {metric}"
            )
        ),
        yaxis=dict(
            title=dict(
                text=f"Reference {metric}"
            )
        ),
    )

    return fig

In [6]:
def view_struct(struct_file):
    v=WeasWidget()
    v.from_ase(read(struct_file, index=":"))
    DisplayHandle("weas").update(None)
    display(v, display_id="weas")

In [7]:
def build_eos_plots(mlips, struct):
    fig1 = go.Figure()

    for mlip in mlips:
        raw_file = Path(
            f"/Users/elliottkasoar/Documents/PSDI/mlip-testing/mof_eos_results/{struct.stem}-{mlip}-eos-raw.dat"
        )
        energies = []
        volumes = []
        with open(raw_file, encoding="utf8") as f:
            lines = f.readlines()
            for line in lines[1:]:
                _, e_0, v_0 = tuple(float(x) for x in line.split())
                energies.append(e_0)
                volumes.append(v_0)

        eos = EquationOfState(volumes=volumes, energies=energies)
        eos_string, e0, v0, B, x, y, v, e = eos.getplotdata()


        fig1.add_trace(go.Scatter(
            x=v,
            y=e,
            name=mlip,
            mode="markers",
            # text=formulae,
            # hovertemplate="<br>Reference:%{x}<br>MACE:%{y}<br>Formula:%{text}"),
        ))
        fig1.add_trace(go.Scatter(
            x=x,
            y=y,
            name=mlip,
            mode="lines",
            # text=formulae,
            # hovertemplate="<br>Reference:%{x}<br>MACE:%{y}<br>Formula:%{text}"),
        ))

        fig1.update_traces()

    return fig1

In [8]:
def build_eos_structs(mlip, struct):
    struct_file = Path(
        f"/Users/elliottkasoar/Documents/PSDI/mlip-testing/mof_eos_results/{struct.stem}-{mlip}-generated.extxyz"
    )
    view_struct(struct_file)

In [16]:
external_stylesheets = ["https://codepen.io/chriddyp/pen/bWLwgP.css"]

app = Dash(
    __name__,
    external_stylesheets=external_stylesheets,
    suppress_callback_exceptions=True,
)

styles = {
    "pre": {
        "border": "thin lightgrey solid",
        "overflowX": "scroll"
    }
}

app.layout = html.Div([
    html.H1(f"{name}", style={"color": "black"}),
    
    html.Div(
        dbc.Container([
            dbc.Label(f'{name} scores'),
            dash_table.DataTable(score_data, [{"name": i, "id": i} for i in score_columns], id='score'),
        ]),
        style={"display": "flex", "width": "99%", "alignItems": "center",},
    ),

    html.Div(
        id='metrics-placeholder',
        style={"display": "flex", "width": "99%", "alignItems": "center",}
    ),

    html.Div(
        id="metrics-scatter-placeholder",
        style={"display": "inline-block", "width": "99%"}
    ),

    html.Div(
        id="metrics-plot-placeholder",
        style={"display": "inline-block", "width": "99%"}
    ),

    html.Div(
        id="metrics-struct-placeholder",
        style={"display": "inline-block", "width": "99%"}
    ),

])

@callback(
    Output('metrics-placeholder', 'children'),
    Input('score', 'active_cell'),
)
def update_metrics(active_cell):
    if active_cell is None:
        return html.Div(
                "Click on a column to view more details.",
                style={"display": "flex", "width": "99%", "alignItems": "center",}
            )
    if active_cell["column_id"] == f"{name} Score":
        return dbc.Container([
            dbc.Label(f"{name} metrics"),
            dash_table.DataTable(metrics_data, [{"name": i, "id": i} for i in metrics_columns], id="metrics"),
        ])

@callback(
    Output('metrics-scatter-placeholder', 'children'),
    Input('metrics', 'active_cell'),
)
def update_scatter(active_cell):
    if active_cell is None:
        return html.Div(
                "Click on a metric to view more details.",
                style={"display": "flex", "width": "99%", "alignItems": "center",}
            )
    if active_cell.get("column_id", None):
        return dcc.Graph(
            id="metrics-scatter", 
            figure=build_parity(mlips, active_cell["column_id"], input_results)
        )

@callback(
    Output("metrics-plot-placeholder", "children"),
    Input("metrics-scatter", "clickData"),
)
def update_plot_from_scatter(clickData):
    if clickData is None:
        return html.Div(
            "Click on a point to view plot.",
            style={"display": "flex", "width": "99%", "alignItems": "center",}
        )

    idx = clickData["points"][0]["pointNumber"]
    return dcc.Graph(
        id="metrics-plot", figure=build_eos_plots(mlips, structs[idx])
    )


@callback(
    Output("metrics-struct-placeholder", "children"),
    Input("metrics-scatter", "clickData"),
    Input("metrics-plot", "clickData"),
)
def update_struct(clickDataScatter, clickDataPlot):
    if not (clickDataScatter or clickDataPlot):
        return html.Div(
            "Click on a point to view structure.",
            style={"display": "flex", "width": "99%", "alignItems": "center",}
        )
    if clickDataScatter:
        idx = clickDataScatter["points"][0]["pointNumber"]
        view_struct(structs[idx])
    if clickDataPlot:
        idx = clickDataScatter["points"][0]["pointNumber"]
        mlip = mlips[clickDataScatter["points"][0]["curveNumber"]]
        build_eos_structs(mlip, structs[idx])

if __name__ == "__main__":
    app.run(port=1234, debug=False)