In [1]:
import numpy as np
import tbplas as tb
import sisl
from pathlib import Path
import time

def add_orbitals(cell: tb.PrimitiveCell, positions, onsites, labels) -> None:
    """
    Add orbitals to the model.

    There are n_atoms atoms, with n_orbs orbitals each in that same position. We will extract those orbitals from the atom info.
    """
    for i in range(positions.shape[0]):
        n_orbs = len(labels[i])
        for j in range(n_orbs):
            cell.add_orbital_cart(positions[i], unit=tb.ANG, energy=onsites[i*n_orbs+j], label=labels[i][j])


def add_hopping_terms(cell: tb.PrimitiveCell, iscs, orbs_in, orbs_out, hoppings) -> None:
    n_hops = len(iscs)
    for i in range(n_hops):
        cell.add_hopping(rn=iscs[i], orb_i=orbs_in[i], orb_j=orbs_out[i], energy=hoppings[i])



In [2]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def plot_bands(k_len, bands, k_idx, k_label, predicted_bands=None, filepath=None):
    """
    Plot band structure using Plotly.

    Parameters:
    - k_len: array-like, k-point distances
    - bands: 2D array, shape (n_kpoints, n_bands), true bands
    - k_idx: list of indices where vertical lines are drawn
    - k_label: list of labels for xticks
    - predicted_bands: optional 2D array, same shape as `bands`
    """
    fig = go.Figure()
    num_bands = bands.shape[1]

    # True bands: solid black
    for i in range(num_bands):
        fig.add_trace(go.Scatter(
            x=k_len,
            y=bands[:, i],
            mode='lines',
            name=f'True Band {i+1}',
            line=dict(color='black', width=1, dash='solid'),
            legendgroup=f'Band {i+1}',
            showlegend=True if predicted_bands is None else False
        ))

    # Predicted bands: dashed black
    if predicted_bands is not None:
        for i in range(num_bands):
            fig.add_trace(go.Scatter(
                x=k_len,
                y=predicted_bands[:, i],
                mode='lines',
                name=f'Predicted Band {i+1}',
                line=dict(color='black', width=1, dash='dash'),
                legendgroup=f'Band {i+1}',
                showlegend=True
            ))

    # Vertical lines
    for idx in k_idx:
        fig.add_shape(type="line",
                      x0=k_len[idx], y0=bands.min(), x1=k_len[idx], y1=bands.max(),
                      line=dict(color="black", width=1))

    # Layout
    fig.update_layout(
        xaxis=dict(
            title="k (1/nm)",
            tickmode='array',
            tickvals=[k_len[i] for i in k_idx],
            ticktext=k_label,
            ticks='',  # Hide tick marks
            showticklabels=True,
            range=[0, k_len.max()]
        ),
        yaxis=dict(title="Energy (eV)"),
        margin=dict(l=50, r=20, t=20, b=50),
        showlegend=True
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig



def plot_dos(energies, dos, predicted_dos=None, filepath=None):
    """
    Plot Density of States (DOS) vertically with Energy on y-axis.

    Parameters:
    - energies: array-like, energy values (eV)
    - dos: array-like, true DOS values
    - predicted_dos: optional array-like, predicted DOS values (same shape as dos)
    """
    fig = go.Figure()

    # True DOS
    fig.add_trace(go.Scatter(
        x=dos,
        y=energies,
        mode='lines',
        name='True',
        line=dict(color='black', width=1, dash='solid')
    ))

    # Predicted DOS
    if predicted_dos is not None:
        fig.add_trace(go.Scatter(
            x=predicted_dos,
            y=energies,
            mode='lines',
            name='Pred',
            line=dict(color='black', width=1, dash='dash')
        ))

    # Layout
    fig.update_layout(
        xaxis=dict(title='DOS (1/eV)'),
        yaxis=dict(title='Energy (eV)'),
        margin=dict(l=50, r=50, t=20, b=20),
        showlegend=True
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig


def combine_band_and_dos(fig_band, fig_dos, filepath=None):
    """
    Combine band structure and DOS plots side by side into a single figure.

    Parameters:
    - fig_band: Plotly figure from plot_bands()
    - fig_dos: Plotly figure from plot_dos()

    Returns:
    - Combined Plotly figure
    """
    # Create 1-row, 2-column subplot
    fig = make_subplots(rows=1, cols=2, shared_yaxes=True,
                        column_widths=[0.75, 0.25],
                        horizontal_spacing=0.02,
                        specs=[[{"type": "xy"}, {"type": "xy"}]])

    # Add band traces to subplot (1,1)
    for trace in fig_band.data:
        fig.add_trace(trace, row=1, col=1)

    # Add DOS traces to subplot (1,2)
    for trace in fig_dos.data:
        fig.add_trace(trace, row=1, col=2)

    # Update layout
    fig.update_layout(
        xaxis=dict(title='k (1/nm)'),  # subplot (1,1)
        xaxis2=dict(title='DOS (1/eV)'),  # subplot (1,2)
        yaxis=dict(title='Energy (eV)'),  # shared y-axis
        showlegend=True,
        margin=dict(l=50, r=20, t=20, b=40)
    )

    # === Output ===
    if filepath is not None:
        if filepath.suffix.lower() == ".html":
            fig.write_html(str(filepath))
        elif filepath.suffix.lower() == ".png":
            fig.write_image(str(filepath), height=1200, width=900,)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        
    else:
        fig.show()

    return fig

In [3]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def plot_bands_matplotlib(k_len, bands, k_idx, k_label, predicted_bands=None, filepath=None):
    """
    Plot band structure using Matplotlib.

    Parameters:
    - k_len: array-like, k-point distances
    - bands: 2D array, shape (n_kpoints, n_bands), true bands
    - k_idx: list of indices where vertical lines are drawn
    - k_label: list of labels for xticks
    - predicted_bands: optional 2D array, same shape as `bands`
    - filepath: Path or str, if not None, save to this file (supports .png, .pdf)
    """
    fig, ax = plt.subplots(figsize=(7, 9))  # Adjust size as needed

    num_bands = bands.shape[1]
    # Plot true bands
    for i in range(num_bands):
        ax.plot(k_len, bands[:, i], color='black', lw=1.0, label='True' if i==0 else "")

    # Plot predicted bands if given
    if predicted_bands is not None:
        for i in range(num_bands):
            ax.plot(k_len, predicted_bands[:, i], color='black', lw=1.0, ls='--', label='Predicted' if i==0 else "")

    # Draw vertical lines
    for idx in k_idx:
        ax.axvline(k_len[idx], color='k', lw=1.0)

    # X ticks and labels
    ax.set_xlim(0, np.amax(k_len))
    ax.set_xticks([k_len[i] for i in k_idx])
    ax.set_xticklabels(k_label)
    ax.set_xlabel("k (1/nm)")
    ax.set_ylabel("Energy (eV)")

    # Legend
    if predicted_bands is not None:
        handles, labels = ax.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax.legend(by_label.values(), by_label.keys(), loc='best')
    else:
        ax.legend().set_visible(False)

    plt.tight_layout()

    # Save or show
    if filepath is not None:
        filepath = Path(filepath)
        if filepath.suffix.lower() in [".png", ".pdf"]:
            plt.savefig(str(filepath), dpi=300)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        plt.close()
    else:
        plt.show()
        plt.close()

    return fig


import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def plot_dos_matplotlib(energies, dos, predicted_dos=None, filepath=None):
    """
    Plot Density of States (DOS) vertically with Energy on y-axis using matplotlib.

    Parameters:
    - energies: array-like, energy values (eV)
    - dos: array-like, true DOS values
    - predicted_dos: optional array-like, predicted DOS values (same shape as dos)
    - filepath: Path or str, if not None, save to this file (supports .png, .pdf)
    """
    fig, ax = plt.subplots(figsize=(6, 8))  # Adjust size as needed

    # Plot true DOS
    ax.plot(dos, energies, color='black', lw=1, label='True')

    # Plot predicted DOS if provided
    if predicted_dos is not None:
        ax.plot(predicted_dos, energies, color='black', lw=1, ls='--', label='Predicted')

    # Labels and limits
    ax.set_xlabel("DOS (1/eV)")
    ax.set_ylabel("Energy (eV)")

    ax.legend(loc='best')
    plt.tight_layout()

    # Save or show
    if filepath is not None:
        filepath = Path(filepath)
        if filepath.suffix.lower() in [".png", ".pdf"]:
            plt.savefig(str(filepath), dpi=300)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        plt.close()
    else:
        plt.show()
        plt.close()

    return fig



import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def plot_bands_and_dos_matplotlib(
    k_len, bands, k_idx, k_label,
    energies, dos,
    predicted_bands=None, predicted_dos=None,
    filepath=None
):
    """
    Plot band structure and DOS side-by-side with shared energy axis.

    Parameters:
    - k_len: array-like, k-point distances
    - bands: 2D array, shape (n_kpoints, n_bands), true bands
    - k_idx: list of indices where vertical lines are drawn
    - k_label: list of labels for xticks
    - energies: array-like, energy values (eV)
    - dos: array-like, true DOS values
    - predicted_bands: optional 2D array, same shape as `bands`
    - predicted_dos: optional array-like, predicted DOS values (same shape as dos)
    - filepath: Path or str, if not None, save to this file (supports .png, .pdf)
    """

    fig, (ax_band, ax_dos) = plt.subplots(
        nrows=1, ncols=2,
        figsize=(11, 8),
        gridspec_kw={'width_ratios': [3, 1], 'wspace': 0.10},
        sharey=True
    )

    # --- Band Structure ---
    num_bands = bands.shape[1]
    for i in range(num_bands):
        ax_band.plot(k_len, bands[:, i], color='black', lw=1.0, label='True' if i==0 else "")
    if predicted_bands is not None:
        for i in range(num_bands):
            ax_band.plot(k_len, predicted_bands[:, i], color='black', lw=1.0, ls='--', label='Predicted' if i==0 else "")
    for idx in k_idx:
        ax_band.axvline(k_len[idx], color='k', lw=1.0)

    ax_band.set_xlim(0, np.amax(k_len))
    ax_band.set_xticks([k_len[i] for i in k_idx])
    ax_band.set_xticklabels(k_label)
    ax_band.set_xlabel("k (1/nm)")
    ax_band.set_ylabel("Energy (eV)")

    if predicted_bands is not None:
        handles, labels = ax_band.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        ax_band.legend(by_label.values(), by_label.keys(), loc='best')
    else:
        ax_band.legend().set_visible(False)

    # --- DOS ---
    ax_dos.plot(dos, energies, color='black', lw=1, label='True')
    if predicted_dos is not None:
        ax_dos.plot(predicted_dos, energies, color='black', lw=1, ls='--', label='Predicted')
    ax_dos.set_xlabel("DOS (1/eV)")
    ax_dos.yaxis.set_tick_params(labelleft=False)
    ax_dos.legend(loc='best')

    plt.tight_layout()

    if filepath is not None:
        filepath = Path(filepath)
        if filepath.suffix.lower() in [".png", ".pdf"]:
            plt.savefig(str(filepath), dpi=300)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        plt.close(fig)
    else:
        plt.show()
        plt.close(fig)

    return fig



In [4]:
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

def combine_band_and_dos_from_figures(fig_band, fig_dos, filepath=None):
    """
    Combine band structure and DOS matplotlib figures into one figure with side-by-side subplots.

    Parameters:
    - fig_band: matplotlib Figure from plot_bands_matplotlib
    - fig_dos: matplotlib Figure from plot_dos_matplotlib
    - filepath: Path or str, if not None, save to this file (supports .png, .pdf)
    """
    # Extract axes and lines
    ax_band = fig_band.axes[0]
    ax_dos = fig_dos.axes[0]
    
    # Create new combined figure
    fig_comb, (ax1, ax2) = plt.subplots(
        nrows=1, ncols=2, 
        figsize=(9, 8), 
        gridspec_kw={'width_ratios': [3, 1], 'wspace': 0.06},
        sharey=True
    )
    
    # --- Band structure traces ---
    for line in ax_band.get_lines():
        ax1.plot(line.get_xdata(), line.get_ydata(),
                 color=line.get_color(),
                 lw=line.get_linewidth(),
                 ls=line.get_linestyle(),
                 label=line.get_label())
    # Copy vlines (vertical lines for k_idx)
    for vline in [c for c in ax_band.get_children() if isinstance(c, plt.Line2D) and c.get_linestyle() == '--']:
        ax1.axvline(vline.get_xdata()[0], color=vline.get_color(), lw=vline.get_linewidth())

    # X/Y labels and ticks for bands
    ax1.set_xlabel(ax_band.get_xlabel())
    ax1.set_ylabel(ax_band.get_ylabel())
    ax1.set_xticks(ax_band.get_xticks())
    ax1.set_xticklabels([lbl.get_text() for lbl in ax_band.get_xticklabels()])

    # --- DOS traces ---
    for line in ax_dos.get_lines():
        ax2.plot(line.get_xdata(), line.get_ydata(),
                 color=line.get_color(),
                 lw=line.get_linewidth(),
                 ls=line.get_linestyle(),
                 label=line.get_label())

    # X/Y labels for DOS
    ax2.set_xlabel(ax_dos.get_xlabel())
    ax2.yaxis.set_tick_params(labelleft=False)

    # Legends
    # Only show legend on first subplot, combine unique labels
    handles1, labels1 = ax1.get_legend_handles_labels()
    by_label = dict(zip(labels1, handles1))
    ax1.legend(by_label.values(), by_label.keys(), loc='best')
    handles2, labels2 = ax2.get_legend_handles_labels()
    by_label2 = dict(zip(labels2, handles2))
    if by_label2:
        ax2.legend(by_label2.values(), by_label2.keys(), loc='best')

    plt.tight_layout()
    
    if filepath is not None:
        filepath = Path(filepath)
        if filepath.suffix.lower() in [".png", ".pdf"]:
            plt.savefig(str(filepath), dpi=300)
        else:
            raise ValueError(f"Unsupported file extension: {filepath.suffix}")
        plt.close(fig_comb)
    else:
        plt.show()
        plt.close(fig_comb)
    
    return fig_comb


In [5]:
path = Path("../../dataset/SHARE_OUTPUTS_8_ATOMS/0a2a-fbca-4649-8012-f5aa640bfd1d")
structure = path.parts[-1]
file = sisl.get_sile(path / "aiida.HSX")
geometry = file.read_geometry()

# Empty cell
vectors = geometry.cell
cell = tb.PrimitiveCell(vectors, unit=tb.ANG)

# Add orbitals
positions = geometry.xyz
labels = [[orb.name() for orb in atom] for atom in geometry.atoms]

# To add the orbitals we need the onsite energies.
h = file.read_hamiltonian()
h_mat = h.tocsr().tocoo()

rows = h_mat.row
cols = h_mat.col
data = h_mat.data

# Main diagonal length:
n_diag = min(h_mat.shape[0], h_mat.shape[1])

# Loop through all diagonal elements
onsites = np.zeros(n_diag, dtype=data.dtype)
for i in range(n_diag):
    # Find where both row and col equal i
    mask = (rows == i) & (cols == i)
    vals = data[mask]
    if len(vals) > 0:
        onsites[i] = vals[0]  # In COO, there could be duplicates, but take the first
    else:
        onsites[i] = 0  # Or np.nan if you prefer

# onsites = 
add_orbitals(cell, positions, onsites, labels)

# Add hopping terms.
# We need to iterate though each nnz element of h and get the isc in a tuple, the orb_in, the orb_out and the hopping value.
nnz = len(data)
n_orbs = len(labels[0]) # Assuming all atoms have the same nr of orbitals
n_atoms = len(positions)
iscs = []
orbs_in = []
orbs_out = []
hoppings = []
for k in range(nnz):
    row = rows[k]
    col = cols[k]
    if row != col:  # Only add hopping terms for off-diagonal elements
        iscs.append(geometry.o2isc(col))
        orbs_in.append(col % (n_atoms*n_orbs))
        orbs_out.append(row)
        hoppings.append(data[k])
        

add_hopping_terms(cell, iscs, orbs_in, orbs_out, hoppings)

In [None]:
# Calculations

# Define a path in k-space
k_dir_x = geometry.rcell[:,0]
k_dir_y = geometry.rcell[:,1]
k_dir_z = geometry.rcell[:,2]
k_points = np.array([
    [0.0, 0.0, 0.0],    # Gamma
    k_dir_x,
    k_dir_x + k_dir_y,
    k_dir_x + k_dir_y + k_dir_z
])
k_label = ["G", "X", "Y", "Z-G"]

n_ks_vect = np.linspace(14, 50, 50-14 +1, dtype=int)
n_ks_bands = []
bands_time = []
bands_list = []
dos_time = []
dos_list = []
for n_ks in n_ks_vect:
    k_path, k_idx = tb.gen_kpath(k_points, [n_ks, n_ks, n_ks])

    print(f"Calculating for {n_ks} k-points...")


    # Bands
    n_ks_bands.append(n_ks)
    solver = tb.DiagSolver(cell)
    solver.config.k_points = k_path
    time1 = time.time()
    k_len, bands = solver.calc_bands()
    time2 = time.time()
    bands_list.append((k_len, bands))
    bands_time.append(time2 - time1)

    print(f"BANDS Time taken for {n_ks} k-points: {time2 - time1:.2f} seconds")


    # DOS
    k_mesh = tb.gen_kmesh((3*n_ks, 3*n_ks, 3*n_ks))  # Uniform meshgrid
    e_min = np.min(bands)
    e_max = np.max(bands)
    solver = tb.DiagSolver(cell)
    solver.config.k_points = k_mesh
    solver.config.e_min = e_min
    solver.config.e_max = e_max
    time3 = time.time()
    energies, dos = solver.calc_dos()
    time4 = time.time()
    dos_list.append((energies, dos))
    dos_time.append(time4 - time3)

    print(f"DOS Time taken for {n_ks} k-points: {time4 - time3:.2f} seconds")

    savedir = Path("kmesh_study")
    savedir.mkdir(exist_ok=True, parents=True)
    filepath = savedir / f"{n_atoms}atm_{structure}_bandsdos.png"
    fig = plot_bands_and_dos_matplotlib(
        k_len, bands, k_idx, k_label,
        energies, dos,
        predicted_bands=None,
        predicted_dos=None,
        filepath=filepath
    )



Calculating for 14 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 14 k-points: 0.14 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
DOS Time taken for 14 k-points: 1300.79 seconds


  plt.tight_layout()


Calculating for 15 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 15 k-points: 0.12 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
DOS Time taken for 15 k-points: 1599.31 seconds


  plt.tight_layout()


Calculating for 16 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 16 k-points: 0.14 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
DOS Time taken for 16 k-points: 1949.75 seconds


  plt.tight_layout()


Calculating for 17 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 17 k-points: 0.25 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
DOS Time taken for 17 k-points: 2328.10 seconds


  plt.tight_layout()


Calculating for 18 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 18 k-points: 0.14 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
DOS Time taken for 18 k-points: 2761.95 seconds


  plt.tight_layout()


Calculating for 19 k-points...

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

Using Eigen backend for diagonalization.
BANDS Time taken for 19 k-points: 0.14 seconds

Parallelization details:
  MPI disabled    
  OMP_NUM_THREADS  : n/a   
  MKL_NUM_THREADS  : n/a   

Output details:
  Directory  : ./
  Prefix     : sample

