In [1]:
from pathlib import Path
import numpy as np
import torch
from tools.import_utils import load_config
from tools.scripts_utils import generate_g2m_dataset_from_paths, get_model_dataset, init_mace_g2m_model
from tools.tools import get_basis_from_structures_paths, load_model
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import warnings
import sisl
from tools.debug import create_sparse_matrix
from joblib import dump, load

from graph2mat import (
    BasisTableWithEdges,
)

  _Jd, _W3j_flat, _W3j_indices = torch.load(os.path.join(os.path.dirname(__file__), 'constants.pt'))


In [None]:
debug_mode = False
# *********************************** #
# * VARIABLES TO CHANGE BY THE USER * #
# *********************************** #
model_dir = Path("../results/h_crystalls_1") # Results directory
filename = "train_best_model.tar" # Model name (or relative path to the results directory)
compute_calculations = False # Save or Load calculations.

# *********************************** #

# Hide some warnings
warnings.filterwarnings("ignore", message="The TorchScript type system doesn't support")
warnings.filterwarnings("ignore", message=".*is not a known matrix type key.*")
if debug_mode:
    print("**************************************************")
    print("*                                                *")
    print("*              DEBUG MODE ACTIVATED              *")
    print("*                                                *")
    print("**************************************************")

savedir = model_dir / "results"
savedir.mkdir(exist_ok=True, parents=True)
calculations_path = savedir / "calculations_alldataset.joblib"

# Define orbital labels (for now we will assume that all atoms have the same orbitals). Use the same order as appearance in the hamiltonian.
orbitals = {
    0: "s1",
    1: "s2",
    2: "py1",
    3: "pz1",
    4: "px1",
    5: "py2",
    6: "pz2",
    7: "px2",
    8: "Pdxy",
    9: "Pdyz",
    10: "Pdz2",
    11: "Pdxz",
    12: "Pdx2-y2",
}
n_orbs = len(orbitals)

# Load the config of the model
config = load_config(model_dir / "config.yaml")
device = torch.device("cpu")

# Load the same dataset used to train/validate the model (paths)
train_paths, val_paths = get_model_dataset(model_dir, verbose=True)

# Load the results
print("Loading the results...")
try:
    data = load(calculations_path)
    print("Results loaded!")
except FileNotFoundError:
    raise FileNotFoundError(f"Could not find the saved calculations at {calculations_path}")

# Reconstruct your tuples and labels
train_data   = (data['train_true'],   data['train_pred'])
val_data     = (data['val_true'],     data['val_pred'])
train_labels = data['train_labels']
val_labels   = data['val_labels']


# Unpack for clarity
train_true, train_pred = train_data
val_true,   val_pred   = val_data

# Means
train_means = (
    np.array([m.mean() for m in train_true]),
    np.array([m.mean() for m in train_pred])
)
val_means = (
    np.array([m.mean() for m in val_true]),
    np.array([m.mean() for m in val_pred])
)

# Standard deviations (ddof=1)
train_stds = (
    np.array([np.std(m.toarray(), ddof=0) for m in train_true]),
    np.array([np.std(m.toarray(), ddof=0) for m in train_pred])
)
val_stds = (
    np.array([np.std(m.toarray(), ddof=0) for m in val_true]),
    np.array([np.std(m.toarray(), ddof=0) for m in val_pred])
)

# Max absolute error
maxae_train = np.array([
    np.max(np.abs(t - p))
    for t, p in zip(train_true, train_pred)
])
maxae_val = np.array([
    np.max(np.abs(t - p))
    for t, p in zip(val_true, val_pred)
])

maxaes = ([maxae_train, maxae_val])
maxaes_labels = ([path.parts[-2][14:] +"/"+ path.parts[-1] for path in train_paths], [path.parts[-2][14:] +"/"+ path.parts[-1] for path in val_paths])

colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # medium gray
    '#bcbd22',  # curry yellow-green
    '#17becf',  # blue-teal
    '#fdae61',  # sandy orange
    '#66c2a5',  # seafoam green
    '#fc8d62',  # coral
    '#a6d854',  # light lime
    '#ffd92f',  # sunflower
    '#e5c494',  # beige
    '#b3b3b3'   # soft gray
]

NameError: name 'a' is not defined

In [None]:
import plotly.graph_objects as go

