In [None]:
import sys
sys.path.insert(0, '/home/alapena/GitHub/graph2mat4abn')
import os
os.chdir('/home/alapena/GitHub/graph2mat4abn') # Change to the root directory of the project

from graph2mat4abn.tools.import_utils import load_config, get_object_from_module
from graph2mat4abn.tools.tools import get_basis_from_structures_paths, get_kwargs, load_model
from graph2mat4abn.tools.scripts_utils import get_model_dataset, generate_g2m_dataset_from_paths, init_mace_g2m_model
from pathlib import Path
from e3nn import o3
from mace.modules import MACE, RealAgnosticResidualInteractionBlock
from graph2mat.models import MatrixMACE
from graph2mat.bindings.e3nn import E3nnGraph2Mat
import torch
import warnings
from graph2mat import BasisTableWithEdges
import sisl
from torch_geometric.loader import DataLoader
import numpy as np
from plotly.subplots import make_subplots

warnings.filterwarnings("ignore", message="The TorchScript type system doesn't support")
warnings.filterwarnings("ignore", message=".*is not a known matrix type key.*")

# *********************************** #
# * VARIABLES TO CHANGE BY THE USER * #
# *********************************** #
n_k_bands = 80
n_k_dos = 50
path = Path("data_hBN")
model_dir = Path("results/h_crystalls_5")
filename = "val_best_model.tar"
savedir = Path("results/h_crystalls_8/results/hBN")
# *********************************** #

In [None]:
# === Model init ===

# Load config
config = load_config(model_dir / "config.yaml")
device = torch.device("cpu")

# Generate the G2M basis
# Load the same dataset used to train/validate the model (paths)
train_paths, val_paths = get_model_dataset(model_dir, verbose=True) # Load model's dataset but just to compute the basis.
paths = train_paths + val_paths
basis = get_basis_from_structures_paths(paths, verbose=True, num_unique_z=config["dataset"].get("num_unique_z", None))
table = BasisTableWithEdges(basis)

# Convert data path to dataset
dataset, processor = generate_g2m_dataset_from_paths(config, basis, table, [path], device=device, verbose=True)

# Init model
model, optimizer, lr_scheduler, loss_fn = init_mace_g2m_model(config, table)

# Load model
model_path = model_dir / filename
initial_lr = float(config["optimizer"].get("initial_lr", None))
model, checkpoint, optimizer, lr_scheduler = load_model(model, optimizer, model_path, lr_scheduler=lr_scheduler, initial_lr=initial_lr, device=device)
print(f"Loaded model in epoch {checkpoint["epoch"]} with training loss {checkpoint["train_loss"]} and validation loss {checkpoint["val_loss"]}.")

Loaded 465 training paths and 117 validation paths.
Basis computation.
Number of structures to look on: 582
Looking for unique atoms in each structure...


1it [00:00, 59.82it/s]


Found enough basis points. Breaking the search...
Found enough basis points. Breaking the search...
Found the following atomic numbers: [5, 7]
Corresponding path indices: [0, 0]
Basis with 2 elements built!

Basis for atom 0.
	Atom type: 5
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [3.02420918 2.02341372 3.73961942 3.73961942 3.73961942 2.51253945
 2.51253945 2.51253945 3.73961942 3.73961942 3.73961942 3.73961942
 3.73961942]

Basis for atom 1.
	Atom type: 7
	Basis: ((2, 0, 1), (2, 1, -1), (1, 2, 1))
	Basis convention: siesta_spherical
	R: [2.25704422 1.4271749  2.78012609 2.78012609 2.78012609 1.75309697
 1.75309697 1.75309697 2.78012609 2.78012609 2.78012609 2.78012609
 2.78012609]
Generating dataset...
Generating split 0...


1it [00:00, 46.04it/s]

Keeping all the dataset in memory.




To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



Using Optimizer Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0001
    maximize: False
    weight_decay: 0
)
LR Scheduler:  ReduceLROnPlateau
Arguments:  None
Keyword arguments:  {'cooldown': 0, 'eps': 0.0, 'factor': 0.9, 'min_lr': 1e-09, 'mode': 'min', 'patience': 100}
Using Loss function <class 'graph2mat.core.data.metrics.block_type_mse'>
Loaded model in epoch 110870 with training loss 350.75750732421875 and validation loss 39898.26953125.


In [None]:
# Prediction

data = DataLoader([dataset[0]], 1)
data = next(iter(data))

model_predictions = model(data=data)

