# VASPbook

This notebook is designed to:
- Visualize the electronic properties of a system with and without Spin-Orbit Coupling (SOC)
- Assess the contributions of individual atomic species and specific orbitals
- Compare the band structures of two or more configurations
- Visualize unfolded band structures obtained using bands4vasp
- Examine shift currents calculated with Wannier90

The notebook was developed in Python 3.12.7.

## 0. Packages

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import gridspec
import matplotlib.pyplot as plt
import warnings
import matplotlib.cm as cm
import matplotlib as mpl
from matplotlib import colors
import warnings
from matplotlib.lines import Line2D


# 1. Functions

This section contains the functions. Do not modify them unless  it is required.

## 1.1 Data collection

### 1.1.1 Density Of States (DOS) & Projected Density Of States (PDOS)

In [None]:
def extract_DOS(percorso, num_ions):
    """
    Extracts the full Density of States (DOS) from a VASP DOSCAR file.

    Parameters
    ----------
    percorso : str
        Path to the DOSCAR file.
    num_ions : int
        Number of ions in the system.

    Returns
    -------
    dict
        Dictionary containing the following keys:

        - `DOS_total` : np.ndarray of shape (N_points, 3)
            Total DOS including columns: Energy (shifted by Efermi), DOS, IDOS.
        - `Efermi` : float
            Fermi energy of the system.
        - `PDOS_arrays` : tuple of np.ndarray
            Partial DOS per ion (old workflow format).
        - `PDOS` : tuple of dict
            Structured partial DOS per ion (new workflow format). Each dict contains:
            
            - `'energy'` : np.ndarray
                Energy grid (shifted by Efermi).
            - `'total'` : np.ndarray
                Total DOS for the ion.
            - `'idos'` : np.ndarray, optional
                Integrated DOS (if present).
            - Orbital-resolved DOS, e.g. `'s'`, `'p_x'`, `'d_xy'`, etc.:
                - For non-spin-polarized systems: array of floats per orbital
                - For spin-polarized systems: dict with `'up'` and `'down'` arrays

    Notes
    -----
    - The energy axis of all DOS arrays is shifted by the Fermi energy.
    - Supports standard VASP PDOS formats:
        * 2 columns  : total DOS
        * 3 columns  : total + IDOS
        * 10 columns : s/p/d orbitals (non-spin)
        * 19 columns : s/p/d orbitals (spin-polarized)
    - Raises RuntimeError if an unrecognized PDOS format is encountered.
    """


    with open(percorso, 'r') as f:
        lines = f.readlines()

    # ==========================================================
    # ======================== HEADER ==========================
    # ==========================================================
    header_1 = [int(x) for x in lines[0].split()]
    header_2 = [float(x) for x in lines[1].split()]
    header_3 = [float(x) for x in lines[2].split()]
    compound = lines[4].strip()
    header_5 = [float(x) for x in lines[5].split()]

    num_points = int(header_5[2])
    E_f = header_5[3]

    # ==========================================================
    # ====================== DOS TOTALE ========================
    # ==========================================================
    DOS_total = np.array([
        [float(x) for x in lines[i].split()]
        for i in range(6, 6 + num_points)
    ])

    DOS_total[:, 0] -= E_f  # shift energia

    # ==========================================================
    # ===================== PDOS PER IONE ======================
    # ==========================================================
    PDOS_arrays = []
    PDOS_dicts = []

    for i in range(num_ions):
        start = ((i + 1) * (num_points + 1)) + 6
        end   = ((i + 2) * (num_points + 1)) + 5

        arr = np.array([
            [float(x) for x in lines[j].split()]
            for j in range(start, end)
        ])

        arr[:, 0] -= E_f
        PDOS_arrays.append(arr)

        # ---- versione strutturata ----
        pd = {"energy": arr[:, 0]}

        if arr.shape[1] == 2:
            pd["total"] = arr[:, 1]

        elif arr.shape[1] == 3:
            pd["total"] = arr[:, 1]
            pd["idos"]  = arr[:, 2]

        elif arr.shape[1] == 10:
            orbitals = ['s', 'p_y', 'p_z', 'p_x',
                        'd_xy', 'd_yz', 'd_z2-r2', 'd_xz', 'd_x2-y2']
            for k, orb in enumerate(orbitals):
                pd[orb] = arr[:, k + 1]

        elif arr.shape[1] == 19:
            orbitals = ['s', 'p_y', 'p_z', 'p_x',
                        'd_xy', 'd_yz', 'd_z2-r2', 'd_xz', 'd_x2-y2']
            col = 1
            for orb in orbitals:
                pd[orb] = {
                    "up": arr[:, col],
                    "down": arr[:, col + 1]
                }
                col += 2

        else:
            raise RuntimeError(
                f"Formato PDOS non riconosciuto per ione {i} "
                f"(colonne = {arr.shape[1]})"
            )

        PDOS_dicts.append(pd)

    # ==========================================================
    # ======================= OUTPUT ===========================
    # ==========================================================
    return {
        "DOS_total": DOS_total,
        "Efermi": E_f,
        "PDOS_arrays": tuple(PDOS_arrays),  # compatibile con codice vecchio
        "PDOS": tuple(PDOS_dicts)           # nuovo formato
    }


### 1.1.2 Bands Structure & Projected Bands Structure

