# Dataset Processing Utilities

Reusable functions for computing adsorption heights, extracting atomic indices, and other dataset processing tasks.

## Imports

In [None]:
import numpy as np
import pandas as pd
from typing import List, Tuple, Dict, Optional
import os

## Adsorption Height Functions

In [None]:
def get_adsorption_heights_per_atom(
    contcar_path: str,
    adsorbate_indices: List[int],
    surface_indices: List[int]
) -> List[float]:
    """
    Computes adsorption heights for each adsorbate atom, defined as the vertical (z) distance
    to the closest surface atom.

    Parameters:
    -----------
    contcar_path : str
        Path to the VASP CONTCAR file (in direct coordinates).
    adsorbate_indices : list[int]
        Indices of adsorbate atoms (0-based).
    surface_indices : list[int]
        Indices of surface atoms (0-based).

    Returns:
    --------
    list[float]
        List of adsorption heights (in Å) for each adsorbate atom.
    """
    with open(contcar_path, 'r') as f:
        lines = f.readlines()

    scale = float(lines[1].strip())
    lattice_vectors = np.array([list(map(float, lines[i].split())) for i in range(2, 5)])
    lattice_vectors *= scale

    # Find element counts and coordinate start
    idx = 5
    while not all(c.isalpha() or c.isspace() for c in lines[idx]):
        idx += 1
    num_atoms = list(map(int, lines[idx + 1].split()))
    total_atoms = sum(num_atoms)

    coord_start_idx = idx + 2
    if lines[coord_start_idx].strip().lower().startswith("selective"):
        coord_start_idx += 1
    if not lines[coord_start_idx].strip().lower().startswith("direct"):
        raise ValueError("Expected 'Direct' coordinates.")
    coord_start_idx += 1

    # Read and convert coordinates
    direct_coords = np.array([
        list(map(float, lines[i].strip().split()[:3]))
        for i in range(coord_start_idx, coord_start_idx + total_atoms)
    ])
    cart_coords = np.dot(direct_coords, lattice_vectors)

    # Surface and adsorbate z-coordinates
    surface_z = cart_coords[surface_indices][:, 2]

    adsorption_heights = []
    for idx in adsorbate_indices:
        z_ads = cart_coords[idx][2]
        z_closest_surface = surface_z[np.argmin(np.abs(surface_z - z_ads))]
        adsorption_heights.append(z_ads - z_closest_surface)

    return adsorption_heights


def get_adsorption_heights_average(
    contcar_path: str,
    adsorbate_indices: List[int],
    surface_indices: List[int]
) -> float:
    """
    Computes average adsorption height of all adsorbate atoms.

    Parameters:
    -----------
    contcar_path : str
        Path to the VASP CONTCAR file.
    adsorbate_indices : list[int]
        Indices of adsorbate atoms (0-based).
    surface_indices : list[int]
        Indices of surface atoms (0-based).

    Returns:
    --------
    float
        Average adsorption height in Å.
    """
    heights = get_adsorption_heights_per_atom(contcar_path, adsorbate_indices, surface_indices)
    return np.mean(heights) if heights else 0.0


def get_adsorption_heights_min(
    contcar_path: str,
    adsorbate_indices: List[int],
    surface_indices: List[int]
) -> float:
    """
    Computes minimum (closest) adsorption height of adsorbate atoms.

    Parameters:
    -----------
    contcar_path : str
        Path to the VASP CONTCAR file.
    adsorbate_indices : list[int]
        Indices of adsorbate atoms (0-based).
    surface_indices : list[int]
        Indices of surface atoms (0-based).

    Returns:
    --------
    float
        Minimum adsorption height in Å.
    """
    heights = get_adsorption_heights_per_atom(contcar_path, adsorbate_indices, surface_indices)
    return np.min(heights) if heights else 0.0

## Coordinate Transformation Functions

