In [36]:
# -*- coding: utf-8 -*-
"""
Interactive Phonon Dispersion + Eigenvector Visualizer
Author: Han-Hsuan Wu
"""

import numpy as np
import matplotlib.pyplot as plt
import time
from mpl_toolkits.mplot3d import Axes3D  # for 3D lattice plotting
from ipywidgets import interact, IntSlider, FloatSlider, Button, VBox, HBox
from IPython.display import display
from ase.io import read  # for POSCAR parsing (uses ASE)
import os

# ===============================
# 1. File Parsers
# ===============================

def parse_band_yaml(filename="band.yaml"):
    
    import re

    start_time = time.time()

    with open(filename, "r") as f:
        text = f.read()

    # Find all q-point headers and their positions
    qpos_iter = list(re.finditer(r'(?m)^\s*-\s*q-position:\s*\[\s*([^\]]+)\]', text))
    if not qpos_iter:
        raise ValueError('Could not find any "q-position" entries in band.yaml')

    nqpoints = len(qpos_iter)

    # Extract q-point coordinates
    qpoints = np.empty((nqpoints, 3), dtype=np.float64)
    for qi, m in enumerate(qpos_iter):
        vals = [float(x) for x in m.group(1).split(",")]
        if len(vals) != 3:
            raise ValueError("q-position does not have 3 components")
        qpoints[qi] = vals

    # Build spans for each q-point section
    spans = []
    for i, m in enumerate(qpos_iter):
        start = m.end()
        end = qpos_iter[i + 1].start() if (i + 1) < nqpoints else len(text)
        spans.append((start, end))

    # Determine nbands and natoms from the first q-point section
    first_section = text[spans[0][0] : spans[0][1]]

    # Count bands in first section by counting frequency entries
    band_matches_first = list(re.finditer(r'(?m)^\s*frequency:\s*([-\d.+Ee]+)', first_section))
    nbands = len(band_matches_first)
    if nbands == 0:
        raise ValueError('Could not determine "nbands" from frequency entries')

    # Extract first eigenvector block and count atoms
    first_ev_block_match = re.search(
        r'(?ms)eigenvector:\s*(.*?)(?=^\s*frequency:|\Z)', first_section
    )
    if not first_ev_block_match:
        raise ValueError('Could not find "eigenvector" block in first q-point section')

    first_ev_block = first_ev_block_match.group(1)
    natoms = len(re.findall(r'(?m)^\s*-\s*#\s*atom', first_ev_block))
    if natoms == 0:
        # Fallback: count complex pairs, divide by 3 components per atom
        pairs_first = re.findall(r'\[\s*([-\d.+Ee]+)\s*,\s*([-\d.+Ee]+)\s*\]', first_ev_block)
        if len(pairs_first) % 3 != 0:
            raise ValueError("Could not infer natoms from eigenvector pairs")
        natoms = len(pairs_first) // 3

    # Pre-allocate arrays
    print(f"Data dimensions: {nqpoints} q-points, {nbands} bands, {natoms} atoms")
    qpoints_array = qpoints
    frequencies = np.empty((nqpoints, nbands), dtype=np.float64)
    eigenvectors = np.empty((nqpoints, nbands, natoms, 3), dtype=np.complex128)

    # Helper regex for band blocks within a q-point section:
    # capture frequency and the following eigenvector block
    band_block_re = re.compile(
        r'(?ms)^\s*frequency:\s*([-\d.+Ee]+)\s*.*?\beigenvector:\s*(.*?)(?=^\s*frequency:|\Z)'
    )

    # Parse all sections
    for qi, (s, e) in enumerate(spans):
        section = text[s:e]
        band_index = 0
        for bm in band_block_re.finditer(section):
            freq_str = bm.group(1)
            frequencies[qi, band_index] = float(freq_str) * 4.14 #THz to meV
            ev_block = bm.group(2)

            # Extract all [real, imag] pairs in order
            pairs = re.findall(r'\[\s*([-\d.+Ee]+)\s*,\s*([-\d.+Ee]+)\s*\]', ev_block)
            if len(pairs) < natoms * 3:
                raise ValueError(
                    f"Not enough eigenvector components at q={qi}, band={band_index}"
                )

            # Fill eigenvectors in the order: atoms x (x,y,z)
            for idx in range(natoms * 3):
                a = idx // 3
                c = idx % 3
                re_part = float(pairs[idx][0])
                im_part = float(pairs[idx][1])
                eigenvectors[qi, band_index, a, c] = re_part + 1j * im_part

            band_index += 1

        if band_index != nbands:
            raise ValueError(f"Expected {nbands} bands, found {band_index} at q-index {qi}")

    yaml_load_time = time.time()
    print(f"Text parsing took: {yaml_load_time - start_time:.2f} seconds")

    return qpoints_array, frequencies, eigenvectors