In [None]:
def extract_BANDS_full(percorso_eigenval, percorso_procar, Efermi, soc=0):
    """
    Extracts the full band structure and orbital occupations from VASP EIGENVAL and PROCAR files.

    This function supports both spin-polarized (SOC) and non-spin-polarized cases.

    Parameters
    ----------
    percorso_eigenval : str
        Path to the EIGENVAL file.
    percorso_procar : str
        Path to the PROCAR file.
    Efermi : float
        Fermi energy used to shift the bands.
    soc : int, optional
        Spin-Orbit Coupling flag (0 = no SOC, 1 = SOC). Default is 0.

    Returns
    -------
    dict
        Dictionary containing:

        - `bands` : np.ndarray, shape (Nb, Nk)
            Band energies shifted by Efermi.
        - `kpoints` : np.ndarray, shape (Nk, 3)
            K-points in Cartesian coordinates.
        - `k_distances` : np.ndarray, shape (Nk,)
            Cumulative distances along the k-path.
        - `occupazioni` : np.ndarray
            Orbital occupations:
              * No SOC: shape (Nion, Nb, Nk, Norb)
              * SOC:    shape (Nion, Nb, Nk, 4*Norb)
        - `charge`, `sigma_x`, `sigma_y`, `sigma_z` : np.ndarray, only if SOC
            Spin-resolved contributions per ion and orbital.
        - `meta` : dict
            Metadata including:
              * Nk : number of k-points
              * Nb : number of bands
              * num_ions : number of ions
              * num_orbitals : number of orbitals per ion
              * Efermi : Fermi energy used
              * soc : SOC flag

    Notes
    -----
    - The EIGENVAL file provides the band energies and k-points.
    - The PROCAR file provides orbital occupations.
    - Occupation arrays are reshaped and transposed to match the convention:
      (Nion, Nb, Nk, Norb) for non-SOC, and concatenated for SOC.
    - Raises RuntimeError if data shapes are inconsistent with SOC setting.
    """

    # ==========================================================
    # ===================== EIGENVAL ===========================
    # ==========================================================
    with open(percorso_eigenval, "r") as f:
        lines = [line.strip() for line in f if line.strip() != ""]

    # ---- HEADER ----
    header = lines[5].split()
    Nk = int(header[1])
    Nb = int(header[2])

    kpoints = []
    bands = np.zeros((Nb, Nk))

    line_idx = 6   # subito dopo l'header

    k_index = 0
    while k_index < Nk and line_idx < len(lines):

        # ---- KPOINT ----
        k = list(map(float, lines[line_idx].split()[:3]))
        kpoints.append(k)
        line_idx += 1

        # ---- BANDS ----
        for b in range(Nb):
            energy = float(lines[line_idx].split()[1])
            bands[b, k_index] = energy
            line_idx += 1

        k_index += 1

    kpoints = np.array(kpoints)

    # ---- k-path cumulativo ----
    dk = np.linalg.norm(np.diff(kpoints, axis=0), axis=1)
    k_distances = np.concatenate([[0.0], np.cumsum(dk)])

    # ---- shift di Fermi ----
    bands -= Efermi

    # ==========================================================
    # ====================== PROCAR ============================
    # ==========================================================
    with open(percorso_procar, "r") as f:
        pro_lines = f.readlines()

    # -------- METADATA --------
    num_ions = None
    num_orbitals = None
    for line in pro_lines:
        if "# of ions:" in line:
            num_ions = int(line.split()[-1])
        if line.strip().startswith("ion"):
            num_orbitals = len(line.split()) - 1
        if num_ions is not None and num_orbitals is not None:
            break

    if num_ions is None or num_orbitals is None:
        raise RuntimeError("Impossibile determinare num_ions o num_orbitals dal PROCAR.")

    # -------- ESTRAZIONE DATI --------
    rows = []
    for line in pro_lines:
        line = line.strip()
        if not line or line.startswith("#") or "tot" in line or "PROCAR" in line:
            continue
        try:
            rows.append([float(x) for x in line.split()])
        except ValueError:
            continue

    data = np.array(rows)[:, 1:1+num_orbitals]

    # ==========================================================
    # ===================== NO SOC =============================
    # ==========================================================
    if soc == 0:
        expected = Nk * Nb * num_ions
        if data.shape[0] != expected:
            raise RuntimeError(f"Dati PROCAR incompatibili con soc=0: attesi {expected} righe, trovate {data.shape[0]}")

        occupazioni = data.reshape(Nk, Nb, num_ions, num_orbitals).transpose(2, 1, 0, 3)

        return {
            "bands": bands,
            "kpoints": kpoints,
            "k_distances": k_distances,
            "occupazioni": occupazioni,
            "meta": {
                "Nk": Nk,
                "Nb": Nb,
                "num_ions": num_ions,
                "num_orbitals": num_orbitals,
                "Efermi": Efermi
            }
        }

    # ==========================================================
    # ====================== SOC ===============================
    # ==========================================================
    else:
        block = num_ions * 4
        expected = Nk * Nb * block
        if data.shape[0] != expected:
            raise RuntimeError(f"Dati PROCAR incompatibili con soc=1: attesi {expected} righe, trovate {data.shape[0]}")

        data = data.reshape(Nk, Nb, block, num_orbitals)

        charge  = data[:, :, 0*num_ions:1*num_ions, :]
        sigma_x = data[:, :, 1*num_ions:2*num_ions, :]
        sigma_y = data[:, :, 2*num_ions:3*num_ions, :]
        sigma_z = data[:, :, 3*num_ions:4*num_ions, :]

        # Trasposizione: (Nion, Nb, Nk, Norb)
        charge  = charge.transpose(2, 1, 0, 3)
        sigma_x = sigma_x.transpose(2, 1, 0, 3)
        sigma_y = sigma_y.transpose(2, 1, 0, 3)
        sigma_z = sigma_z.transpose(2, 1, 0, 3)

        occupazioni = np.concatenate([charge, sigma_x, sigma_y, sigma_z], axis=-1)

        return {
            "bands": bands,
            "kpoints": kpoints,
            "k_distances": k_distances,
            "occupazioni": occupazioni,
            "charge": charge,
            "sigma_x": sigma_x,
            "sigma_y": sigma_y,
            "sigma_z": sigma_z,
            "meta": {
                "Nk": Nk,
                "Nb": Nb,
                "num_ions": num_ions,
                "num_orbitals": num_orbitals,
                "Efermi": Efermi
            }
        }


### 1.1.3 Unfolded Bands structure

In [None]:
def unfolded_BANDS(filepath):
    """
    Reads unfolded band structures from a bands4vasp output file.

    Parameters
    ----------
    filepath : str
        Path to the bands4vasp unfolded band file.

    Returns
    -------
    np.ndarray
        2D array of floats containing the unfolded bands.
        Each row corresponds to a k-point, and columns contain 
        energy values and possibly spectral weights, depending on 
        the bands4vasp format.

    Notes
    -----
    - Expects a plain text file with numerical data.
    - Lines are converted to floats and stored in a NumPy array.
    """
    rows = []
    with open(filepath, 'r') as f:
        for line in f:
            parts = line.split()
            # Convert parts to floats and add to rows
            float_parts = [float(x) for x in parts]
            rows.append(float_parts)
    data = np.array(rows, dtype=float)
    return data

def read_data(filepath):
    """
    Reads non-unfolded band data from a bands4vasp output file.

    **Parameters**
    ----------
    filepath : str
        Path to the bands4vasp band file.

    **Returns**
    -------
    np.ndarray
        2D array of floats containing the band data.
        Empty lines and comment lines starting with '#' are skipped.

    Notes
    -----
    - Useful for reading standard band structure files before any unfolding.
    - Each row corresponds to a k-point or energy value depending on the file format.
    """
    data_rows = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith('#'):
                continue  # skip empty or comment lines
            parts = line.split()
            float_parts = [float(x) for x in parts]
            data_rows.append(float_parts)
    data = np.array(data_rows, dtype=float)
    return data

### 1.1.4 Wannier Bands & Shift Currents

In [None]:
def wannier_bands(bandpath, kpath, numkpoints):
    """
    Load a Wannier90-style band structure and corresponding k-points.

    Parameters
    ----------
    bandpath : str
        Path to the Wannier90 band structure file (usually `*.band`).
    kpath : str
        Path to the KPOINTS file defining the k-point path.
    numkpoints : int
        Number of k-points along the path.

    Returns
    -------
    k_distances : np.ndarray, shape (numkpoints,)
        Cumulative distances along the k-point path.
    bande : np.ndarray, shape (N_bands, numkpoints)
        Band energies at each k-point.

    Notes
    -----
    - The band file is expected to be in the standard Wannier90 format, where
      each block corresponds to a single band and empty lines separate bands.
    - The function reads the KPOINTS file to compute cumulative k-point distances
      for plotting band structures.
    - The returned `bande` array contains energies for each band along the path.
    """
    with open(bandpath, "r") as f:
        lines = f.read().strip().split('\n')

    bands_raw = []
    current_band = []

    for line in lines:
        if line.strip():
            current_band.append([float(val) for val in line.split()])
        else:
            bands_raw.append(current_band)
            current_band = []

    if current_band:
        bands_raw.append(current_band)

    bands_array = np.array([np.array(band) for band in bands_raw])
    bande = bands_array[:, :, 1]
    with open(kpath, 'r') as f:
        lines = f.readlines()

    # KPOINTS
    kpoints = []
    for line in range(1, numkpoints+1): 
        kpoints_data = [float(num) for num in lines[line].split()[:3]]
        kpoints.append(list(map(float, kpoints_data)))

    kpoints = np.array(kpoints)

    k_distances = [0]
    for i in range(1, len(kpoints)):
        dist = np.linalg.norm(kpoints[i] - kpoints[i - 1])
        k_distances.append(k_distances[-1] + dist)

    return k_distances, bande

