# Interactive Phonon Visualization with Matplotlib

This notebook provides interactive visualization of phonon modes from Phonopy calculations with real-time controls for:
- Q-point selection
- Band selection
- Vector scaling
- Animation controls
- Color mapping

Based on the `extract_vectors_phonopy.py` functions.


In [None]:
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.widgets import Slider, Button, RadioButtons
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import sys
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as patches


In [None]:
def read_files(band_yaml_path='band.yaml', vesta_path='POSCAR.vesta'):
    """
    Read phonon data from band.yaml and POSCAR.vesta files
    
    Parameters:
    -----------
    band_yaml_path : str
        Path to band.yaml file
    vesta_path : str
        Path to POSCAR.vesta file
        
    Returns:
    --------
    tuple : (band_yaml, vesta)
        Raw file contents
    """
    try:
        with open(band_yaml_path, 'r') as f:
            band_yaml = f.read().split('phonon:')[1]
        with open(vesta_path, 'r') as f:
            vesta = f.read()
        print(f"Successfully loaded {band_yaml_path} and {vesta_path}")
        return band_yaml, vesta
    except FileNotFoundError as e:
        print(f"File not found: {e}")
        return None, None
    except Exception as e:
        print(f"Error reading files: {e}")
        return None, None


In [None]:
def extract(band_yaml):
    """
    Extract phonon data from band.yaml content
    
    Parameters:
    -----------
    band_yaml : str
        Raw content of band.yaml file
        
    Returns:
    --------
    tuple : (displacements, qpoint_band, q_position)
        displacements : array [qpoint][band][atom][direction]
        qpoint_band : array [qpoint][band] - frequencies
        q_position : list - q-point positions
    """
    if band_yaml is None:
        return None, None, None
        
    # Extract q-points
    q_point = band_yaml.split('q-position:')[1:]
    q_position = []
    for i in range(len(q_point)):
        q_position.append(q_point[i].split('[')[1].split(']')[0])

    # Extract bands
    bands = []
    for band in q_point:
        if band:
            bands.append(band.split('frequency:')[1:])
    
    nq_point = len(bands)
    nbands = len(bands[0])
    
    qpoint_band = [['' for band in range(nbands)] for qpoint in range(nq_point)]
    eigenvectors = [['' for band in range(nbands)] for qpoint in range(nq_point)]
    
    for q_point in range(nq_point):
        for band in range(nbands):
            data = (bands[q_point][band].split('atom')[1:])
            eigenvectors[q_point][band] = data
            data = float(bands[q_point][band].split('eigenvector')[0])
            qpoint_band[q_point][band] = data
    
    natoms = len(eigenvectors[0][0])
    qpoint_band = np.array(qpoint_band, dtype=float)
               
    # Extract displacements
    displacements = [[[[0 for direction in range(3)] for atom in range(natoms)] 
                    for band in range(nbands)] for qpoint in range(nq_point)]
    
    for q_point in range(nq_point):
        for band in range(nbands):
            for atom in range(natoms):
                vector = eigenvectors[q_point][band][atom]
                for direction in range(3):
                    data = float(vector.split('[')[direction + 1].split(',')[0])
                    displacements[q_point][band][atom][direction] = data
                    
    displacements = np.array(displacements, dtype=float)
    
    print(f"Extracted data: {nq_point} q-points, {nbands} bands, {natoms} atoms")
    return displacements, qpoint_band, q_position