# Reconstruct matrices
h_pred = processor.matrix_from_data(data, predictions={"node_labels": model_predictions["node_labels"], "edge_labels": model_predictions["edge_labels"]})[0]
h_true = processor.matrix_from_data(data)[0]

In [None]:
# 1. Plot structure
savedir_struct = savedir
file = sisl.get_sile(path / "aiida.fdf")
fig = file.plot.geometry(axes="xyz")
fig.show()

INFO	Task(Task-3) nodify.node.140451552435216:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552429840:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552685728:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552442080:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451541118240:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552682944:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552429840:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552685728:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552442080:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451541118240:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552682944:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.

INFO	Task(Task-3) nodify.node.140451552678384:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552443040:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552435936:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140457023470864:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552673968:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552441792:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552678432:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552678096:node.py:get()- Evaluated because inputs changed.
INFO	Task(Task-3) nodify.node.140451552435936:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140457023470864:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552673968:node.py:get()- No need to evaluate
INFO	Task(Task-3) nodify.node.140451552

In [None]:
from graph2mat4abn.tools.script_plots import plot_hamiltonian


plot_hamiltonian(
    h_true.todense(), h_pred.todense(),
    matrix_label="Hamiltonian",
    figure_title="Standard hBN",
    predicted_matrix_text=None,
    filepath=None
)
None

In [None]:
# Feed tbplas

import tbplas as tb
from tools.tbplas_tools import add_hopping_terms, add_orbitals, get_hoppings, get_onsites


file = sisl.get_sile(path / "aiida.HSX")
geometry = file.read_geometry()

# Feed TBPLaS with our data

# Empty cell
n_atoms=2
vectors = geometry.cell
cell_true = tb.PrimitiveCell(vectors, unit=tb.ANG)
cell_pred = tb.PrimitiveCell(vectors, unit=tb.ANG)

# Add orbitals
positions = geometry.xyz
labels = [[orb.name() for orb in atom] for atom in geometry.atoms]
onsites_true = get_onsites(h_true.tocsr().tocoo())
onsites_pred = get_onsites(h_pred.tocsr().tocoo())

add_orbitals(cell_true, positions, onsites_true, labels)
add_orbitals(cell_pred, positions, onsites_pred, labels)

# Add hoppings
iscs_true, orbs_in_true, orbs_out_true, hoppings_true = get_hoppings(h_true.tocsr().tocoo(), n_atoms, geometry)
iscs_pred, orbs_in_pred, orbs_out_pred, hoppings_pred = get_hoppings(h_pred.tocsr().tocoo(), n_atoms, geometry)

add_hopping_terms(cell_true, iscs_true, orbs_in_true, orbs_out_true, hoppings_true)
add_hopping_terms(cell_pred, iscs_pred, orbs_in_pred, orbs_out_pred, hoppings_pred)

# Create overlap matrix
overlap_true = tb.PrimitiveCell(cell_true.lat_vec, cell_true.origin, 1.0)

o_true = file.read_overlap()
onsites_overlap_true = get_onsites(o_true.tocsr().tocoo())
for k in range(cell_true.num_orb):
    orbital = cell_true.orbitals[k]
    overlap_true.add_orbital(orbital.position, onsites_overlap_true[k])

iscs_overlap_true, orbs_in_overlap_true, orbs_out_overlap_true, hoppings_overlap_true = get_hoppings(o_true.tocsr().tocoo(), n_atoms, geometry)
add_hopping_terms(overlap_true, iscs_overlap_true, orbs_in_overlap_true, orbs_out_overlap_true, hoppings_overlap_true)

# Plot Brillouin Zone

import numpy as np
import plotly.graph_objects as go
from itertools import product
from pathlib import Path