In [None]:
def shift_currents(filepath, component):
    """
    Load shift current data from a Wannier90 calculation.

    **Parameters**
    ----------
    filepath : str
        Path to the folder containing the Wannier90 shift current files.
    component : str
        Component of the shift current to read (e.g., 'xx', 'xy', 'yz').

    **Returns**
    -------
    np.ndarray
        2D array of floats containing the shift current data.
        Each row corresponds to an energy/frequency point and
        columns contain the computed shift current values.

    Notes
    -----
    - Expects a file named `wannier90-sc_{component}.dat` inside `filepath`.
    - The file must contain numeric data separated by whitespace.
    - Useful for plotting or analyzing shift currents along specific tensor components.
    """
    with open(filepath + "/wannier90-sc_" + component + ".dat", "r") as f:
        lines = f.readlines()

    data = np.array([list(map(float, line.split())) for line in lines])

    return data

## 1.2 Plotting

### 1.2.1 Proprietà elettroniche

#### 1.2.1.1 Helper DOS

In [None]:
def select_pdos_channels(DOS_data, species=None, orbitals="all", factor_map=None):
    """
    Select individual PDOS channels for plotting or analysis.

    Each channel corresponds to a single ion and orbital.

    Parameters
    ----------
    DOS_data : dict
        Dictionary containing PDOS information, as returned by `extract_DOS`.
    species : int, list of int, or None, optional
        Ion index (or list of indices) to include.
        If None, all ions are included.
    orbitals : str or list of str, optional
        Orbitals to include for each ion. Use "all" to select all available orbitals.
    factor_map : dict, optional
        Optional mapping from ion index to a scaling factor for the DOS values.

    Returns
    -------
    list of dict
        Each dict contains:
        - `energy` : np.ndarray
            Energy grid (shifted by Efermi).
        - `dos` : np.ndarray
            DOS values for the given ion and orbital.
        - `factor` : float
            Scaling factor applied to the DOS.
        - `species` : int
            Ion index.
        - `orbital` : str
            Orbital name.
        - `label` : str
            String label for the channel, e.g., "Ion 0 s".

    Notes
    -----
    - Useful for plotting PDOS contributions from selected ions and orbitals.
    - The returned channels can be directly iterated over to create multi-channel plots.
    - If `factor_map` is provided, each DOS is scaled by the corresponding factor.
    """

    PDOS = DOS_data["PDOS"]

    if species is None:
        species = range(len(PDOS))
    elif isinstance(species, int):
        species = [species]

    channels = []

    for i in species:
        pd = PDOS[i]
        energies = pd["energy"]
        factor = 1 if factor_map is None else factor_map.get(i, 1)

        if orbitals == "all":
            orb_list = [k for k in pd.keys() if k != "energy"]
        else:
            orb_list = orbitals

        for orb in orb_list:
            channels.append({
                "energy": energies,
                "dos": pd[orb],
                "factor": factor,
                "species": i,
                "orbital": orb,
                "label": f"Ion {i} {orb}"
            })

    return channels


In [None]:
def group_by_species(channels):
    """
    Aggregate PDOS channels by atomic species.

    This function sums the DOS contributions of all orbitals for each ion
    (species) and produces a single channel per species.

    Parameters
    ----------
    channels : list of dict
        List of PDOS channels, as returned by `select_pdos_channels`.
        Each channel dict must contain keys 'energy', 'dos', 'factor', 'species', and 'orbital'.

    Returns
    -------
    list of dict
        Aggregated PDOS channels, one per species. Each dict contains:
        - `energy` : np.ndarray
            Energy grid (shifted by Efermi).
        - `dos` : np.ndarray or dict
            Summed DOS for the species. If spin-polarized, a dict with 'up' and 'down'.
        - `factor` : float
            Scaling factor (from the first channel of the species).
        - `label` : str
            Label for the species channel, e.g., "Ion 0".
        - `orbital` : str
            Set to "total".
        - `species` : int
            Atomic species index.

    Notes
    -----
    - Handles both spin-polarized and non-spin-polarized DOS data.
    - Useful for plotting total DOS per species instead of per orbital.
    - The original channel list remains unmodified.
    """

    grouped = {}

    for ch in channels:
        key = ch["species"]

        if key not in grouped:
            grouped[key] = {
                "energy": ch["energy"],
                "dos": None,
                "factor": ch["factor"],
                "label": f"Ion {key}",
                "orbital": "total",
                "species": key  # <<< aggiunta qui
            }

        dos = ch["dos"]

        if grouped[key]["dos"] is None:
            if isinstance(dos, dict):
                grouped[key]["dos"] = {"up": dos["up"].copy(), "down": dos["down"].copy()}
            else:
                grouped[key]["dos"] = dos.copy()
        else:
            if isinstance(dos, dict):
                grouped[key]["dos"]["up"] += dos["up"]
                grouped[key]["dos"]["down"] += dos["down"]
            else:
                grouped[key]["dos"] += dos

    return list(grouped.values())


In [None]:
def group_by_orbital(channels):
    """
    Aggregate PDOS channels by orbital type.

    This function sums the DOS contributions of the same orbital across
    all ions (species), producing a single channel per orbital.

    Parameters
    ----------
    channels : list of dict
        List of PDOS channels, as returned by `select_pdos_channels` or `group_by_species`.
        Each channel dict must contain keys 'energy', 'dos', 'factor', and 'orbital'.

    Returns
    -------
    list of dict
        Aggregated PDOS channels, one per orbital. Each dict contains:
        - `energy` : np.ndarray
            Energy grid (shifted by Efermi).
        - `dos` : np.ndarray or dict
            Summed DOS for the orbital. If spin-polarized, a dict with 'up' and 'down'.
        - `factor` : float
            Scaling factor (from the first channel of that orbital).
        - `label` : str
            Label for the orbital, e.g., "s", "p_x", "d_xy".
        - `orbital` : str
            Orbital name.

    Notes
    -----
    - Handles both spin-polarized and non-spin-polarized DOS data.
    - Useful for plotting total DOS per orbital instead of per ion.
    - The original channel list remains unmodified.
    """

    grouped = {}

    for ch in channels:
        key = ch["orbital"]

        if key not in grouped:
            grouped[key] = {
                "energy": ch["energy"],
                "dos": None,
                "factor": ch["factor"],
                "label": key,
                "orbital": key
            }

        dos = ch["dos"]

        if grouped[key]["dos"] is None:
            grouped[key]["dos"] = dos.copy()
        else:
            if isinstance(dos, dict):
                for s in ["up", "down"]:
                    grouped[key]["dos"][s] += dos[s]
            else:
                grouped[key]["dos"] += dos

    return list(grouped.values())