def plot_dataset_results(
        train_data, val_data,
        colors, title,
        train_labels, val_labels,
        train_means, val_means,
        train_stds, val_stds,
        maxaes, maxaes_labels,
        filepath
):
    fig = go.Figure()
    n_train_samples = len(train_data[0])
    n_val_samples = len(val_data[0])

    train_data = [[train_data[k][i].data for i in range(n_train_samples)] for k in range(len(train_data))]
    val_data = [[val_data[k][i].data for i in range(n_val_samples)] for k in range(len(val_data))]
    

    # ====== TRAINING DATA ======
    matrix_traces = []
    mean_traces = []
    std_traces = []
    for i in range(n_train_samples):
        # Training matrix elements
        # print("It", i)
        # print("True=",train_data[0][i])
        # print("Pred=",train_data[1][i])
        trace = go.Scatter(
            x=train_data[0][i],
            y=train_data[1][i],
            mode='markers',
            marker=dict(
                # symbol='dash',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=0)
            ),
            name=f'Training sample {i}',
            text=train_labels[i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='training',
            # legendgrouptitle="Training samples",
            showlegend=True
        )
        matrix_traces.append(trace)

        # Training means
        trace = go.Scatter(
            x=[train_means[0][i]],
            y=[train_means[1][i]],
            mode='markers',
            marker=dict(
                symbol='square',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=0)
            ),
            name=f'Mean training {i}',
            text=maxaes_labels[0][i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='training_mean',
            visible=False,
            showlegend=True,
        )
        mean_traces.append(trace)

        # Training std
        trace = go.Scatter(
            x=[train_stds[0][i]],
            y=[train_stds[1][i]],
            mode='markers',
            marker=dict(
                symbol='triangle-up',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=0)
            ),
            name=f'Std training {i}',
            text=maxaes_labels[0][i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='training_std',
            visible=False,
            showlegend=True,
        )
        std_traces.append(trace)

    # === Validation ===
    for i in range(n_val_samples):
        trace = go.Scatter(
            x=val_data[0][i],
            y=val_data[1][i],
            mode='markers',
            marker=dict(
                symbol='circle-open',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=1,)
            ),
            name=f'Validation sample {i}',
            text=val_labels[i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='validation',
            # legendgrouptitle="Validation samples",
            showlegend=True
        )
        matrix_traces.append(trace)

        # Val means
        trace = go.Scatter(
            x=[val_means[0][i]],
            y=[val_means[1][i]],
            mode='markers',
            marker=dict(
                symbol='square-open',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=1,)
            ),
            name=f'Mean val {i}',
            text=maxaes_labels[1][i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='val_mean',
            visible=False,
            showlegend=True,
        )
        mean_traces.append(trace)

        # Validation std
        trace = go.Scatter(
            x=[val_stds[0][i]],
            y=[val_stds[1][i]],
            mode='markers',
            marker=dict(
                symbol='triangle-up-open',
                size=5,
                color=colors[i % len(colors)],
                line=dict(width=0)
            ),
            name=f'Std validation {i}',
            text=maxaes_labels[1][i],
            # hovertemplate='True: %{x:.2f}<br>Pred: %{y:.2f}<br>%{text}',
            legendgroup='val_std',
            visible=False,
            showlegend=True,
        )
        std_traces.append(trace)
        


    # Add identity line
    train_flattened_data = ([np.min(train_data[0][i]) for i in range(n_train_samples)], [np.max(train_data[0][i]) for i in range(n_val_samples)]) # [min, max]
    # print((train_flattened_data[0]))
    train_flattened_data = [train_flattened_data[k] for k in range(len(train_flattened_data))]
    # print(len(train_flattened_data))

    min, max = np.min(train_flattened_data[0]), np.max(train_flattened_data[1])
    diagonal_trace = go.Scatter(
        x=[min, max],
        y=[min, max],
        mode='lines',
        line=dict(color='black', dash='solid'),
        name='Ideal'
    )


    # Last dropdown: Max Absolute error
    error_trace_train = go.Scatter(
        x=maxaes[0],
        y=maxaes_labels[0],
        mode='markers',
        marker=dict(
            symbol='x',
            size=6,
            color="blue"
        ),
        name='Training',
        showlegend=False  # optional: hide legend for this simple plot
    )
    # Last dropdown: Max Absolute error
    error_trace_val = go.Scatter(
        x=maxaes[1],
        y=maxaes_labels[1],
        mode='markers',
        marker=dict(
            symbol='x',
            size=6,
            color="red"
        ),
        name='Validation',
        showlegend=False  # optional: hide legend for this simple plot
    )
    # zero_line_trace = go.Scatter(
    #     x=[0, 0],
    #     y=[maxaes_labels[1][-1], maxaes_labels[0][-1]],  # or your preferred Y range
    #     mode='lines',
    #     line=dict(color='black', dash='dash'),
    #     name='zero',
    #     showlegend=False
    # )






    # Create figure and update layout
    traces = matrix_traces + mean_traces + std_traces + [error_trace_train] + [error_trace_val] + [diagonal_trace]
    fig = go.Figure(data=traces)
    fig.update_layout(
        width=1000,
        height=1000,
        title=title,
        # xaxis_title='True Values',
        # yaxis_title='Predicted Values',
        legend_title='Legend',
        # hovermode='closest',
        # template='plotly_white',
        xaxis=dict(
            title='True Values',
            tickformat=".2f"
        ),
        yaxis=dict(
            title='Predicted Values',
            tickformat=".2f"
        )
    )

    # Add dropdown
    data_true = np.concatenate(train_data[0] + val_data[0], axis=0)
    data_pred = np.concatenate(train_data[1] + val_data[1], axis=0)
    min_x_data = data_true.min()
    max_x_data = data_true.max()
    min_y_data = data_pred.min()
    max_y_data = data_pred.max()

    min_x_mean = np.min([train_means[0].min(), val_means[0].min()])
    max_x_mean = np.max([train_means[0].max(), val_means[0].max()])
    min_y_mean = np.min([train_means[1].min(), val_means[1].min()])
    max_y_mean = np.max([train_means[1].max(), val_means[1].max()])

    min_x_std = np.min([train_stds[0].min(), val_stds[0].min()])
    max_x_std = np.max([train_stds[0].max(), val_stds[0].max()])
    min_y_std = np.min([train_stds[1].min(), val_stds[1].min()])
    max_y_std = np.max([train_stds[1].max(), val_stds[1].max()])

    fig.update_layout(
        updatemenus=[
            dict(
                buttons=[
                    dict(
                        label="SISL Hamiltonian elements",
                        method="update",
                        args=[{"visible": [True]*len(matrix_traces) + [False]*len(mean_traces) + [False]*len(std_traces) + [False]*2 + [True]},
                            {
                                "xaxis": {"range": [min_x_data-0.05*min_x_data, max_x_data+0.05*max_x_data]},
                                "yaxis": {"range": [min_y_data-0.05*min_y_data, max_y_data+0.05*max_y_data]},
                            }]
                    ),
                    dict(
                        label="Mean",
                        method="update",
                        args=[
                            {"visible": [False]*len(matrix_traces) + [True]*len(mean_traces) + [False]*len(std_traces) + [False]*2 + [True]},
                            {
                                "xaxis": {"range": [min_x_mean-0.0005*min_x_mean, max_x_mean+0.0005*max_x_mean]},
                                "yaxis": {"range": [min_y_mean-0.0005*min_y_mean, max_y_mean+0.0005*max_y_mean]},
                            }

                        ]
                    ),
                    dict(
                        label="Std",
                        method="update",
                        args=[
                            {"visible": [False]*len(matrix_traces) + [False]*len(mean_traces) + [True]*len(std_traces) + [False]*2 + [True]},
                            {
                                "xaxis": {"range": [min_x_std-0.0005*min_x_std, max_x_std+0.0005*max_x_std]},
                                "yaxis": {"range": [min_y_std-0.0005*min_y_std, max_y_std+0.0005*max_y_std]},
                            }

                        ]
                    ),
                    dict(
                        label="Max Absolute Error",
                        method="update",
                        args=[
                            {"visible": [False]*len(matrix_traces) + [False]*len(mean_traces) + [False]*len(std_traces) + [True]*2 + [False]},
                            {"xaxis": {"title": "Max Absolute Error"},
                            "yaxis": {
                                "title": "Structures",
                                "type": "category",
                                #  "categoryorder": "array",
                                "categoryarray": maxaes_labels,
                                "autorange": "reversed"
                            },
                            "showlegend": [False]*len(matrix_traces + mean_traces + std_traces) + [False]*2 + [False]}
                        ]
                    )


                ],
                direction="down",
                pad={"r": 10, "t": 10},
                showactive=True,
                x=0.3,
                xanchor="left",
                y=1.1,
                yanchor="top"
            ),
        ]
    )

    # Save to HTML if path is provided
    if filepath:
        f = open(filepath, "w")
        f.close()
        with open(filepath, 'a') as f:
            f.write(fig.to_html(full_html=False, include_plotlyjs='cdn'))
        f.close()
        
        with open(f"{str(filepath)[:-4]}.json", "w") as f:
            f.write(fig.to_json())

    else:
        fig.show()

    
    return fig

In [None]:
filepath= savedir / "dataset_analysis.html"
title = f"Dataset analysis. Used model {model_dir.parts[-1]}"
print("Generating results...")
plot_dataset_results(
    train_data=train_data, val_data=val_data,
    colors=colors, title=title,
    train_labels=train_labels, val_labels=val_labels,
    train_means=train_means, val_means=val_means,
    train_stds=train_stds, val_stds=val_stds,
    maxaes=maxaes, maxaes_labels=maxaes_labels,
    filepath=None
)
print(f"Results saved at {filepath}!")

# Loading the json

In [None]:
import plotly.io as pio
fig = pio.read_json(savedir / "dataset_analysis..json")
fig.show()