def plot_brillouin_zone(a1, a2, a3, filepath=None, points=None, labels=None):
    a1 = np.array(a1)
    a2 = np.array(a2)
    a3 = np.array(a3)
    origin = np.zeros(3)

    corners = []
    for signs in product([0, 1], repeat=3):
        corner = origin + signs[0]*a1 + signs[1]*a2 + signs[2]*a3
        corners.append(corner)
    corners = np.array(corners)

    faces = [
        [0,1,3,2], # a3=0
        [4,5,7,6], # a3=1
        [0,1,5,4], # a2=0
        [2,3,7,6], # a2=1
        [0,2,6,4], # a1=0
        [1,3,7,5]  # a1=1
    ]

    x, y, z = corners[:,0], corners[:,1], corners[:,2]
    fig = go.Figure()

    for face in faces:
        fig.add_trace(go.Mesh3d(
            x=x[face], y=y[face], z=z[face],
            color='lightblue',
            opacity=0.5,
            alphahull=0,
            showscale=False
        ))

    for v, name, color in zip([a1, a2, a3], ['a1','a2','a3'], ['red','green','blue']):
        fig.add_trace(go.Scatter3d(
            x=[0,v[0]], y=[0,v[1]], z=[0,v[2]],
            mode='lines+markers+text',
            marker=dict(size=4, color=color),
            line=dict(width=6, color=color),
            text=[name, ''],
            textposition="top center"
        ))

    if points is not None:
        points = np.array(points)
        if labels is not None:
            fig.add_trace(go.Scatter3d(
                x=points[:,0], y=points[:,1], z=points[:,2],
                mode='lines+markers+text',
                marker=dict(size=6, color='black'),
                line=dict(width=3, color='black'),
                text=labels,
                textposition="top center",
                name='k-points'
            ))
        else:
            fig.add_trace(go.Scatter3d(
                x=points[:,0], y=points[:,1], z=points[:,2],
                mode='lines+markers',
                marker=dict(size=6, color='black'),
                line=dict(width=3, color='black'),
                name='k-points'
            ))

    fig.update_layout(
        scene=dict(
            xaxis_title='x', yaxis_title='y', zaxis_title='z',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, t=0, b=0)
    )

    if filepath:
        filepath = Path(filepath)
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath))
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
    else:
        fig.show()

    return fig