In [None]:
def plot_dos_channel(ax, channel, color=None, use_fill=True, group_type=None, species_palette=None):
    """
    Plot a single DOS (Density of States) channel, handling spin-polarized and non-spin-polarized data.

    Supports optional fill_between for spin-up/down, automatic coloring, and grouping by species or orbital.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Matplotlib Axes object where the DOS will be plotted.
    channel : dict
        A PDOS channel dict, containing keys:
        - `energy` : np.ndarray
        - `dos` : np.ndarray or dict (with 'up'/'down' for spin-polarized)
        - `factor` : float
        - `orbital` : str
        - `label` : str
        - `species` : int, optional
    color : None or str, optional
        If specified, uses this color for the plot. Default is None (automatic coloring).
    use_fill : bool, optional
        If True and DOS is spin-polarized, uses `fill_between` for spin-up/down. Default is True.
    group_type : None or str, optional
        Indicates if plotting grouped channels; can be "species", "orbital", or None.
    species_palette : dict[int, color], optional
        Palette mapping species indices to colors (used only if `group_type="species"`).

    Notes
    -----
    - Dashed lines are automatically applied for selected orbitals (`p_x`, `d_x2-y2`) in non-spin plots.
    - Spin-polarized DOS (`dos` as dict with keys 'up'/'down') can be visualized with either `fill_between` or separate lines.
    - `factor` scales the DOS values, useful for normalized plotting.
    - The function does **not** return any value; it plots directly on the provided Axes object.
    """

    dashed_orbitals = ["p_x", "d_x2-y2"]

    energy = channel["energy"]
    dos = channel["dos"]
    factor = channel["factor"]
    label = channel["label"]
    orb = channel["orbital"]
    species = channel.get("species", None)

    # --- Colore automatico ---
    if color is None:
        if group_type == "species" and species_palette is not None:
            color = species_palette.get(species, "black")
        elif group_type == "orbital" and species_palette is not None:
            color = species_palette.get(orb, "black")
        else:
            # default specie+orbital
            color = species_palette.get((species, orb), "black")


    # -------- DOS non spin-polarized --------
    if isinstance(dos, np.ndarray):
        if orb in dashed_orbitals:
            line, = ax.plot(
                energy, factor * dos,
                linewidth=5, label=label, color=color
            )
            line.set_dashes([1, 2])
        else:
            ax.plot(
                energy, factor * dos,
                linewidth=2, label=label, color=color
            )

    # -------- DOS spin-polarized --------
    elif isinstance(dos, dict):
        if use_fill:
            ax.fill_between(
                energy, 0, factor * dos["up"],
                color=color, alpha=0.5, label=label + r" $\uparrow$"
            )
            ax.fill_between(
                energy, 0, -factor * dos["down"],
                color=color, alpha=0.5, label=label + r" $\downarrow$"
            )
        else:
            ax.plot(
                energy, factor * dos["up"],
                linewidth=2, label=label + r" $\uparrow$", color=color
            )
            ax.plot(
                energy, -factor * dos["down"],
                linewidth=2, linestyle="--",
                label=label + r" $\downarrow$", color=color
            )


#### 1.2.1.2 Helper bande

In [None]:
def select_band_channels(bands_data, projection):
    """
    Construct projection channels for bands, separating spin ↑ and ↓ if present.

    Each channel corresponds to a single ion, orbital, band, and spin (or total for non-spin data).

    Parameters
    ----------
    bands_data : dict
        Dictionary returned by `extract_BANDS_full`, containing keys:
        - `occupazioni` : np.ndarray or dict (Nion, Nb, Nk, Norb)
        - `bands` : np.ndarray (Nb, Nk)
        - `k_distances` : np.ndarray (Nk,)
        - `meta` : dict with metadata (Nk, Nb, Nion, Norb, etc.)
    projection : dict
        Projection specification, can contain:
        - `species` : int or list of int, optional
            Ion indices to include (default: all ions)
        - `orbitals` : int, list of int, or "all", optional
            Orbitals to include (default: all orbitals)
        - `spin` : str, optional
            "up", "down", or "both" (default: "both")

    Returns
    -------
    list of dict
        Each dict represents a single band projection channel and contains:
        - `k` : np.ndarray
            k-path distances (Nk,)
        - `energy` : np.ndarray
            Band energies (Nb,)
        - `weight` : np.ndarray
            Orbital weights (Nk,) for the selected ion, orbital, and spin
        - `band` : int
            Band index
        - `species` : int
            Ion index
        - `orbital` : int
            Orbital index
        - `spin` : str
            "up", "down", or "total"
        - `label` : str
            String label, e.g., "Ion 0 orb 2 up"

    Notes
    -----
    - Handles both spin-polarized and non-spin-polarized band data.
    - Useful for plotting band structures with orbital/ionic projections.
    - The returned channels can be iterated over for per-band, per-orbital plotting.
    """

    occ = bands_data["occupazioni"]   # (Nion, Nb, Nk, Norb)
    bands = bands_data["bands"]
    k_dist = bands_data["k_distances"]

    meta = bands_data["meta"]
    Nion, Nb, Nk, Norb = occ.shape[:4]

    # ---------- normalizzazione input ----------
    species = projection.get("species", range(Nion))
    if isinstance(species, int):
        species = [species]

    orbitals = projection.get("orbitals", range(Norb))
    if orbitals == "all":
        orbitals = range(Norb)
    elif isinstance(orbitals, int):
        orbitals = [orbitals]

    spin_sel = projection.get("spin", "both")  # "up", "down", "both"

    channels = []

    for i in species:
        for o in orbitals:
            for b in range(Nb):

                w = occ[i, b, :, o]

                # ---- spin-polarized ----
                if isinstance(w, dict):
                    spins = []
                    if spin_sel in ("up", "both"):
                        spins.append(("up", w["up"]))
                    if spin_sel in ("down", "both"):
                        spins.append(("down", w["down"]))
                else:
                    spins = [("total", w)]

                for s, weights in spins:
                    channels.append({
                        "k": k_dist,
                        "energy": bands[b],
                        "weight": weights,
                        "band": b,
                        "species": i,
                        "orbital": o,
                        "spin": s,
                        "label": f"Ion {i} orb {o} {s}"
                    })

    return channels


In [None]:
def normalize_weights(weights, min_size=5, max_size=120):
    """
    Normalize an array of weights to a specified size range, e.g., for plotting markers.

    Parameters
    ----------
    weights : array-like
        Input array of weights (e.g., band projection weights).
    min_size : float, optional
        Minimum size after normalization (default: 5).
    max_size : float, optional
        Maximum size after normalization (default: 120).

    Returns
    -------
    np.ndarray
        Array of normalized weights scaled to the range [min_size, max_size].

    Notes
    -----
    - If all weights are zero (or nearly zero), returns an array filled with `min_size`.
    - Useful for mapping orbital/band weights to marker sizes in plots.
    - Preserves relative ratios when weights are nonzero.
    """
    w = np.array(weights)
    if np.allclose(w.max(), 0):
        return np.full_like(w, min_size)
    w_norm = (w - w.min()) / (w.max() - w.min())
    return min_size + w_norm * (max_size - min_size)