In [None]:
def read_poscar_coordinates(
    poscar_path: str,
    coordinate_type: str = 'auto'
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Read Cartesian coordinates from VASP POSCAR/CONTCAR file.

    Parameters:
    -----------
    poscar_path : str
        Path to POSCAR or CONTCAR file.
    coordinate_type : str
        'direct' for direct coordinates, 'cartesian' for Cartesian,
        'auto' to detect automatically. Default: 'auto'

    Returns:
    --------
    tuple
        (cartesian_coordinates, lattice_vectors, element_symbols)
    """
    with open(poscar_path, 'r') as f:
        lines = f.readlines()

    # Parse lattice scale
    scale = float(lines[1].strip())
    lattice_vectors = np.array([list(map(float, lines[i].split())) for i in range(2, 5)])
    lattice_vectors *= scale

    # Find element counts
    idx = 5
    while not all(c.isalpha() or c.isspace() for c in lines[idx]):
        idx += 1
    element_symbols = lines[idx].split()
    num_atoms = list(map(int, lines[idx + 1].split()))
    total_atoms = sum(num_atoms)

    # Find coordinate type
    coord_start_idx = idx + 2
    if lines[coord_start_idx].strip().lower().startswith("selective"):
        coord_start_idx += 1

    coord_line = lines[coord_start_idx].strip().lower()
    if coordinate_type == 'auto':
        is_direct = coord_line.startswith('direct') or 'd' in coord_line
    else:
        is_direct = coordinate_type.lower() in ['direct', 'd']

    coord_start_idx += 1

    # Read coordinates
    coordinates = np.array([
        list(map(float, lines[i].strip().split()[:3]))
        for i in range(coord_start_idx, coord_start_idx + total_atoms)
    ])

    # Convert to Cartesian if needed
    if is_direct:
        cart_coords = np.dot(coordinates, lattice_vectors)
    else:
        cart_coords = coordinates * scale

    return cart_coords, lattice_vectors, element_symbols


def cart_to_direct(
    cartesian_coords: np.ndarray,
    lattice_vectors: np.ndarray
) -> np.ndarray:
    """
    Convert Cartesian coordinates to direct coordinates.

    Parameters:
    -----------
    cartesian_coords : np.ndarray
        Cartesian coordinates (N x 3).
    lattice_vectors : np.ndarray
        Lattice vectors (3 x 3).

    Returns:
    --------
    np.ndarray
        Direct coordinates (N x 3).
    """
    return np.dot(cartesian_coords, np.linalg.inv(lattice_vectors))

## Atomic Index Functions

def compute_distances(
    coords1: np.ndarray,
    coords2: np.ndarray,
    cell: Optional[np.ndarray] = None
) -> np.ndarray:
    """
    Compute pairwise distances between atoms, with optional periodic boundary conditions.

    Parameters:
    -----------
    coords1 : np.ndarray
        First set of coordinates (N x 3).
    coords2 : np.ndarray
        Second set of coordinates (M x 3).
    cell : np.ndarray, optional
        Lattice vectors (3 x 3) for PBC calculations. If None, no PBC applied.

    Returns:
    --------
    np.ndarray
        Distance matrix (N x M).
    """
    if cell is None:
        # Simple Euclidean distance
        diff = coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :]
        return np.linalg.norm(diff, axis=2)
    else:
        # With periodic boundary conditions
        inv_cell = np.linalg.inv(cell)
        diff = coords1[:, np.newaxis, :] - coords2[np.newaxis, :, :]
        diff_direct = np.dot(diff, inv_cell.T)
        diff_direct = diff_direct - np.round(diff_direct)
        diff_cart = np.dot(diff_direct, cell.T)
        return np.linalg.norm(diff_cart, axis=2)


def get_nearest_neighbors(
    coords: np.ndarray,
    center_idx: int,
    n_neighbors: int = 12,
    cell: Optional[np.ndarray] = None
) -> Tuple[List[int], np.ndarray]:
    """
    Find nearest neighbors of a given atom.

    Parameters:
    -----------
    coords : np.ndarray
        Atomic coordinates (N x 3).
    center_idx : int
        Index of center atom.
    n_neighbors : int
        Number of nearest neighbors to return. Default: 12
    cell : np.ndarray, optional
        Lattice vectors for PBC. Default: None

    Returns:
    --------
    tuple
        (neighbor_indices, neighbor_distances)
    """
    center_coord = coords[center_idx:center_idx+1]
    distances = compute_distances(center_coord, coords, cell=cell)[0]
    
    # Sort and get top n_neighbors (excluding center atom)
    sorted_indices = np.argsort(distances)
    neighbor_indices = sorted_indices[1:n_neighbors+1]  # Skip center atom (distance 0)
    neighbor_distances = distances[neighbor_indices]
    
    return neighbor_indices.tolist(), neighbor_distances

# Example 1: Calculate adsorption heights
"""
path = "/path/to/CONTCAR"
adsorbate_indices = [120, 121, 122, 123, 124, 125]  # Example indices
surface_indices = [29, 89, 9, 49, 99, 59, 24, 4, 94, 54, 39, 109, 19, 69, 119, 79, 34, 14, 114, 74]

heights = get_adsorption_heights_per_atom(path, adsorbate_indices, surface_indices)
avg_height = get_adsorption_heights_average(path, adsorbate_indices, surface_indices)
min_height = get_adsorption_heights_min(path, adsorbate_indices, surface_indices)

print(f"Per-atom heights: {heights}")
print(f"Average height: {avg_height:.3f} Å")
print(f"Minimum height: {min_height:.3f} Å")
"""

# Example 3: Compute nearest neighbors
"""
path = "/path/to/CONTCAR"

cart_coords, lattice, elements = read_poscar_coordinates(path)

# Find 12 nearest neighbors of first atom with periodic BC
neighbors, distances = get_nearest_neighbors(
    cart_coords,
    center_idx=0,
    n_neighbors=12,
    cell=lattice
)

print(f"Nearest neighbors: {neighbors}")
print(f"Distances: {distances}")
"""