def parse_poscar(filename="POSCAR"):
    """Parse POSCAR using ASE."""
    atoms = read(filename)
    return atoms


# ===============================
# 2. Plotting
# ===============================

def plot_dispersion(qpoints, frequencies,aspect=5):
    """Plot phonon dispersion curve."""
    fig, ax = plt.subplots(figsize=(10, 6))  # total figure size in inches

    for band in range(frequencies.shape[1]):
        ax.plot(range(len(qpoints)), frequencies[:, band], 'k-')

    ax.set_xlabel("Q-point index")
    ax.set_ylabel("Frequency (meV)")
    ax.set_title("Phonon Dispersion")

    # Force aspect ratio: 'equal' or a number (y per x)
    ax.set_aspect(aspect)      # stretches vertically

    return fig, ax


def plot_eigenvectors(atoms, eigvecs, scale=1.0, expand_cell=(1, 1, 1), plane="xz", 
                        show_vectors=True, show_cell=True, numbering=False, cell_boundary=False, 
                        colors=None, atom_scale=0.45, legend=True,
                        head_width=0.1, head_length=0.1, arrow_linewidth=1,
                        z_range=None, rotation_angle=0, rotation_axis='z'):
    """
    Plot phonon eigenvectors with rotation capability.
    
    Additional Parameters:
    ---------------------
    rotation_angle : float
        Rotation angle in degrees
    rotation_axis : str
        Axis to rotate around ('x', 'y', or 'z')
    """
    import matplotlib.pyplot as plt
    import numpy as np
    from matplotlib.patches import Circle
    from matplotlib.collections import PatchCollection
    from matplotlib.lines import Line2D
    from ase.data import covalent_radii, chemical_symbols
    
    def rotation_matrix(angle, axis):
        """Create rotation matrix for given angle and axis."""
        angle_rad = np.radians(angle)
        cos_a, sin_a = np.cos(angle_rad), np.sin(angle_rad)
        
        if axis == 'x':
            return np.array([[1, 0, 0],
                           [0, cos_a, -sin_a],
                           [0, sin_a, cos_a]])
        elif axis == 'y':
            return np.array([[cos_a, 0, sin_a],
                           [0, 1, 0],
                           [-sin_a, 0, cos_a]])
        elif axis == 'z':
            return np.array([[cos_a, -sin_a, 0],
                           [sin_a, cos_a, 0],
                           [0, 0, 1]])
        else:
            return np.eye(3)
    
    # Store original number of atoms before expansion
    original_n_atoms = len(atoms)
    
    # Expand the atoms structure
    atoms = atoms * expand_cell
    
    # Get atomic data from expanded structure
    positions = atoms.get_positions()
    symbols = atoms.get_chemical_symbols()
    numbers = atoms.get_atomic_numbers()
    
    # Apply z-range filter if specified (after rotation)
    if z_range is not None:
        z_min, z_max = z_range
        z_coords = positions[:, 2]  # z is always the third coordinate
        z_mask = (z_coords >= z_min) & (z_coords <= z_max)
        
        # Filter all arrays based on z_mask
        positions = positions[z_mask]
        symbols = [symbols[i] for i, keep in enumerate(z_mask) if keep]
        numbers = numbers[z_mask]
        
        print(f"Filtered to {len(positions)} atoms within z-range [{z_min:.2f}, {z_max:.2f}] Å")
        print(f"Original z-range: [{z_coords.min():.2f}, {z_coords.max():.2f}] Å")
    
    # Apply rotation if specified
    if rotation_angle != 0:
        rot_matrix = rotation_matrix(rotation_angle, rotation_axis)
        positions = np.dot(positions, rot_matrix.T)

    # Handle eigenvectors with rotation
    if eigvecs is not None and show_vectors:
        eigvecs = np.array(eigvecs)
        if eigvecs.ndim == 1:
            eigvecs = eigvecs.reshape(-1, 3)
        elif eigvecs.ndim > 2:
            eigvecs = eigvecs.reshape(-1, 3)
        
        # Apply rotation to eigenvectors
        if rotation_angle != 0:
            rot_matrix = rotation_matrix(rotation_angle, rotation_axis)
            eigvecs = np.dot(eigvecs, rot_matrix.T)
        
        # Expand eigenvectors for supercell
        if eigvecs.shape[0] == original_n_atoms:
            nx, ny, nz = expand_cell
            expanded_eigvecs = []
            for iz in range(nz):
                for iy in range(ny):
                    for ix in range(nx):
                        expanded_eigvecs.append(eigvecs)
            eigvecs = np.vstack(expanded_eigvecs)
        
        # Apply z-range filter to eigenvectors
        if z_range is not None:
            if eigvecs.shape[0] == len(z_mask):
                eigvecs = eigvecs[z_mask]
            else:
                print(f"Warning: eigvecs shape {eigvecs.shape} doesn't match expanded atoms after filtering")
                show_vectors = False
        elif eigvecs.shape[0] != len(positions):
            print(f"Warning: eigvecs shape {eigvecs.shape} doesn't match expanded atoms {len(positions)}")
            show_vectors = False
    
    # Check if we have any atoms left after filtering
    if len(positions) == 0:
        print("No atoms found in the specified z-range!")
        return
    
    # Create figure
    fig, ax = plt.subplots(figsize=(8, 8))
    
    # Choose projection plane
    if plane == "xy":
        axes = [0, 1, 2]  # x, y, z indices
        xlabel, ylabel = "x [Å]", "y [Å]"
    elif plane == "xz":
        axes = [0, 2, 1]  # x, z, y indices
        xlabel, ylabel = "x [Å]", "z [Å]"
    elif plane == "yz":
        axes = [1, 2, 0]  # y, z, x indices
        xlabel, ylabel = "y [Å]", "z [Å]"
    
    # Get 2D positions for the chosen plane
    positions_2d = positions[:, axes[:2]]
    
    # Sort atoms by depth for proper layering
    order = np.argsort(-positions[:, axes[2]])
    positions_2d = positions_2d[order]
    ordered_numbers = numbers[order]
    ordered_symbols = [symbols[i] for i in order]
    
    # Color mapping
    unique_symbols = sorted(set(symbols))
    if colors is None:
        colormap = plt.cm.Set3
        color_map = {sym: colormap(i / len(unique_symbols)) for i, sym in enumerate(unique_symbols)}
    else:
        colormap = plt.cm.Set3
        color_map = {}
        for i, sym in enumerate(unique_symbols):
            if sym in colors:
                color_map[sym] = colors[sym]
            else:
                color_map[sym] = colormap(i / len(unique_symbols))
                print(f"Warning: No color specified for element '{sym}', using default color")

    atom_colors = [color_map[sym] for sym in ordered_symbols]
    sizes = covalent_radii[ordered_numbers] * atom_scale
    
    # Create and add atom circles
    circles = [Circle(position, size) for position, size in zip(positions_2d, sizes)]
    coll = PatchCollection(circles, facecolors=atom_colors, edgecolors="black", linewidths=1)
    ax.add_collection(coll)
    
    # Add eigenvector arrows
    if show_vectors and eigvecs is not None:
        for i, pos in enumerate(positions):
            vec = np.real(eigvecs[i]) * scale
            vec_2d = vec[axes[:2]]
            pos_2d = pos[axes[:2]]
            
            if np.linalg.norm(vec_2d) > 1e-10:
                ax.arrow(pos_2d[0], pos_2d[1], vec_2d[0], vec_2d[1], 
                        head_width=head_width, head_length=head_length, fc='red', ec='red', 
                        linewidth=arrow_linewidth, alpha=0.8)
    
    # Set plot properties
    ax.set_aspect('equal')
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    # Update title
    title = "Phonon Mode Eigenvectors"
    if rotation_angle != 0:
        title += f" (rotated {rotation_angle}° around {rotation_axis}-axis)"
    if z_range is not None:
        title += f" (z: {z_range[0]:.1f} - {z_range[1]:.1f} Å)"
    ax.set_title(title)
    
    # Add legend
    if legend:
        legend_elements = [
            Line2D([0], [0], marker="o", color="w", markeredgecolor="k",
                   label=sym, markerfacecolor=color_map[sym], markersize=12)
            for sym in unique_symbols
        ]
        ax.legend(handles=legend_elements, loc="upper right")
    
    # Add numbering if requested
    if numbering:
        for i, (position, size) in enumerate(zip(positions_2d, sizes)):
            ax.annotate(f"{order[i]}", xy=position, ha="center", va="center",
                       fontsize=8, fontweight='bold')
    
    # Adjust plot limits
    if len(positions_2d) > 0:
        margin = max(sizes) * 1.5
        x_min, x_max = positions_2d[:, 0].min() - margin, positions_2d[:, 0].max() + margin
        y_min, y_max = positions_2d[:, 1].min() - margin, positions_2d[:, 1].max() + margin
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
    
    plt.tight_layout()
    return fig