In [None]:
def get_band_color(channel, group_type=None, palette=None):
    """
    Determine the plotting color for a band channel based on grouping.

    Parameters
    ----------
    channel : dict
        A band or PDOS channel dictionary containing at least `species` or `orbital` keys.
    group_type : None, "species", or "orbital", optional
        Indicates the grouping type for color selection:
        - None : always returns "black"
        - "species" : uses `channel["species"]` to select a color from the palette
        - "orbital" : uses `channel["orbital"]` to select a color from the palette
    palette : dict, optional
        Mapping from species indices or orbital names to colors.
        Only used if `group_type` is "species" or "orbital".

    Returns
    -------
    str
        Color name or hex string for plotting the channel.

    Notes
    -----
    - Defaults to "black" if no palette is provided or `group_type` is None.
    - Useful for consistent coloring when plotting multiple bands or PDOS channels grouped by species or orbital.
    """
    if group_type is None:
        return "black"

    if group_type == "species":
        return palette.get(channel["species"], "black")

    if group_type == "orbital":
        return palette.get(channel["orbital"], "black")

    return "black"


In [None]:
def plot_band_channel(ax, channel, cmap, norm, size_scale):
    """
    Plot a single band projection channel using a scatter plot.

    Marker sizes are scaled according to the channel weights.

    Parameters
    ----------
    ax : matplotlib.axes.Axes
        Matplotlib Axes object to plot on.
    channel : dict
        Band channel dictionary, as returned by `select_band_channels`. 
        Must contain keys:
        - `k` : np.ndarray of k-path distances
        - `energy` : np.ndarray of band energies
        - `weight` : np.ndarray of orbital weights
    cmap : matplotlib.colors.Colormap
        Colormap for mapping weights to colors.
    norm : matplotlib.colors.Normalize
        Normalization for the colormap (e.g., `matplotlib.colors.Normalize` or `LogNorm`).
    size_scale : tuple of two floats
        Minimum and maximum marker sizes, e.g., (5, 120). Passed to `normalize_weights`.

    Returns
    -------
    matplotlib.collections.PathCollection
        The scatter plot object, useful for colorbars or further adjustments.

    Notes
    -----
    - Marker sizes are proportional to `channel["weight"]`.
    - Useful for visualizing orbital contributions in band structures.
    - The function plots directly on the provided Axes object.
    """
    sizes = normalize_weights(
        channel["weight"],
        min_size=size_scale[0],
        max_size=size_scale[1]
    )

    sc = ax.scatter(
        channel["k"],
        channel["energy"],
        c=channel["weight"],
        s=sizes,
        cmap=cmap,
        norm=norm,
        alpha=0.7,
        linewidths=0
    )

    return sc


#### 1.2.1.3 Plot 