In [None]:
class PhononVisualizer:
    """
    Interactive phonon mode visualizer with matplotlib
    """
    
    def __init__(self, displacements, qpoint_band, q_position, atom_positions=None):
        self.displacements = displacements
        self.qpoint_band = qpoint_band
        self.q_position = q_position
        self.atom_positions = atom_positions
        
        self.nqpoint = len(displacements)
        self.nbands = len(displacements[0])
        self.natoms = len(displacements[0][0])
        
        # Default parameters
        self.current_qpoint = 0
        self.current_band = 0
        self.scale = 5.0
        self.animation_speed = 0.1
        self.is_animating = False
        self.time = 0
        
        # Create figure and axes
        self.fig, self.ax = plt.subplots(figsize=(12, 8))
        self.setup_plot()
        
    def setup_plot(self):
        """Initialize the plot"""
        self.ax.clear()
        self.ax.set_xlabel('X')
        self.ax.set_ylabel('Y')
        self.ax.set_title('Interactive Phonon Mode Visualization')
        self.ax.grid(True, alpha=0.3)
        self.ax.set_aspect('equal')
        
    def get_atom_positions(self):
        """Generate default atom positions if not provided"""
        if self.atom_positions is None:
            # Create a simple 2D grid for visualization
            positions = np.zeros((self.natoms, 2))
            for i in range(self.natoms):
                positions[i] = [i % 4, i // 4]  # Simple grid layout
            return positions
        return self.atom_positions
    
    def update_visualization(self, qpoint=None, band=None, scale=None, time=None):
        """Update the visualization with new parameters"""
        if qpoint is not None:
            self.current_qpoint = qpoint
        if band is not None:
            self.current_band = band
        if scale is not None:
            self.scale = scale
        if time is not None:
            self.time = time
            
        self.ax.clear()
        self.setup_plot()
        
        # Get current displacements
        disp = self.displacements[self.current_qpoint][self.current_band]
        atom_pos = self.get_atom_positions()
        
        # Calculate animated positions
        if self.is_animating:
            phase = 2 * np.pi * self.time
            animated_disp = disp * np.sin(phase)
        else:
            animated_disp = disp
            
        # Plot atoms
        for i in range(self.natoms):
            x, y = atom_pos[i]
            dx, dy = animated_disp[i][0] * self.scale, animated_disp[i][1] * self.scale
            
            # Atom position
            self.ax.scatter(x, y, c='blue', s=100, alpha=0.7, zorder=2)
            
            # Displacement vector
            if np.linalg.norm([dx, dy]) > 0.001:  # Only draw if significant
                self.ax.arrow(x, y, dx, dy, head_width=0.1, head_length=0.1, 
                            fc='red', ec='red', alpha=0.8, zorder=1)
        
        # Add information text
        freq = self.qpoint_band[self.current_qpoint][self.current_band] * 4.14  # Convert to meV
        qpos = self.q_position[self.current_qpoint]
        info_text = f"Q-point: {qpos}\nBand: {self.current_band + 1}\nFrequency: {freq:.2f} meV\nScale: {self.scale:.1f}"
        self.ax.text(0.02, 0.98, info_text, transform=self.ax.transAxes, 
                    verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
        
        self.fig.canvas.draw()
    
    def create_interactive_controls(self):
        """Create interactive control widgets"""
        
        # Q-point selector
        qpoint_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=self.nqpoint-1,
            step=1,
            description='Q-point:',
            style={'description_width': 'initial'}
        )
        
        # Band selector
        band_slider = widgets.IntSlider(
            value=0,
            min=0,
            max=self.nbands-1,
            step=1,
            description='Band:',
            style={'description_width': 'initial'}
        )
        
        # Scale slider
        scale_slider = widgets.FloatSlider(
            value=5.0,
            min=0.1,
            max=20.0,
            step=0.1,
            description='Scale:',
            style={'description_width': 'initial'}
        )
        
        # Animation controls
        animate_button = widgets.Button(description='Start Animation')
        speed_slider = widgets.FloatSlider(
            value=0.1,
            min=0.01,
            max=1.0,
            step=0.01,
            description='Speed:',
            style={'description_width': 'initial'}
        )
        
        # View options
        view_2d = widgets.Button(description='2D View')
        view_3d = widgets.Button(description='3D View')
        
        # Connect callbacks
        def on_qpoint_change(change):
            self.update_visualization(qpoint=change['new'])
            
        def on_band_change(change):
            self.update_visualization(band=change['new'])
            
        def on_scale_change(change):
            self.update_visualization(scale=change['new'])
            
        def on_animate_click(b):
            if self.is_animating:
                self.is_animating = False
                animate_button.description = 'Start Animation'
            else:
                self.is_animating = True
                animate_button.description = 'Stop Animation'
                self.animate()
                
        def on_speed_change(change):
            self.animation_speed = change['new']
            
        qpoint_slider.observe(on_qpoint_change, names='value')
        band_slider.observe(on_band_change, names='value')
        scale_slider.observe(on_scale_change, names='value')
        animate_button.on_click(on_animate_click)
        speed_slider.observe(on_speed_change, names='value')
        
        # Layout controls
        controls = widgets.VBox([
            widgets.HBox([qpoint_slider, band_slider]),
            widgets.HBox([scale_slider, speed_slider]),
            widgets.HBox([animate_button, view_2d, view_3d])
        ])
        
        return controls
    
    def animate(self):
        """Animation loop"""
        if self.is_animating:
            self.time += self.animation_speed
            self.update_visualization(time=self.time)
            # Schedule next frame
            self.fig.canvas.mpl_connect('timer_event', lambda event: self.animate())
    
    def plot_frequency_spectrum(self):
        """Plot frequency spectrum for all bands at current q-point"""
        fig, ax = plt.subplots(figsize=(10, 6))
        
        frequencies = self.qpoint_band[self.current_qpoint] * 4.14  # Convert to meV
        bands = range(1, len(frequencies) + 1)
        
        bars = ax.bar(bands, frequencies, alpha=0.7, color='skyblue', edgecolor='navy')
        
        # Highlight current band
        if 0 <= self.current_band < len(frequencies):
            bars[self.current_band].set_color('red')
            bars[self.current_band].set_alpha(1.0)
        
        ax.set_xlabel('Band Index')
        ax.set_ylabel('Frequency (meV)')
        ax.set_title(f'Frequency Spectrum at Q-point {self.current_qpoint + 1}')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def plot_dispersion_curve(self):
        """Plot phonon dispersion curve"""
        fig, ax = plt.subplots(figsize=(12, 8))
        
        q_indices = range(self.nqpoint)
        
        for band in range(self.nbands):
            frequencies = self.qpoint_band[:, band] * 4.14  # Convert to meV
            ax.plot(q_indices, frequencies, 'b-', alpha=0.7, linewidth=1)
        
        # Highlight current q-point and band
        if 0 <= self.current_qpoint < self.nqpoint and 0 <= self.current_band < self.nbands:
            current_freq = self.qpoint_band[self.current_qpoint, self.current_band] * 4.14
            ax.plot(self.current_qpoint, current_freq, 'ro', markersize=10, zorder=5)
        
        ax.set_xlabel('Q-point Index')
        ax.set_ylabel('Frequency (meV)')
        ax.set_title('Phonon Dispersion Curve')
        ax.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()


In [None]:
# Load data
band_yaml, vesta = read_files('band.yaml', 'POSCAR.vesta')

if band_yaml is not None and vesta is not None:
    # Extract phonon data
    displacements, qpoint_band, q_position = extract(band_yaml)
    
    if displacements is not None:
        print(f"Data loaded successfully!")
        print(f"Q-points: {len(displacements)}")
        print(f"Bands: {len(displacements[0])}")
        print(f"Atoms: {len(displacements[0][0])}")
    else:
        print("Failed to extract data from band.yaml")
else:
    print("Please ensure band.yaml and POSCAR.vesta files are in the current directory")


In [None]:
# Create interactive visualizer
if 'displacements' in locals() and displacements is not None:
    visualizer = PhononVisualizer(displacements, qpoint_band, q_position)
    
    # Display initial visualization
    visualizer.update_visualization()
    plt.show()
    
    # Create and display controls
    controls = visualizer.create_interactive_controls()
    display(controls)


In [None]:
# Additional analysis plots
if 'visualizer' in locals():
    # Frequency spectrum
    visualizer.plot_frequency_spectrum()
    
    # Dispersion curve
    visualizer.plot_dispersion_curve()


In [None]:
# Export functionality
def export_animation(visualizer, filename='phonon_animation.gif', duration=5):
    """Export animation as GIF"""
    import matplotlib.animation as animation
    
    def animate_frame(frame):
        time = frame * 0.1
        visualizer.update_visualization(time=time)
        return visualizer.ax.collections + visualizer.ax.patches
    
    anim = animation.FuncAnimation(
        visualizer.fig, animate_frame, frames=int(duration * 10),
        interval=100, blit=False, repeat=True
    )
    
    anim.save(filename, writer='pillow', fps=10)
    print(f"Animation saved as {filename}")

def export_static_image(visualizer, filename='phonon_mode.png'):
    """Export current view as static image"""
    visualizer.fig.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Image saved as {filename}")

# Example usage (uncomment to use)
# if 'visualizer' in locals():
#     export_static_image(visualizer)
#     export_animation(visualizer)


## Usage Instructions

1. **Load Data**: Run the data loading cell to read your `band.yaml` and `POSCAR.vesta` files
2. **Interactive Controls**: Use the sliders and buttons to:
   - Select different q-points and bands
   - Adjust vector scaling
   - Start/stop animation
   - Control animation speed
3. **Analysis Plots**: View frequency spectra and dispersion curves
4. **Export**: Save static images or animated GIFs of your visualizations

## Features

- **Real-time parameter adjustment** with sliders
- **Animation controls** for dynamic visualization
- **Multiple view modes** (2D/3D)
- **Frequency analysis** with spectrum plots
- **Dispersion curve** visualization
- **Export capabilities** for images and animations
- **Interactive highlighting** of current selections