# ===============================
# 3. Interactive Controls
# ===============================

def interactive_viewer(qpoints, frequencies, eigenvectors, atoms, scale=2.0, expand_cell=(1, 1, 1), plane="xz",
                        q_index=0, band_index=0, colors=None, legend=True, atom_scale=0.4,
                        head_width=0.1, arrow_linewidth=0.1, z_range=None,
                        rotation_angle=0, rotation_axis='z',
                        dispersion_y_range=None,aspect_ratio=5.0,
                        save_dir=None):
    """Interactive widget to explore phonon modes with save functionality."""

    # Global variables to store current figures
    current_dispersion_fig = None
    current_eigenvector_fig = None
    
    def view_mode(q_index=0, band_index=0, head_width=head_width, arrow_linewidth=arrow_linewidth, 
                  scale=scale, aspect_ratio=aspect_ratio):
        nonlocal current_dispersion_fig, current_eigenvector_fig
        
        # Clear previous plots
        plt.close('all')
        
        # Plot dispersion
        fig_disp, ax_disp = plot_dispersion(qpoints, frequencies, aspect=aspect_ratio)
        ax_disp.plot(q_index, frequencies[q_index, band_index], 'ro', markersize=10)
        ax_disp.set_title(f"Phonon Dispersion (selected mode {frequencies[q_index, band_index]:.2f} meV)")
        if dispersion_y_range is not None:
            ax_disp.set_ylim(dispersion_y_range)
        current_dispersion_fig = fig_disp
        plt.show()

        # Plot eigenvectors
        eigvecs = eigenvectors[q_index, band_index]
        current_eigenvector_fig = plot_eigenvectors(atoms, eigvecs, scale=scale, expand_cell=expand_cell, 
                                                   colors=colors, legend=legend,
                                                   head_width=head_width, arrow_linewidth=arrow_linewidth, 
                                                   z_range=z_range, plane=plane,
                                                   rotation_angle=rotation_angle, rotation_axis=rotation_axis, 
                                                   atom_scale=atom_scale)
        plt.show()

    def save_plots(button,save_dir=save_dir):
        """Save both dispersion and eigenvector plots."""
        if current_dispersion_fig is None or current_eigenvector_fig is None:
            print("No plots to save. Please generate plots first.")
            return
            
        # Generate filename with current parameters
        from datetime import datetime
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Get current slider values
        current_q = q_slider.value
        current_band = band_slider.value
        current_freq = frequencies[current_q, current_band]
        
        # Save dispersion plot
        disp_filename = f"phonon_dispersion_q{current_q}_band{current_band}_{current_freq:.1f}meV_{timestamp}.png"
        current_dispersion_fig.savefig(os.path.join(save_dir,disp_filename), dpi=300, bbox_inches='tight')
        print(f"Dispersion plot saved as: {disp_filename}")
        
        # Save eigenvector plot
        eigvec_filename = f"phonon_eigenvectors_q{current_q}_band{current_band}_{current_freq:.1f}meV_{timestamp}.png"
        current_eigenvector_fig.savefig(os.path.join(save_dir,eigvec_filename), dpi=300, bbox_inches='tight')
        print(f"Eigenvector plot saved as: {eigvec_filename}")
        
        print("Both plots saved successfully!")

    def save_plots_wrapper(button):
        save_plots(button, save_dir=save_dir)

    # Create save button
    save_button = Button(description='Save Plots', button_style='success', icon='save')
    save_button.on_click(save_plots_wrapper)

    # Create sliders
    q_slider = IntSlider(min=0, max=len(qpoints)-1, step=1, value=q_index, description='Q-point')
    band_slider = IntSlider(min=0, max=frequencies.shape[1]-1, step=1, value=band_index, description='Band')
    head_width_slider = FloatSlider(min=0.1, max=0.5, step=0.1, value=head_width, description='Head width')
    arrow_linewidth_slider = FloatSlider(min=1, max=5, step=1, value=arrow_linewidth, description='Arrow width')
    scale_slider = FloatSlider(min=1, max=10.0, step=0.1, value=scale, description='Scale')
    aspect_ratio_slider = FloatSlider(min=1, max=20.0, step=0.1, value=aspect_ratio, description='Aspect ratio')
    
    # Create interactive function
    def update_plots(q_index, band_index, head_width, arrow_linewidth, scale, aspect_ratio):
        view_mode(q_index, band_index, head_width, arrow_linewidth, scale, aspect_ratio)
    
    # Display save button at the top
    display(save_button)
    
    # Create the interactive widget
    interactive_widget = interact(update_plots,
                                 q_index=q_slider,
                                 band_index=band_slider,
                                 head_width=head_width_slider,
                                 arrow_linewidth=arrow_linewidth_slider,
                                 scale=scale_slider,
                                 aspect_ratio=aspect_ratio_slider)
    
    return interactive_widget