In [None]:
def plot_DOS_bands(DOS_data, DOS_color, DOS_label,
                   bands_data, bands_color, bands_label,
                   k_path,
                   ylim,
                   xlim,
                   legend_position,
                   plot_type,
                   projection=None, 
                   ax=None,
                   style=None):
    """
    Plot DOS (Density of States), band structures, or both, with optional orbital/ion/spin projections 
    (fat bands), similar to standard VASP analysis.

    Supports:
    - Total DOS and PDOS per species/orbital
    - Band structures with or without projections
    - Combined DOS + bands plots
    - Spin-polarized and non-spin-polarized systems

    Parameters
    ----------
    DOS_data : dict
        Output from `extract_DOS()`. Must contain:
        - `DOS_total` : total DOS array (energy, DOS, IDOS)
        - `PDOS` : optional per-ion/per-orbital DOS
    DOS_color : str
        Color for total DOS curve (if no projection).
    DOS_label : str
        Label for the DOS curve in the legend.
    bands_data : dict
        Output from `extract_BANDS_full()`. Must contain:
        - `bands` : array of band energies
        - `k_distances` : cumulative k-path distances
        - `occupazioni` : orbital weights for each ion/band/kpoint
    bands_color : str
        Color for unprojected bands.
    bands_label : str
        Label for bands in the legend.
    k_path : tuple
        Tuple `(positions, labels)` defining high-symmetry k-points for x-axis ticks.
        - `positions` : list of float positions
        - `labels` : list of str labels (e.g., ['Γ', 'X', 'M'])
    ylim : tuple
        y-axis limits (energy, in eV)
    xlim : tuple
        x-axis limits
    legend_position : tuple
        Tuple `(x, y)` for legend placement
    plot_type : str, optional
        "DOS", "bands", or "full". Determines which part to plot.
    projection : dict, optional
        Projection specifications for fat bands or PDOS:
        - `species` : int or list of int (ion indices)
        - `orbitals` : int, list of int, or "all"
        - `spin` : "up", "down", "both", or "difference" (for bands)
        - `factor_map` : dict mapping ion index → scaling factor (for DOS)
        - `group` : "species" or "orbital" (for coloring/grouping PDOS)
        - `size_scale` : tuple (min, max) marker sizes for fat bands
    ax : matplotlib.axes.Axes, optional
        Axes to plot on. If None, a new figure is created.
    style : str or None, optional
        "paper" applies publication-style formatting (serif font, grid off, etc.)

    Notes
    -----
    **Orbital indices (0-based) for projections:**
    ```
    0 : 's'
    1 : 'p_y'
    2 : 'p_z'
    3 : 'p_x'
    4 : 'd_xy'
    5 : 'd_yz'
    6 : 'd_z2-r2'
    7 : 'd_xz'
    8 : 'd_x2-y2'
    ```

    **DOS projections (PDOS)**: Highlight contributions from specific ions or orbitals using:
    ```python
    projection = {
        "species": [0, 2],          # ion indices
        "orbitals": [4, 6],         # orbital indices (e.g., d_xy=4, d_z2-r2=6)
        "factor_map": {0: 1.0, 2: 0.5},
        "group": "species"          # colors by species
    }
    plot_DOS_bands(DOS_data, "black", "Total DOS",
                   bands_data, "gray", "Bands",
                   k_path, (-5, 5), (0, 1), (1.05, 0.5),
                   plot_type="DOS",
                   projection=projection)
    ```

    **Band projections (fat bands)**: Marker sizes reflect orbital weights for each ion/spin:
    ```python
    projection = {
        "species": [0,1,2],
        "orbitals": "all",           # all orbitals
        "spin": "up",
        "size_scale": (5, 80)        # min/max marker sizes
    }
    plot_DOS_bands(DOS_data, "black", "Total DOS",
                   bands_data, "gray", "Bands",
                   k_path, (-5, 5), (0, 1), (1.05, 0.5),
                   plot_type="bands",
                   projection=projection)
    ```

    **Combined DOS + bands**: `plot_type="full"` creates a side-by-side layout with DOS (right) and bands (left), aligned along the energy axis:
    ```python
    projection = {
        "species": [0,1],
        "orbitals": [4,7],           # d_xy, d_xz
        "spin": "up",
        "size_scale": (5, 80),
        "group": "species"
    }
    plot_DOS_bands(DOS_data, "black", "Total DOS",
                   bands_data, "gray", "Bands",
                   k_path, (-5, 5), (0, 1), (1.05, 0.5),
                   plot_type="full",
                   projection=projection,
                   style="paper")
    ```

    - Spin-polarized data (`spin="both"`) will plot up/down separately.
    - `spin="difference"` shows up-down difference with a diverging colormap.
    - Vertical line at E_F (0 eV) is automatically drawn in red.
    - PDOS and fat bands are colored and sized according to the `projection` dictionary and optional palettes.
    """
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(8,5))

    # ---------------- STILE ----------------
    if style == "paper":
        mpl.rcParams.update({
            "font.family": "serif",
            "font.size": 14,
            "axes.linewidth": 1.2,
            "xtick.major.size": 6,
            "ytick.major.size": 6,
            "savefig.dpi": 300
        })
        ax.grid(False)

    # ---------------- DOS NON PROIETTATO ----------------
    if plot_type == "DOS" and projection is None:
        ax.plot(DOS_data["DOS_total"][:, 0], DOS_data["DOS_total"][:, 1],
                label=DOS_label, color=DOS_color)
        ax.set_xlabel("Energia [eV]")
        ax.set_ylabel("DOS [states/eV]")
        ax.set_ylim(ylim)
        ax.set_xlim(xlim)
        ax.axvline(0, color='red', linestyle='--', label='E$_F$')

    # ---------------- DOS PROIETTATO ----------------
    elif plot_type == "DOS" and projection is not None:

        channels = select_pdos_channels(
            DOS_data,
            species=projection.get("species", None),
            orbitals=projection.get("orbitals", "all"),
            factor_map=projection.get("factor_map", None)
        )

        group_type = projection.get("group", None)

        if group_type == "species":
            channels = group_by_species(channels)  # se vuoi sommare tutte le orbitali per specie
            num_species = len(set(ch["species"] for ch in channels))
            colors_list = cm.tab10(np.linspace(0, 1, num_species))
            species_palette = {ch["species"]: colors_list[i] for i, ch in enumerate(channels)}
            
        elif group_type == "orbital":
            channels = group_by_orbital(channels)
            # colori unici per ogni orbital
            orbitals_list = [ch["orbital"] for ch in channels]
            colors_list = cm.tab10(np.linspace(0,1,len(orbitals_list)))
            species_palette = {ch["orbital"]: colors_list[i] for i,ch in enumerate(channels)}
            
        else:
            # default: un colore unico per ogni combinazione specie+orbital
            species_palette = {}
            for i, ch in enumerate(channels):
                species_palette[(ch["species"], ch["orbital"])] = cm.tab20(i % 20)


        for ch in channels:
            plot_dos_channel(ax, ch, color=None, use_fill=True,
                             group_type=group_type, species_palette=species_palette)

        ax.plot(DOS_data["DOS_total"][:, 0], DOS_data["DOS_total"][:, 1],
                label="Full DOS", color=DOS_color)
        ax.axvline(0, color="red", linestyle="--", label=r"$E_F$")
        ax.set_xlabel(r"$\epsilon - \epsilon_F$ [eV]")
        ax.set_ylabel(r"DOS [$\frac{states}{eV}$]")
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
        ax.legend(loc='center right', bbox_to_anchor=legend_position)

    # ---------------- BANDS NON PROIETTATE ----------------
    if plot_type == "bands" and projection is None:
        for b in range(bands_data["bands"].shape[0]):
            ax.plot(bands_data["k_distances"], bands_data["bands"][b, :],
                    color=bands_color)
        ax.set_ylabel(r"$\epsilon - \epsilon_{F}$ [eV]")
        ax.set_ylim(ylim)
        ax.set_xlim(xlim if xlim else (bands_data["k_distances"][0], bands_data["k_distances"][-1]))
        ax.axhline(0, color='red', linestyle='--', label='E$_F$')
        ax.set_xticks(k_path[0])
        ax.set_xticklabels(k_path[1])
        ax.grid(axis='x')

    # ---------------- BANDS PROIETTATE (FAT BANDS) ----------------
    elif plot_type == "bands" and projection is not None:

        channels = select_band_channels(bands_data, projection)

        all_weights = np.concatenate([ch["weight"] for ch in channels])
        w_max = all_weights.max()

        if w_max < 1e-6:
            warnings.warn("Proiezione quasi nulla: specie/orbitali probabilmente irrilevanti")
        if np.any(all_weights < -1e-8) and projection.get("spin") != "difference":
            warnings.warn("Pesi negativi trovati (non attesi)")

        spin_marker = {"up":"o", "down":"v", "total":"o", "difference":"o"}
        size_min, size_max = projection.get("size_scale", (5, 80))

        for ch in channels:
            k = ch["k"]
            E = ch["energy"]
            w = ch["weight"]
            spin = ch["spin"]

            # dimensione + alpha
            size = np.interp(abs(w), [0, w_max], [size_min, size_max])
            alpha = 0.15 + 0.85*(abs(w)/w_max)

            # cmap
            if projection.get("spin") == "difference":
                cmap = cm.seismic
                norm = colors.TwoSlopeNorm(vmin=-w_max, vcenter=0.0, vmax=w_max)
            else:
                cmap = cm.viridis if spin in ("up","total") else cm.plasma
                norm = colors.Normalize(vmin=0, vmax=w_max)

            ax.scatter(k, E, s=size, c=w, cmap=cmap, norm=norm,
                       marker=spin_marker[spin], alpha=alpha, edgecolors="none")

        # colorbar
        sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, pad=0.02)
        cbar.set_label("Projection weight")

        # estetica
        ax.axhline(0, color="red", linestyle="--", linewidth=1)
        ax.set_ylabel(r"$\epsilon - \epsilon_F$ [eV]")
        ax.set_ylim(ylim)
        ax.set_xlim(xlim if xlim else (bands_data["k_distances"][0], bands_data["k_distances"][-1]))
        ax.set_xticks(k_path[0])
        ax.set_xticklabels(k_path[1])
        ax.grid(axis="x", linestyle=":", alpha=0.6)

        # ---- legenda spin solo se sia up che down presenti ----
        spins_present = {ch["spin"] for ch in channels}

        legend_elements = []

        if "both" in spins_present:
            legend_elements.append(Line2D([0],[0], marker='o', color='k', lw=0, label='spin ↑'))
            legend_elements.append(Line2D([0],[0], marker='v', color='k', lw=0, label='spin ↓'))

        # # aggiungi sempre legenda per size/color projection
        # legend_elements.append(Line2D([0],[0], lw=0, label='size ∝ projection'))
        # legend_elements.append(Line2D([0],[0], lw=0, label='color ∝ projection'))

        ax.legend(handles=legend_elements, loc='upper right', frameon=False)


    # ---------------- FULL PLOT (BANDS + DOS) ----------------
    if plot_type == "full":
        fig = ax.figure
        fig.clf()
        gs = gridspec.GridSpec(1,2, width_ratios=[3,1], wspace=0.05)

        # BANDE
        ax1 = fig.add_subplot(gs[0])
        for b in range(bands_data["bands"].shape[0]):
            ax1.plot(bands_data["k_distances"], bands_data["bands"][b],
                     color=bands_color, linewidth=1)
        ax1.axhline(0, color='red', linestyle='--', linewidth=1)
        ax1.set_ylabel(r"$\epsilon - \epsilon_F$ [eV]")
        ax1.set_ylim(ylim)
        ax1.set_xlim(xlim if xlim else (bands_data["k_distances"][0], bands_data["k_distances"][-1]))
        ax1.set_xticks(k_path[0])
        ax1.set_xticklabels(k_path[1])
        ax1.grid(axis='x', linestyle=':', alpha=0.6)

        # DOS
        ax2 = fig.add_subplot(gs[1], sharey=ax1)
        ax2.plot(DOS_data["DOS_total"][:, 1], DOS_data["DOS_total"][:, 0],
                 color=DOS_color, linewidth=1.5)
        ax2.axhline(0, color='red', linestyle='--', linewidth=1)
        ax2.set_xlabel("DOS [states/eV]")
        ax2.set_xlim(0, np.max(DOS_data["DOS_total"][:,1])*1.05)
        ax2.tick_params(axis='y', which='both', left=False, right=False, labelleft=False)

        fig.subplots_adjust(left=0.12, right=1.25, bottom=0.12, top=0.95, wspace=0.05)