def combine_band_and_dos(fig_band, fig_dos, filepath=None):
    """
    Combine band structure and DOS plots side by side into a single figure.

    Parameters:
    - fig_band: Plotly figure from plot_bands()
    - fig_dos: Plotly figure from plot_dos()

    Returns:
    - Combined Plotly figure
    """
    # Create 1-row, 2-column subplot
    fig = make_subplots(rows=1, cols=2, shared_yaxes=True,
                        column_widths=[0.75, 0.25],
                        horizontal_spacing=0.02,
                        specs=[[{"type": "xy"}, {"type": "xy"}]])

    # Add band traces to subplot (1,1)
    for trace in fig_band.data:
        fig.add_trace(trace, row=1, col=1)

    # Add DOS traces to subplot (1,2)
    for trace in fig_dos.data:
        fig.add_trace(trace, row=1, col=2)

    # Update layout
    fig.update_layout(
        xaxis=dict(title='k (1/nm)'),  # subplot (1,1)
        xaxis2=dict(title='DOS (1/eV)'),  # subplot (1,2)
        yaxis=dict(title='Energy (eV)'),  # shared y-axis
        showlegend=True,
        margin=dict(l=50, r=20, t=20, b=40)
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig

def plot_bands(k_len, bands, k_idx, k_label, predicted_bands=None, filepath=None):
    """
    Plot band structure using Plotly.

    Parameters:
    - k_len: array-like, k-point distances
    - bands: 2D array, shape (n_kpoints, n_bands), true bands
    - k_idx: list of indices where vertical lines are drawn
    - k_label: list of labels for xticks
    - predicted_bands: optional 2D array, same shape as `bands`
    """
    fig = go.Figure()
    num_bands = bands.shape[1]

    # True bands: solid black
    for i in range(num_bands):
        fig.add_trace(go.Scatter(
            x=k_len,
            y=bands[:, i],
            mode='lines',
            name=f'True Band {i+1}',
            line=dict(color='black', width=1, dash='solid'),
            legendgroup=f'Band {i+1}',
            showlegend=True if predicted_bands is None else False
        ))

    # Predicted bands: dashed black
    if predicted_bands is not None:
        for i in range(num_bands):
            fig.add_trace(go.Scatter(
                x=k_len,
                y=predicted_bands[:, i],
                mode='lines',
                name=f'Predicted Band {i+1}',
                line=dict(color='black', width=1, dash='dash'),
                legendgroup=f'Band {i+1}',
                showlegend=True
            ))

    # Vertical lines
    for idx in k_idx:
        fig.add_shape(type="line",
                      x0=k_len[idx], y0=bands.min(), x1=k_len[idx], y1=bands.max(),
                      line=dict(color="black", width=1))

    # Layout
    fig.update_layout(
        xaxis=dict(
            title="k (1/nm)",
            tickmode='array',
            tickvals=[k_len[i] for i in k_idx],
            ticktext=k_label,
            ticks='',  # Hide tick marks
            showticklabels=True,
            range=[0, k_len.max()]
        ),
        yaxis=dict(title="Energy (eV)"),
        margin=dict(l=50, r=20, t=20, b=50),
        showlegend=True
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig

def plot_dos(energies, dos, predicted_dos=None, filepath=None):
    """
    Plot Density of States (DOS) vertically with Energy on y-axis.

    Parameters:
    - energies: array-like, energy values (eV)
    - dos: array-like, true DOS values
    - predicted_dos: optional array-like, predicted DOS values (same shape as dos)
    """
    fig = go.Figure()

    # True DOS
    fig.add_trace(go.Scatter(
        x=dos,
        y=energies,
        mode='lines',
        name='True',
        line=dict(color='black', width=1, dash='solid')
    ))

    # Predicted DOS
    if predicted_dos is not None:
        fig.add_trace(go.Scatter(
            x=predicted_dos,
            y=energies,
            mode='lines',
            name='Pred',
            line=dict(color='black', width=1, dash='dash')
        ))

    # Layout
    fig.update_layout(
        xaxis=dict(title='DOS (1/eV)'),
        yaxis=dict(title='Energy (eV)'),
        margin=dict(l=50, r=50, t=20, b=20),
        showlegend=True
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig


In [None]:
b1, b2, b3 = cell_true.get_reciprocal_vectors()/10 # Each shape (3), Angstrom^-1
B = np.vstack([b1, b2, b3])  # shape (3,3)
Gamma = np.array([0.0, 0.0, 0.0])
k_cart = np.array([Gamma, b1/2, 1/3*(b1 + 2*b2), Gamma])
k_label = [r"$\Gamma$", "M", "K", r"$\Gamma$"]

k_frac = np.array([np.linalg.solve(B.T, k) for k in k_cart])
k_path, k_idx = tb.gen_kpath(k_frac, [n_k_bands for _ in range(len(k_frac) -1)])

plot_brillouin_zone(b1, b2, b3, points=k_cart, labels=k_label) #I am still not sure if the kpath corresponds (because of where the atoms are located)
None

In [None]:
# Bands
solver = tb.DiagSolver(cell_true, overlap_true)
solver.config.k_points = k_path
solver.config.prefix = "bands_true"
timer = tb.Timer()
timer.tic("bands_true")
k_len_true, bands_true = solver.calc_bands()
timer.toc("bands_true")
timer.report_total_time()

solver = tb.DiagSolver(cell_pred, overlap_true)
solver.config.k_points = k_path
solver.config.prefix = "bands_pred"
timer.tic("bands_pred")
k_len_pred, bands_pred = solver.calc_bands()
timer.toc("bands_pred")
timer.report_total_time()

# DOS
k_mesh = tb.gen_kmesh((n_k_dos, n_k_dos, n_k_dos))  # Uniform meshgrid
e_min = float(np.min(bands_true))
e_max = float(np.max(bands_true))

solver = tb.DiagSolver(cell_true, overlap_true)
solver.config.k_points = k_mesh
solver.config.e_min = e_min
solver.config.e_max = e_max
solver.config.prefix = "dos_true"
timer.tic("dos_true")
energies_true, dos_true = solver.calc_dos()
timer.toc("dos_true")
timer.report_total_time()


Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : bands_true

Using Eigen backend for diagonalization.
	 bands_true :    0.12220

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : bands_pred

Using Eigen backend for diagonalization.
	 bands_true :    0.12220
	 bands_pred :    0.12192

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : dos_true



In [None]:
# Predicted DOS

# e_min = float(np.min(bands_pred))
# e_max = float(np.max(bands_pred))
solver = tb.DiagSolver(cell_pred, overlap_true)
solver.config.k_points = k_mesh
solver.config.e_min = e_min
solver.config.e_max = e_max
solver.config.prefix = "dos_pred"
print("e_min=", e_min, "e_max=", e_max)
timer.tic("dos_pred")
energies_pred, dos_pred = solver.calc_dos()
timer.toc("dos_pred")
timer.report_total_time()


Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

e_min= -18.97060533086461 e_max= 183.9712611130489
Output details:
  Directory  : ./
  Prefix     : dos_pred

Using Eigen backend for diagonalization.
	 bands_true :    0.11967
	 bands_pred :    0.11520
	 dos_true   :    0.77182
	 dos_pred   :    0.77261


In [None]:
fig_bands = plot_bands(k_len_true, bands_true, k_idx, k_label, predicted_bands=bands_pred)
fig_dos = plot_dos(energies_true, dos_true, predicted_dos=dos_pred)

combine_band_and_dos(fig_bands, fig_dos)
None