In [3]:
import os
path = r"D:\OneDrive - personalmicrosoftsoftware.uci.edu\BAs\BAsPDOS\Zhe Wang DFT TB BAs\twin-4\data-twin-4\twin"
print("Load band.yaml and POSCAR")
qpoints, freqs, eigvecs = parse_band_yaml(os.path.join(path, "band-twin-4.yaml")) #band-twin-4.yaml
atoms = parse_poscar(os.path.join(path, "POSCAR"))
# Step 1: quick dispersion plot
# plot_dispersion(qpoints, freqs)

Load band.yaml and POSCAR
Data dimensions: 707 q-points, 144 bands, 48 atoms
Text parsing took: 77.08 seconds


In [35]:
%matplotlib widget
# Step 2: start interactive viewer (in Jupyter)
colors={    'B': 'blue',           # Silicon in blue
            'As': 'lime',  }
twin_z=((atoms.get_positions()[:,2].max()+3*atoms.get_positions()[:,2].min())/4,(3*atoms.get_positions()[:,2].max()+atoms.get_positions()[:,2].min())/4)
interactive_viewer(qpoints, freqs, eigvecs, atoms,expand_cell=(5, 1, 1), 
                    q_index=200, band_index=120, colors=colors,legend=False, 
                    scale = 5.6, head_width=0.2, arrow_linewidth=3.0,
                    z_range=twin_z,
                    plane="xz",
                    rotation_angle=-90,
                    rotation_axis='y',
                    dispersion_y_range=(70, 95),
                    aspect_ratio=20.0,
                    save_dir=path)

TypeError: Button.on_click() got an unexpected keyword argument 'save_dir'

In [None]:
# Test the new interactive viewer with save functionality
# The save button will appear below the sliders
# When clicked, it will save both the dispersion plot and eigenvector plot
# with descriptive filenames including q-point, band, frequency, and timestamp

print("Testing the new interactive viewer with save functionality...")
print("Use the sliders to adjust parameters, then click the 'Save Plots' button to save both images.")