### 1.2.2 Bande unfoldate

In [None]:
def plot_unfolded_bands(bandsunf, 
                        bands, label_DFT,
                        shift,  # energy shift for DFT bands, choose by eye
                        k_path,
                        plot_DFT = True,
                        fontsize=24,
                        figuresize=(12,6),
                        pointsize=50,
                        ylim=(-2.5,1.5),
                        legend_position='upper right'):
    """
    Plot unfolded band structures obtained from bands4vasp, with optional overlay of DFT bands.

    The unfolded bands are shown as a scatter plot where the color encodes the spectral weight.
    DFT bands can optionally be plotted on top with a vertical energy shift for comparison.

    Parameters
    ----------
    bandsunf : np.ndarray, shape (N_points, 3)
        Unfolded band data from bands4vasp.
        Columns: [k_distance, energy, spectral_weight]
    bands : np.ndarray, shape (N_bands, 2)
        DFT band structure to overlay (k_distance, energy).
    label_DFT : str
        Label for the DFT bands in the legend.
    shift : float
        Energy shift to apply to the DFT bands for alignment with unfolded bands.
        Must be chosen by eye depending on reference energy.
    k_path : tuple
        High-symmetry k-point information: (positions, labels)
        - positions : list of float
        - labels : list of str, e.g., ['Γ', 'X', 'M']
    plot_DFT : bool, optional
        Whether to overlay the DFT bands (default: True).
    fontsize : int, optional
        Font size for axes and labels (default: 24).
    figuresize : tuple, optional
        Figure size (width, height) in inches (default: (12,6)).
    pointsize : float, optional
        Marker size for unfolded bands (default: 50).
    ylim : tuple, optional
        y-axis limits for energy (default: (-2.5,1.5) eV).
    legend_position : str, optional
        Location of the legend (default: 'upper right').

    Returns
    -------
    None
        Plots directly on a matplotlib figure.

    Notes
    -----
    - The unfolded bands are colored by spectral weight using a custom gradient from white to midnight blue.
      Low weight → light color, high weight → dark blue.
    - DFT bands are plotted in black, optionally shifted in energy to match unfolded bands.
    - The k-path is displayed on the x-axis with high-symmetry labels.
    - The colorbar shows spectral weight ranging from 0 to 1.
    - Useful for comparing supercell calculations (unfolded) with primitive cell DFT bands.
    """
    colors = [
        (0.00, '#FFFFFF'),       # white
        (0.05, '#F5F5DC'),       # beige
        (0.25, '#ADFF2F'),       # greenyellow
        (0.40, '#008000'),       # web-green
        (0.65, '#4682B4'),       # steelblue
        (0.90, '#191970'),       # midnight-blue
        (1.00, '#191970')        # repeat last color at 1.0 to close the scale
    ]

    custom_cmap = mcolors.LinearSegmentedColormap.from_list("custom_palette", colors)
    plt.rcParams.update({'font.size': fontsize})

    fig, ax = plt.subplots(figsize=(figuresize))
    ax.set_facecolor('gray')  # the background inside the plot area
    ax.scatter(bandsunf[:,0], bandsunf[:,1], c=bandsunf[:,2], s=pointsize, marker='o', cmap=custom_cmap) 
    if plot_DFT:
        ax.scatter(bands[:,0], bands[:,1]-shift, s=pointsize/10, marker='o', color='black', label = label_DFT, zorder = 10) 

    ax.set_ylim(ylim)
    ax.set_xlim(bands[0,0], bands[-1,0])
    ax.set_ylabel(r"$\epsilon$ - $\epsilon_{F}$ [eV]")
    labels_x = {
        'x': k_path[0],
        'labels': k_path[1]
    }
    ax.set_xticks(labels_x['x'], labels_x['labels'])
    ax.grid(axis='x')
    ax.legend(loc=legend_position)

    sc = ax.scatter(bandsunf[:,0], bandsunf[:,1], c=bandsunf[:,2], s=pointsize, marker='o', cmap=custom_cmap) 
    cbar = plt.colorbar(sc, ax=ax)
    cbar.set_label('Spectral weight')  # etichetta della colorbar
    cbar.set_ticks([0, 1])  # solo valori estremi
    cbar.set_ticklabels(['0', '1'])

### 1.2.3 Wannier Bands & Shift currents

In [None]:
def plot_wannier_bands(bande_DFT, color_DFT,
                       k_distances, bande, flag, color_wannier,
                       k_path,
                       plot_DFT = True, 
                       pointsize=50, 
                       ylim=(-2.5,1.5),
                       legend_position='upper right',
                       ax=None):
    """
    Plot Wannier90-computed bands with optional overlay of DFT bands.

    The Wannier bands are displayed as scatter points, aligned to the DFT Fermi energy. 
    DFT bands can be overlaid as markers for comparison.

    Parameters
    ----------
    bande_DFT : dict
        Output from `extract_BANDS_full()`. Must contain:
        - 'bands' : array of DFT band energies
        - 'k_distances' : cumulative k-path distances
        - 'meta' : dictionary with at least 'Efermi'
    color_DFT : str
        Color used to plot DFT bands.
    k_distances : array-like
        K-point distances along the path for the Wannier bands.
    bande : array-like, shape (N_bands, N_k)
        Wannier band energies.
    flag : str
        Label used for the Wannier bands in the legend.
    color_wannier : str
        Color for the Wannier bands.
    k_path : tuple
        High-symmetry k-point positions and labels: (positions, labels)
        - positions : list of floats
        - labels : list of str
    plot_DFT : bool, optional
        Whether to overlay DFT bands (default: True).
    pointsize : float, optional
        Marker size for Wannier scatter points (default: 50).
    ylim : tuple, optional
        y-axis limits for energy (default: (-2.5, 1.5) eV).
    legend_position : str, optional
        Location of the legend (default: 'upper right').
    ax : matplotlib.axes.Axes
        Axes object to plot on. Must be provided; function will raise an error if None.

    Returns
    -------
    None
        Plots directly on the provided Axes.

    Notes
    -----
    - Wannier bands are shifted by the DFT Fermi energy to align energy scales.
    - DFT bands are optionally plotted as markers with no connecting lines.
    - The x-axis shows the high-symmetry k-path with labels specified in `k_path`.
    - Useful for validating Wannier interpolation against reference DFT bands.
    """

    if ax is None:
        raise ValueError("Devi passare un oggetto Axes (ax=...)")

    # ---- Wannier ----
    for i, band in enumerate(bande):
        ax.scatter(
            k_distances,
            band - bande_DFT['meta']['Efermi'],
            s=pointsize*10,
            label=flag if i == 0 else None,
            color=color_wannier
        )

    # ---- DFT ----
    if plot_DFT:
        for b in range(bande_DFT['bands'].shape[0]):
            ax.plot(
                bande_DFT['k_distances'],
                bande_DFT['bands'][b],
                marker='o',
                linestyle='',
                markersize=pointsize/2,
                label='DFT' if b == 0 else None,
                color=color_DFT
            )

    ax.set_ylim(ylim)
    ax.set_xlim(k_distances[0], k_distances[-1])
    ax.set_ylabel(r"$\epsilon - \epsilon_F$ [eV]")
    ax.set_xticks(k_path[0], k_path[1])
    ax.grid(axis='x')
    ax.legend(loc=legend_position)


In [None]:
def plot_shift_currents(sigma_dict,
                        suffixes,
                        indices,
                        flag=None,
                        ax=None,
                        linewidth=2,
                        linestyle='-',
                        colors=None,
                        legend_position='upper right'):
    """
    Plot shift current components σ[abc] or σ[suffix] on a single Axes.

    The function supports multiple indices/components and applies an intelligent automatic
    y-axis labeling based on the number of components and whether the label has already been customized.

    Parameters
    ----------
    sigma_dict : dict
        Dictionary containing shift current arrays. Each entry should be of shape (N, 2):
        - column 0: photon energy ω [eV]
        - column 1: σ values [A^2 N]
    suffixes : list of str
        List of suffix strings corresponding to keys in sigma_dict (e.g., ['xxx', 'xxy', ...]).
    indices : list of int
        Indices into `suffixes` specifying which components to plot.
    flag : str, optional
        Custom label for all plotted lines (overrides suffix names).
    ax : matplotlib.axes.Axes
        Axes object on which to plot. Must be provided; function will raise an error if None.
    linewidth : float, optional
        Line width of the plotted curves (default: 2).
    linestyle : str, optional
        Line style of the curves (default: '-').
    colors : list, optional
        List of colors for the plotted lines. If shorter than `indices`, remaining lines use default colors.
    legend_position : str, optional
        Location of the legend (default: 'upper right').

    Returns
    -------
    None
        Plots directly on the provided Axes.

    Notes
    -----
    - Y-axis labeling:
        * If only one index is plotted and the y-axis is empty or generic, it is labeled as σ^{suffix}.
        * If multiple indices are plotted and the label is still default, it is labeled as σ^{abc}.
        * The label is never overwritten if it has already been customized.
    - X-axis is labeled as ω [eV] if it has not been set.
    - The σ values are scaled by 1e6 to display in μA·N² units.

    """
    if ax is None:
        raise ValueError("Devi passare un oggetto Axes (ax=...)")
    
    for j, idx in enumerate(indices):
        suffix = suffixes[idx]
        data = sigma_dict[suffix]
        
        color = None
        if colors is not None and j < len(colors):
            color = colors[j]
        
        ax.plot(
            data[:, 0],
            -data[:, 1]*1E6,
            linewidth=linewidth,
            linestyle=linestyle,
            color=color,
            label=flag if flag is not None else suffix
        )
    
    # Gestione asse y senza sovrascrivere
    current_ylabel = ax.get_ylabel()
    if not current_ylabel or current_ylabel not in [r"$\sigma^{abc}$ [$\mu AN^2$]"]:
        if len(indices) == 1:
            ax.set_ylabel(rf"$\sigma^{{{suffixes[indices[0]]}}}$ [$\mu AN^2$]")
        else:
            ax.set_ylabel(r"$\sigma^{abc}$ [$\mu AN^2$]")
    
    # Gestione asse x solo se non impostata
    if not ax.get_xlabel():
        ax.set_xlabel(r"$\omega$ [eV]")

    ax.legend(loc=legend_position)
    ax.set_xlim(data[0,0], data[-1,0])


# 2 Plot dei dati

## 2.1 Proprietà elettroniche

In [None]:
DOS = extract_DOS("Dati esempio/DOSCAR", num_ions=3)
bande = extract_BANDS_full("Dati esempio/EIGENVAL", "Dati esempio/PROCAR", Efermi=DOS['Efermi'], soc=0)
k_path_DFT = [[0, 0.4714, 0.8441, 1.3441, 1.8441, 2.3155, 2.6882, 3.1882], ['Γ', 'K', 'M', 'Γ', 'A', 'H', 'L', 'A']]

In [None]:
plt.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(figsize=(8,5))

projection = {
    "species": [0],
    "orbitals": [7],
    "spin": "both",
    "size_scale": (5, 100),
    "group" : "none"
}

plot_DOS_bands(
    DOS,
    DOS_color="black",
    DOS_label="DOS",
    bands_data=bande,
    bands_color="black",
    bands_label="Bands",
    k_path=k_path_DFT,
    ylim=(-4, 4),
    xlim=None,
    legend_position=(1.5, 0.5),
    plot_type="bands",
    projection=projection,
    ax=ax
)

## 2.2 Bande Unfoldate

In [None]:
bandsunf = unfolded_BANDS('Dati esempio/blochbanddata.dat')
DOS = extract_DOS("Dati esempio/DOSCAR", num_ions=3)
bands= read_data("Dati esempio/banddata.dat")
k_path = [[0, 0.21, 0.315, 0.5],['Γ', 'K', 'M', 'Γ']]

plot_unfolded_bands(bandsunf, 
                    bands, 'DFT',
                    0.925, 
                    k_path, 
                    plot_DFT=True)

## 2.3 Wannier Bands & Shift Currents

In [None]:
k_points_wannier, bande_wannier = wannier_bands('Dati esempio/wannier90_band.dat', 'Dati esempio/wannier90_band.kpt', 247)
bande = extract_BANDS_full('Dati esempio/EIGENVAL', 'Dati esempio/PROCAR', 0.56653327, soc=0)
k_path_wannier = [[0, 0.4714045207910317, 0.8440825170409967, 1.3440825170409967],['Γ', 'K', 'M', 'Γ']]

In [None]:
plt.rcParams.update({'font.size': 18})
fig, ax = plt.subplots(figsize=(8,5))

plot_wannier_bands(bande, 'black',
                   k_points_wannier, bande_wannier, 'Wannier', 'green',
                   k_path_wannier, 
                   plot_DFT = True,
                   pointsize=5, ylim=(-2.5,2.5), legend_position= 'upper right', ax=ax)
# plot_wannier_bands(bande, 'black',
#                    k_points_wannier, bande_wannier, 'Wannier2', 'blue',
#                    k_path_wannier, 
#                    plot_DFT = False,
#                    pointsize=3, ylim=(-2.5,2.5), legend_position= 'upper right', ax=ax)
# plot_wannier_bands(bande, 'black',
#                    k_points_wannier, bande_wannier, 'Wannier3', 'red',
#                    k_path_wannier, 
#                    plot_DFT = False,
#                    pointsize=2, ylim=(-2.5,2.5), legend_position= 'upper right', ax=ax)


In [None]:
suffixes = ["xxx", "xxy", "xxz", "xyy", "xyz", "xzz", 
            "yxx", "yxy", "yxz", "yyy", "yyz", "yzz",
            "zxx", "zxy", "zxz", "zyy", "zyz", "zzz"]
path_sc = "Dati esempio/"

sigma = {suffix: shift_currents(path_sc, suffix) for suffix in suffixes}

In [None]:
index1 = [6, 12]
index2 = [1, 2]

fig, ax = plt.subplots(figsize=(8,5))

plot_shift_currents(sigma, suffixes, index1, flag = None, ax=ax) #flag da attivare solo se faccio multiplot una componente alla volta per personalizzare
plot_shift_currents(sigma, suffixes, index2, flag = None, ax=ax)