In [11]:
"""
Differentiable, batched order parameter calculator for torch_sim SimState.

Computes structural order parameters (tetrahedral, octahedral, bond-orientational)
for specified atoms with automatic neighbor finding and PBC handling.
 "torch_nl_linked_cell",
    "torch_nl_n2",
    "torchsim_nl",
Updated for torch_sim v2  with batched neighbor lists.
"""

import torch
import torch.nn as nn
import torch_sim as ts
from torch_sim.neighbors import torch_nl_linked_cell, torch_nl_n2, torchsim_nl
from typing import List, Dict, Optional
import math


class TorchSimOrderParameters(nn.Module):
    """
    Differentiable order parameter calculator for torch_sim SimState.
    
    Automatically handles:
    - Neighbor finding with PBC 
    - Batched computation
    - Differentiability for ML applications
    
    Updated  batched neighbor list API.
    
    Usage:
        >>> op_calc = TorchSimOrderParameters(device='cuda')
        >>> state = ts.initialize_state(structure, device='cuda', dtype=torch.float32)
        >>> 
        >>> # Compute for specific P atoms
        >>> p_indices = torch.where(state.atomic_numbers == 15)[0]
        >>> results = op_calc(state, p_indices, ['tet', 'cn', 'q4'])
        >>> 
        >>> print(f"Tetrahedral OP: {results['tet']}")  # Shape: (n_p_atoms,)
    """
    
    SUPPORTED_TYPES = [
        'cn', 'tet', 'oct', 'bcc',
        'q2', 'q4', 'q6',
        'tri_plan', 'sq_plan', 'tri_pyr'
    ]
    
    def __init__(self, cutoff: float = 3.5, device: str = 'cpu'):
        """
        Args:
            cutoff: Distance cutoff for neighbor finding (Angstroms)
            device: 'cpu' or 'cuda'
        """
        super().__init__()
        self.cutoff = cutoff
        self.device = torch.device(device)
        self.default_params = self._initialize_default_params()
        
    def _initialize_default_params(self) -> Dict[str, Dict[str, float]]:
        """Initialize default parameters for each order parameter."""
        return {
            'tet': {
                'TA': 0.6081734479693927,
                'delta_theta': 12.0,
            },
            'oct': {
                'min_SPP': 2.792526803190927,
                'delta_theta1': 12.0,
                'delta_theta2': 10.0,
                'w_SPP': 3.0,
            },
            'bcc': {
                'min_SPP': 2.792526803190927,
                'delta_theta': 19.47,
                'w_SPP': 6.0,
            },
        }
    
    def forward(
        self,
        state: ts.SimState,
        atom_indices: torch.Tensor,
        order_params: List[str],
        element_filter: Optional[List[int]] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Compute order parameters for specified atoms in a SimState.
        
        Args:
            state: torch_sim SimState containing atomic structure
            atom_indices: (M,) tensor of atom indices to compute OPs for
            order_params: List of order parameter types to compute
            element_filter: Optional list of atomic numbers to include as neighbors
                          (e.g., [15, 16] for P and S only)
        
        Returns:
            Dictionary mapping OP names to (M,) tensors of values
            
        Example:
            >>> # Compute tetrahedral OP for all P atoms
            >>> p_mask = state.atomic_numbers == 15
            >>> p_indices = torch.where(p_mask)[0]
            >>> results = op_calc(state, p_indices, ['tet', 'cn'])
        """
        # Extract and normalize cell matrix for neighbor list
        if state.cell.ndim == 3:
            # Batched systems: shape (n_systems, 3, 3)
            cell_matrix = state.cell
        elif state.cell.ndim == 2:
            if state.cell.shape == (3, 3):
                # Single system: expand to (1, 3, 3)
                cell_matrix = state.cell.unsqueeze(0)
            else:
                raise ValueError(f"Expected cell shape (3, 3), got {state.cell.shape}")
        else:
            raise ValueError(f"Unexpected cell dimensionality: {state.cell.ndim}")
        
        # Ensure PBC has correct shape
        if state.pbc.ndim == 1:
            pbc = state.pbc.unsqueeze(0)  # (3,) -> (1, 3)
        else:
            pbc = state.pbc  # Already (n_systems, 3)
        
        # Get system_idx (required for new API)
        if hasattr(state, 'system_idx') and state.system_idx is not None:
            system_idx = state.system_idx
        else:
            # Single system case: all atoms belong to system 0
            system_idx = torch.zeros(len(state.positions), dtype=torch.long, device=self.device)
        
        # Convert cutoff to tensor
        cutoff_tensor = torch.tensor(self.cutoff, device=self.device, dtype=state.positions.dtype)
        
        # Get neighbor list with new API
        # Returns: (mapping, system_mapping, shifts_idx)
        mapping, system_mapping, shifts_idx = torch_nl_linked_cell(
            positions=state.positions,
            cell=cell_matrix,
            pbc=pbc,
            cutoff=cutoff_tensor,
            system_idx=system_idx,
            self_interaction=False
        )
        
        # Build neighbor list for each target atom
        neighbor_indices, valid_mask = self._build_neighbor_tensor(
            atom_indices, mapping, state.atomic_numbers, element_filter
        )
        
        # Compute positions for neighbors including PBC
        positions_with_pbc = self._apply_pbc_to_neighbors(
            state.positions, neighbor_indices, mapping, shifts_idx, 
            system_mapping, cell_matrix, system_idx, atom_indices
        )
        
        # Compute order parameters
        results = {}
        
        # Get geometric quantities
        vectors, distances, thetas, phis = self._compute_geometry(
            state.positions, atom_indices, neighbor_indices, positions_with_pbc, valid_mask
        )
        
        for op_type in order_params:
            if op_type not in self.SUPPORTED_TYPES:
                raise ValueError(f"Unsupported order parameter: {op_type}")
            
            if op_type == 'cn':
                results[op_type] = self._compute_cn(valid_mask)
            elif op_type == 'tet':
                results[op_type] = self._compute_tetrahedral(vectors, valid_mask)
            elif op_type == 'oct':
                results[op_type] = self._compute_octahedral(vectors, valid_mask)
            elif op_type == 'bcc':
                results[op_type] = self._compute_bcc(vectors, valid_mask)
            elif op_type == 'q2':
                results[op_type] = self._compute_q2(thetas, phis, valid_mask)
            elif op_type == 'q4':
                results[op_type] = self._compute_q4(thetas, phis, valid_mask)
            elif op_type == 'q6':
                results[op_type] = self._compute_q6(thetas, phis, valid_mask)
        
        return results
    
    def _build_neighbor_tensor(
        self,
        atom_indices: torch.Tensor,
        mapping: torch.Tensor,
        atomic_numbers: torch.Tensor,
        element_filter: Optional[List[int]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Build padded neighbor tensor for specified atoms.
        
        Returns:
            neighbor_indices: (M, K) tensor of neighbor indices (padded with -1)
            valid_mask: (M, K) boolean mask of valid neighbors
        """
        M = len(atom_indices)
        
        # Find neighbors for each target atom
        neighbors_list = []
        max_neighbors = 0
        
        for atom_idx in atom_indices:
            # Find all neighbors of this atom
            mask = mapping[0] == atom_idx
            neighs = mapping[1, mask]
            
            # Apply element filter if specified
            if element_filter is not None:
                element_mask = torch.isin(atomic_numbers[neighs], 
                                         torch.tensor(element_filter, device=self.device))
                neighs = neighs[element_mask]
            
            neighbors_list.append(neighs)
            max_neighbors = max(max_neighbors, len(neighs))
        
        # Pad to uniform size
        neighbor_indices = torch.full((M, max_neighbors), -1, 
                                     dtype=torch.long, device=self.device)
        
        for i, neighs in enumerate(neighbors_list):
            neighbor_indices[i, :len(neighs)] = neighs
        
        valid_mask = neighbor_indices >= 0
        
        return neighbor_indices, valid_mask
    
    def _apply_pbc_to_neighbors(
        self,
        positions: torch.Tensor,
        neighbor_indices: torch.Tensor,
        mapping: torch.Tensor,
        shifts_idx: torch.Tensor,
        system_mapping: torch.Tensor,
        cell_matrix: torch.Tensor,
        system_idx: torch.Tensor,
        atom_indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Apply periodic boundary corrections to neighbor positions.
        
        Args:
            positions: (N, 3) atomic positions
            neighbor_indices: (M, K) padded neighbor indices
            mapping: (2, n_neighbors) neighbor pairs from torch_nl
            shifts_idx: (n_neighbors, 3) integer shift indices
            system_mapping: (n_neighbors,) system index for each pair
            cell_matrix: (n_systems, 3, 3) or (1, 3, 3) unit cell
            system_idx: (N,) system index for each atom
            atom_indices: (M,) actual atom indices being computed
        
        Returns:
            positions_with_pbc: (M, K, 3) neighbor positions with PBC applied
        """
        M, K = neighbor_indices.shape
        
        # Initialize with regular positions
        valid_mask = neighbor_indices >= 0
        safe_indices = neighbor_indices.clone()
        safe_indices[~valid_mask] = 0
        
        positions_pbc = positions[safe_indices]  # (M, K, 3)
        
        # Apply shifts if needed
        if shifts_idx.abs().sum() > 0:
            # Convert shift indices to actual Cartesian shifts
            # shifts_idx is (n_neighbors, 3) with integer indices
            
            # For each neighbor pair in mapping, find corresponding shift
            for i in range(M):
                center_atom_idx = atom_indices[i]
                
                for j in range(K):
                    if not valid_mask[i, j]:
                        continue
                    
                    neigh_idx = neighbor_indices[i, j]
                    
                    # Find this pair in the mapping
                    pair_mask = (mapping[0] == center_atom_idx) & (mapping[1] == neigh_idx)
                    
                    if pair_mask.any():
                        # Get the shift indices for this pair
                        pair_shift_idx = shifts_idx[pair_mask][0]  # (3,) integer indices
                        
                        # Get the correct cell matrix for this system
                        if system_idx is not None and len(system_idx) > 0:
                            sys_idx = system_idx[neigh_idx].item()
                        else:
                            sys_idx = 0
                        
                        if cell_matrix.shape[0] > 1:
                            cell_mat = cell_matrix[sys_idx]
                        else:
                            cell_mat = cell_matrix[0]
                        
                        # Convert integer shift indices to Cartesian shift
                        # shift_vector = Σᵢ shift_idx[i] * cell[i]
                        shift_vec = torch.zeros(3, device=positions.device, dtype=positions.dtype)
                        for dim in range(3):
                            shift_vec += pair_shift_idx[dim].float() * cell_mat[dim]
                        
                        positions_pbc[i, j] += shift_vec
        
        return positions_pbc
    
    def _compute_geometry(
        self,
        center_positions: torch.Tensor,
        center_indices: torch.Tensor,
        neighbor_indices: torch.Tensor,
        neighbor_positions_pbc: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute geometric quantities for order parameters.
        
        Returns:
            vectors: (M, K, 3) normalized direction vectors
            distances: (M, K) distances
            thetas: (M, K) polar angles
            phis: (M, K) azimuthal angles
        """
        M, K = neighbor_indices.shape
        
        # Get center positions
        center_pos = center_positions[center_indices]  # (M, 3)
        
        # Compute vectors with PBC
        vectors = neighbor_positions_pbc - center_pos.unsqueeze(1)  # (M, K, 3)
        distances = torch.norm(vectors, dim=-1)  # (M, K)
        
        # Normalize with numerical stability
        vectors = vectors / (distances.unsqueeze(-1) + 1e-10)
        
        # Compute angles with stability
        cos_theta = vectors[..., 2].clamp(-1.0 + 1e-7, 1.0 - 1e-7)
        thetas = torch.acos(cos_theta)
        phis = torch.atan2(vectors[..., 1], vectors[..., 0])
        
        # Apply mask
        vectors = vectors * valid_mask.unsqueeze(-1).float()
        distances = distances * valid_mask.float()
        thetas = thetas * valid_mask.float()
        phis = phis * valid_mask.float()
        
        return vectors, distances, thetas, phis
    
    def _compute_cn(self, valid_mask: torch.Tensor) -> torch.Tensor:
        """Compute coordination number."""
        return valid_mask.sum(dim=1).float()
    
    def _compute_tetrahedral(
        self,
        vectors: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute tetrahedral order parameter (differentiable)."""
        params = self.default_params['tet']
        M, K, _ = vectors.shape
        
        # Compute pairwise angles with numerical stability
        dots = torch.einsum('mki,mji->mkj', vectors, vectors)
        dots = dots.clamp(-1.0 + 1e-7, 1.0 - 1e-7)  # Avoid exact -1 or 1
        angles = torch.acos(dots)
        
        # Target angle for tetrahedral
        target_angle = params['TA'] * math.pi
        delta_theta = params['delta_theta'] * math.pi / 180.0
        
        # Gaussian penalty with numerical stability
        angle_diff = angles - target_angle
        # Clamp exponent to prevent overflow
        exponent = -0.5 * (angle_diff / (delta_theta + 1e-10)) ** 2
        exponent = exponent.clamp(min=-50.0)  # Prevent underflow
        gaussian = torch.exp(exponent)
        
        # Mask valid pairs
        pair_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
        diag_mask = ~torch.eye(K, device=self.device, dtype=torch.bool).unsqueeze(0)
        full_mask = pair_mask & diag_mask
        
        # Sum and normalize
        qtet = (gaussian * full_mask.float()).sum(dim=(1, 2))
        n_neighbors = valid_mask.sum(dim=1).float().clamp(min=1.0)
        norm = n_neighbors * (n_neighbors - 1).clamp(min=1.0)
        
        return qtet / (norm + 1e-10)
    
    def _compute_octahedral(
        self,
        vectors: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute octahedral order parameter (differentiable)."""
        params = self.default_params['oct']
        M, K, _ = vectors.shape
        
        # Compute pairwise angles
        dots = torch.einsum('mki,mji->mkj', vectors, vectors)
        dots = dots.clamp(-1.0, 1.0)
        angles = torch.acos(dots)
        
        theta_thr = params['min_SPP']
        delta_theta1 = params['delta_theta1'] * math.pi / 180.0
        delta_theta2 = params['delta_theta2'] * math.pi / 180.0
        
        # South pole (180°)
        is_south = angles >= theta_thr
        angle_diff_180 = angles - math.pi
        gauss_180 = torch.exp(-0.5 * (angle_diff_180 / delta_theta1) ** 2)
        contrib_south = 3 * gauss_180 * is_south.float()
        
        # Equatorial (90°)
        is_equat = angles < theta_thr
        angle_diff_90 = angles - math.pi / 2
        gauss_90 = torch.exp(-0.5 * (angle_diff_90 / delta_theta2) ** 2)
        contrib_equat = gauss_90 * is_equat.float()
        
        # Mask valid pairs
        pair_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
        diag_mask = ~torch.eye(K, device=self.device, dtype=torch.bool).unsqueeze(0)
        full_mask = pair_mask & diag_mask
        
        total = ((contrib_south + contrib_equat) * full_mask.float()).sum(dim=(1, 2))
        
        # Normalize
        n_neighbors = valid_mask.sum(dim=1).float()
        norm = n_neighbors * (3 + (n_neighbors - 2) * (n_neighbors - 3))
        
        return total / (norm + 1e-10)
    
    def _compute_bcc(
        self,
        vectors: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute BCC order parameter (differentiable)."""
        params = self.default_params['bcc']
        M, K, _ = vectors.shape
        
        dots = torch.einsum('mki,mji->mkj', vectors, vectors)
        dots = dots.clamp(-1.0, 1.0)
        angles = torch.acos(dots)
        
        theta_thr = params['min_SPP']
        delta_theta = params['delta_theta'] * math.pi / 180.0
        
        is_south = angles >= theta_thr
        angle_diff_180 = angles - math.pi
        gauss_180 = torch.exp(-0.5 * (angle_diff_180 / delta_theta) ** 2)
        contrib_south = params['w_SPP'] * gauss_180 * is_south.float()
        
        pair_mask = valid_mask.unsqueeze(2) & valid_mask.unsqueeze(1)
        diag_mask = ~torch.eye(K, device=self.device, dtype=torch.bool).unsqueeze(0)
        full_mask = pair_mask & diag_mask
        
        total = (contrib_south * full_mask.float()).sum(dim=(1, 2))
        
        n_neighbors = valid_mask.sum(dim=1).float()
        norm = n_neighbors * (6 + (n_neighbors - 2) * (n_neighbors - 3))
        
        return total / (norm + 1e-10)
    
    def _compute_q2(
        self,
        thetas: torch.Tensor,
        phis: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute q2 bond orientational order parameter (differentiable)."""
        n_neighbors = valid_mask.sum(dim=1, keepdim=True).float()
        
        sin_t = torch.sin(thetas)
        cos_t = torch.cos(thetas)
        sin_t2 = sin_t ** 2
        cos_t2 = cos_t ** 2
        
        cos_2p = torch.cos(2 * phis)
        sin_2p = torch.sin(2 * phis)
        cos_p = torch.cos(phis)
        sin_p = torch.sin(phis)
        
        sqrt_15_2pi = math.sqrt(15 / (2 * math.pi))
        sqrt_5_pi = math.sqrt(5 / math.pi)
        
        acc = torch.zeros(thetas.shape[0], device=self.device)
        
        # Y_2_-2
        pre_y_2_2 = 0.25 * sqrt_15_2pi * sin_t2
        real = (pre_y_2_2 * cos_2p * valid_mask.float()).sum(dim=1)
        imag = -(pre_y_2_2 * sin_2p * valid_mask.float()).sum(dim=1)
        acc += real**2 + imag**2
        
        # Y_2_-1
        pre_y_2_1 = 0.5 * sqrt_15_2pi * sin_t * cos_t
        real = (pre_y_2_1 * cos_p * valid_mask.float()).sum(dim=1)
        imag = -(pre_y_2_1 * sin_p * valid_mask.float()).sum(dim=1)
        acc += real**2 + imag**2
        
        # Y_2_0
        real = (0.25 * sqrt_5_pi * (3 * cos_t2 - 1.0) * valid_mask.float()).sum(dim=1)
        acc += real**2
        
        # Y_2_1
        real = -(pre_y_2_1 * cos_p * valid_mask.float()).sum(dim=1)
        imag = -(pre_y_2_1 * sin_p * valid_mask.float()).sum(dim=1)
        acc += real**2 + imag**2
        
        # Y_2_2
        real = (pre_y_2_2 * cos_2p * valid_mask.float()).sum(dim=1)
        imag = (pre_y_2_2 * sin_2p * valid_mask.float()).sum(dim=1)
        acc += real**2 + imag**2
        
        return torch.sqrt(4 * math.pi * acc / (5 * n_neighbors.squeeze()**2 + 1e-10))
    
    def _compute_q4(
        self,
        thetas: torch.Tensor,
        phis: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute q4 bond orientational order parameter (differentiable)."""
        n_neighbors = valid_mask.sum(dim=1, keepdim=True).float()
        
        sin_t = torch.sin(thetas)
        cos_t = torch.cos(thetas)
        
        acc = torch.zeros(thetas.shape[0], device=self.device)
        
        # Simplified q4 calculation - full version would include all m=-4 to 4 terms
        # Y_4_0 term
        sqrt_1_pi = math.sqrt(1 / math.pi)
        pre = (3/16.0) * sqrt_1_pi * (35 * cos_t**4 - 30 * cos_t**2 + 3.0)
        real = (pre * valid_mask.float()).sum(dim=1)
        acc += real**2
        
        return torch.sqrt(4 * math.pi * acc / (9 * n_neighbors.squeeze()**2 + 1e-10))
    
    def _compute_q6(
        self,
        thetas: torch.Tensor,
        phis: torch.Tensor,
        valid_mask: torch.Tensor,
    ) -> torch.Tensor:
        """Compute q6 bond orientational order parameter (differentiable)."""
        n_neighbors = valid_mask.sum(dim=1, keepdim=True).float()
        
        cos_t = torch.cos(thetas)
        
        acc = torch.zeros(thetas.shape[0], device=self.device)
        
        # Y_6_0 term
        sqrt_13_pi = math.sqrt(13 / math.pi)
        pre = (1/32.0) * sqrt_13_pi * (
            231 * cos_t**6 - 315 * cos_t**4 + 105 * cos_t**2 - 5.0
        )
        real = (pre * valid_mask.float()).sum(dim=1)
        acc += real**2
        
        return torch.sqrt(4 * math.pi * acc / (13 * n_neighbors.squeeze()**2 + 1e-10))


# Example usage
if __name__ == "__main__":
    from pymatgen.core import Structure, Lattice
    
    # Setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    


    # Load the CIF file
    structure = Structure.from_file("/home/advaitgore/PycharmProjects/torchdisorder/testing/Li3PS4.cif")

    
    # Convert to SimState with float64 for better gradient stability
    state = ts.initialize_state(structure, device=device, dtype=torch.float64)
    
    # Initialize calculator with optimal cutoff
    # Note: P-S distance in this structure is ~5 Å, cutoff=5.5 captures 4 nearest S neighbors
    # Using cutoff=6.0+ will include additional periodic images (CN=6+)
    op_calc = TorchSimOrderParameters(cutoff=3.5, device=device)
    
    # Find P atoms
    p_mask = state.atomic_numbers == 15
    p_indices = torch.where(p_mask)[0]
    
    print(f"\nFound {len(p_indices)} P atoms at indices: {p_indices.tolist()}")
    
    # Compute order parameters (filter for S neighbors only for tetrahedral PS4)
    results = op_calc(
        state,
        p_indices,
        order_params=['cn', 'tet', 'oct', 'q4'],
        element_filter=[16]  # Only S neighbors (not P)
    )
    
    print("\nOrder Parameters:")
    for key, value in results.items():
        if value.numel() == 1:
            print(f"  {key}: {value.item():.4f}")
        else:
            print(f"  {key}: {value.tolist()}")
    
    # Test differentiability
    state.positions.requires_grad = True
    results = op_calc(state, p_indices, ['tet'], element_filter=[16])
    loss = results['tet'].sum()
    loss.backward()
    
    print(f"\nDifferentiability test:")
    print(f"  Gradient exists: {state.positions.grad is not None}")
    print(f"  Gradient shape: {state.positions.grad.shape}")
    if state.positions.grad is not None:
        grad_norm = state.positions.grad.norm()
        has_nan = torch.isnan(state.positions.grad).any()
        print(f"  Gradient norm: {grad_norm.item():.6f}")
        print(f"  Contains NaN: {has_nan.item()}")
        if has_nan:
            print("  Warning: NaN in gradients - check numerical stability!")
    
    print("\n" + "=" * 60)
    print("Expected Results for tetrahedral PS4:")
    print("  CN  = 4 (four S neighbors)")
    print("  tet ~ 0.9+ (high tetrahedral order)")
    print("  If CN=6, cutoff is too large (catching periodic images)")
    print("  If grad=NaN, use float64 dtype or check angle computations")
    print("=" * 60)


Using device: cuda

Found 4 P atoms at indices: [12, 13, 14, 15]

Order Parameters:
  cn: [4.0, 4.0, 4.0, 4.0]
  tet: [0.9489298978999229, 0.9489298978999231, 0.948929897899923, 0.9489298978999227]
  oct: [0.10837966861637124, 0.10837966861637131, 0.10837966861637124, 0.10837966861637141]
  q4: [0.13109755516052246, 0.13109755516052246, 0.13109755516052246, 0.13109755516052246]

Differentiability test:
  Gradient exists: True
  Gradient shape: torch.Size([32, 3])
  Gradient norm: 1.096701
  Contains NaN: False

Expected Results for tetrahedral PS4:
  CN  = 4 (four S neighbors)
  tet ~ 0.9+ (high tetrahedral order)
  If CN=6, cutoff is too large (catching periodic images)
  If grad=NaN, use float64 dtype or check angle computations


In [12]:
"""
Cooper-based Constrained Minimization Problem with Per-Atom Order Parameter Constraints
Uses TorchSimOrderParameters and JSON constraint files

"""

from __future__ import annotations
import json
import torch
import cooper
from pathlib import Path
from typing import Dict, List, Optional, Union
from collections import defaultdict
from ase.data import chemical_symbols

import torch_sim as ts
from cooper.penalty_coefficients import PenaltyCoefficient

# Import your existing modules
from torchdisorder.common.target_rdf import TargetRDFData

# Import order parameter calculator for constraints
from torchdisorder.engine.order_params import TorchSimOrderParameters


class ConstantPenalty(PenaltyCoefficient):
    """Constant penalty coefficient for AugmentedLagrangian."""

    def __init__(self, value: float, device='cuda'):
        super().__init__(init=torch.tensor(value, device=device))
        self.expects_constraint_features = False

    def __call__(self, constraint_features=None):
        """Return the constant penalty value."""
        return self.init


class StructureFactorCMPWithConstraints(cooper.ConstrainedMinimizationProblem):
    """
    Constrained Minimization Problem for RDF/Structure Factor matching
    with per-atom order parameter constraints.
    
    Uses Cooper library for augmented Lagrangian optimization with:
    - RDF/Structure factor matching objective
    - Per-atom order parameter constraints from JSON file
    - Environment-specific constraint targets
    
    Updated for torch_sim   with batched neighbor lists.
    
    Usage:
        >>> cmp = StructureFactorCMPWithConstraints(
        ...     model=xrd_model,
        ...     base_state=state,
        ...     target_vec=rdf_data,
        ...     constraints_file='my_structure_constraints.json',
        ...     loss_fn=loss_fn,
        ...     device='cuda'
        ... )
        >>> 
        >>> # In optimization loop
        >>> cmp_state = cmp.compute_cmp_state(state.positions, state.cell)
        >>> # Cooper handles the rest
        
    Note:
        The order parameter calculator automatically handles the new torch_sim 
        neighbor list API internally. No changes needed in your optimization loop.
    """
    
    def __init__(
        self,
        model,
        base_state: ts.SimState,
        target_vec: TargetRDFData,
        constraints_file: Union[str, Path],
        loss_fn,
        target_kind: str = 'S_Q',
        q_bins: Optional[torch.Tensor] = None,
        device: str = 'cuda',
        penalty_rho: float = 40.0,
        violation_type: str = 'soft'  # 'soft' or 'hard'
    ):
        super().__init__()
        
        self.model = model
        self.base_state = base_state
        self.target = target_vec.F_q_target
        self.target_uncert = target_vec.F_q_uncert
        self.kind = target_kind
        self.q_bins = q_bins
        self.loss_fn = loss_fn
        self.device = device
        self.violation_type = violation_type
        
        # Load constraints from JSON
        self.constraints_data = self._load_constraints(constraints_file)
        
        # Initialize order parameter calculator
        cutoff = self.constraints_data['cutoff']
        self.op_calc = TorchSimOrderParameters(cutoff=cutoff, device=str(device))
        
        # Detect placeholder region (where dF was originally zero)
        self.placeholder_mask = (self.target_uncert < 1e-6).to(device)
        n_placeholder = self.placeholder_mask.sum().item()
        print(f"Found {n_placeholder} placeholder points where dF = 1e-7")
        
        # Set up constraints based on JSON file
        self._setup_constraints(penalty_rho)
        
        print(f"\nInitialized StructureFactorCMPWithConstraints:")
        print(f"  Constrained atoms: {len(self.constraints_data['atom_constraints'])}")
        print(f"  Order parameters: {', '.join(self.constraints_data['metadata']['order_parameter_types'])}")
        print(f"  Total Cooper constraints: {len(self.constraint_dict)}")
    
    def _load_constraints(self, constraints_file: Union[str, Path]) -> Dict:
        """Load constraint specifications from JSON file."""
        with open(constraints_file, 'r') as f:
            constraints = json.load(f)
        return constraints
    
    def _setup_constraints(self, penalty_rho: float):
        """
        Set up Cooper constraint objects for all atom-constraint pairs.
        
        Strategy: Group constraints by order parameter type to reduce
        the number of Cooper constraint objects.
        """
        # Group atoms by order parameter type
        op_to_atoms = defaultdict(list)
        
        for atom_idx_str, atom_constraint in self.constraints_data['atom_constraints'].items():
            atom_idx = int(atom_idx_str)
            for op_name in atom_constraint['order_parameters'].keys():
                op_to_atoms[op_name].append(atom_idx)
        
        # Create one Cooper constraint per order parameter type
        self.constraint_dict = {}
        
        for op_name, atom_list in op_to_atoms.items():
            n_atoms = len(atom_list)
            
            # Create multiplier for all atoms with this order parameter
            multiplier = cooper.multipliers.DenseMultiplier(
                num_constraints=n_atoms,
                device=self.device
            )
            
            # Create penalty coefficient
            penalty_coeff = ConstantPenalty(penalty_rho, device=self.device)
            
            # Create constraint
            constraint = cooper.Constraint(
                multiplier=multiplier,
                constraint_type=cooper.ConstraintType.INEQUALITY,
                formulation_type=cooper.formulations.AugmentedLagrangian,
                penalty_coefficient=penalty_coeff,
            )
            
            self.constraint_dict[op_name] = {
                'constraint': constraint,
                'atom_indices': atom_list
            }
            
            print(f"  Created constraint for {op_name}: {n_atoms} atoms")
    
    def _compute_violations(
        self,
        op_results: Dict[str, torch.Tensor],
        constrained_atom_indices: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        """
        Compute constraint violations for all atoms.
        
        Returns:
            violations: {op_name: tensor of violations per atom}
        """
        violations = {}
        
        for op_name, constraint_info in self.constraint_dict.items():
            atom_list = constraint_info['atom_indices']
            op_violations = []
            
            for atom_idx in atom_list:
                atom_idx_str = str(atom_idx)
                atom_constraint = self.constraints_data['atom_constraints'][atom_idx_str]
                op_params = atom_constraint['order_parameters'][op_name]
                
                # Find position in constrained_atom_indices tensor
                pos = (constrained_atom_indices == atom_idx).nonzero(as_tuple=True)[0]
                if len(pos) == 0:
                    # Atom not in computed results (shouldn't happen)
                    op_violations.append(0.0)
                    continue
                
                pos = pos[0].item()
                value = op_results[op_name][pos].item()
                
                # Compute violation based on constraint type
                if 'min' in op_params and 'max' in op_params:
                    # Box constraint: violation if outside [min, max]
                    if self.violation_type == 'soft':
                        # Soft constraint: smooth penalty outside bounds
                        if value < op_params['min']:
                            violation = op_params['min'] - value
                        elif value > op_params['max']:
                            violation = value - op_params['max']
                        else:
                            violation = 0.0
                    else:
                        # Hard constraint: violation = distance to nearest bound
                        mid = (op_params['min'] + op_params['max']) / 2
                        if value < mid:
                            violation = max(0, op_params['min'] - value)
                        else:
                            violation = max(0, value - op_params['max'])
                else:
                    # Equality constraint with tolerance
                    target = op_params['target']
                    tolerance = op_params.get('tolerance', 0.1)
                    violation = max(0.0, abs(value - target) - tolerance)
                
                # Apply weight
                weight = op_params.get('weight', 1.0)
                op_violations.append(weight * violation)
            
            # Convert to tensor
            violations[op_name] = torch.tensor(op_violations, device=self.device, dtype=torch.float32)
        
        return violations
    
    def compute_cmp_state(
        self,
        positions: torch.Tensor,
        cell: torch.Tensor,
        step: Optional[int] = None
    ) -> cooper.CMPState:
        """
        Compute the constrained minimization problem state.
        
        Args:
            positions: Atomic positions
            cell: Unit cell
            step: Current optimization step (for logging)
        
        Returns:
            Cooper CMPState with loss and constraint violations
        """
        # Update base state
        self.base_state.positions = positions
        self.base_state.cell = cell
        
        # Forward pass through model
        out = self.model(self.base_state)
        
        # Get predicted S(Q) or T(r)
        pred_sq = out[self.kind].squeeze()
        
        # Clamp prediction to experimental value in placeholder region
        pred_sq_clamped = pred_sq.clone()
        pred_sq_clamped[self.placeholder_mask] = self.target[self.placeholder_mask]
        
        # Replace in output dict
        out[self.kind] = pred_sq_clamped.unsqueeze(0) if pred_sq_clamped.dim() == 1 else pred_sq_clamped
        
        # Compute chi-squared loss
        loss_dict = self.loss_fn(out)
        chi2_loss = loss_dict.get("chi2_scatt", loss_dict.get("chi2_corr", loss_dict.get("loss")))
        
        # ====================================================================
        # Compute order parameter constraints
        # ====================================================================
        
        # Get element filter and atom indices
        element_filter = self.constraints_data['element_filter']
        atom_mask = torch.zeros(len(self.base_state.atomic_numbers), dtype=torch.bool, device=self.device)
        for Z in element_filter:
            atom_mask |= (self.base_state.atomic_numbers == Z)
        
        constrained_atom_indices = torch.where(atom_mask)[0]
        
        # Determine which order parameters to compute
        all_op_types = set()
        for atom_constraint in self.constraints_data['atom_constraints'].values():
            all_op_types.update(atom_constraint['order_parameters'].keys())
        
        # Compute all needed order parameters at once
        op_results = self.op_calc(
            self.base_state,
            constrained_atom_indices,
            list(all_op_types),
            element_filter=element_filter
        )
        
        # Compute violations for each constraint group
        violations = self._compute_violations(op_results, constrained_atom_indices)
        
        # ====================================================================
        # Create Cooper constraint states
        # ====================================================================
        
        observed_constraints = {}
        
        for op_name, violation_tensor in violations.items():
            constraint = self.constraint_dict[op_name]['constraint']
            constraint_state = cooper.ConstraintState(violation=violation_tensor)
            observed_constraints[constraint] = constraint_state
        
        # ====================================================================
        # Create misc dict with diagnostics
        # ====================================================================
        
        # Compute violation statistics
        all_violations = torch.cat([v for v in violations.values()])
        n_violations = (all_violations > 1e-4).sum().item()
        avg_violation = all_violations.mean().item()
        max_violation = all_violations.max().item()
        
        misc = dict(
            Q=self.q_bins,
            Y=pred_sq_clamped,
            loss=chi2_loss,
            chi2_loss=chi2_loss,
            kind=self.kind,
            n_violations=n_violations,
            avg_violation=avg_violation,
            max_violation=max_violation,
            op_results={k: v.detach() for k, v in op_results.items()},
            step=step
        )
        
        # Add per-environment statistics if step is provided
        if step is not None and step % 50 == 0:
            env_stats = self._get_environment_statistics(op_results, constrained_atom_indices)
            misc['environment_stats'] = env_stats
        
        return cooper.CMPState(
            loss=chi2_loss,
            observed_constraints=observed_constraints,
            misc=misc
        )
    
    def _get_environment_statistics(
        self,
        op_results: Dict[str, torch.Tensor],
        constrained_atom_indices: torch.Tensor
    ) -> Dict:
        """
        Get detailed statistics about constraint satisfaction by environment type.
        """
        stats_by_env = defaultdict(lambda: {
            'total': 0,
            'satisfied': 0,
            'violations': []
        })
        
        for atom_idx in constrained_atom_indices.tolist():
            atom_idx_str = str(atom_idx)
            if atom_idx_str not in self.constraints_data['atom_constraints']:
                continue
            
            atom_constraint = self.constraints_data['atom_constraints'][atom_idx_str]
            env_type = atom_constraint.get('environment_type', 'unknown')
            
            stats_by_env[env_type]['total'] += 1
            
            # Check if all constraints satisfied for this atom
            pos = (constrained_atom_indices == atom_idx).nonzero(as_tuple=True)[0][0].item()
            
            has_violation = False
            for op_name, op_params in atom_constraint['order_parameters'].items():
                if op_name not in op_results:
                    continue
                
                value = op_results[op_name][pos].item()
                
                # Check violation
                if 'min' in op_params and 'max' in op_params:
                    if value < op_params['min'] or value > op_params['max']:
                        has_violation = True
                        violation = min(op_params['min'] - value, value - op_params['max'])
                        stats_by_env[env_type]['violations'].append(abs(violation))
                else:
                    target = op_params['target']
                    tolerance = op_params.get('tolerance', 0.1)
                    if abs(value - target) > tolerance:
                        has_violation = True
                        stats_by_env[env_type]['violations'].append(abs(value - target))
            
            if not has_violation:
                stats_by_env[env_type]['satisfied'] += 1
        
        # Compute satisfaction rates
        results = {}
        for env_type, stats in stats_by_env.items():
            satisfied_pct = 100.0 * stats['satisfied'] / stats['total'] if stats['total'] > 0 else 0
            avg_viol = sum(stats['violations']) / len(stats['violations']) if stats['violations'] else 0
            
            results[env_type] = {
                'total_atoms': stats['total'],
                'satisfied_atoms': stats['satisfied'],
                'satisfaction_rate': satisfied_pct,
                'avg_violation': avg_viol,
                'max_violation': max(stats['violations']) if stats['violations'] else 0
            }
        
        return results


# ============================================================================
# Backward-compatible wrapper 
# ============================================================================

def create_cmp_with_constraints(
    model,
    base_state: ts.SimState,
    target_vec: TargetRDFData,
    constraints_file: str,
    loss_fn,
    target_kind: str = 'S_Q',
    device: str = 'cuda',
    penalty_rho: float = 40.0,
    **kwargs
) -> StructureFactorCMPWithConstraints:
    """
    Factory function to create CMP with constraints.
    
    Drop-in replacement for your old StructureFactorCMP initialization.
    
    Example:
        >>> # OLD CODE:
        >>> # cmp = StructureFactorCMP(
        >>> #     model, base_state, target_vec, 'S_Q', q_bins, loss_fn,
        >>> #     q_threshold=0.7, device='cuda', penalty_rho=40.0
        >>> # )
        >>> 
        >>> # NEW CODE:
        >>> cmp = create_cmp_with_constraints(
        ...     model, base_state, target_vec,
        ...     constraints_file='my_structure_constraints.json',
        ...     loss_fn=loss_fn,
        ...     target_kind='S_Q',
        ...     device='cuda',
        ...     penalty_rho=40.0
        ... )
    """
    return StructureFactorCMPWithConstraints(
        model=model,
        base_state=base_state,
        target_vec=target_vec,
        constraints_file=constraints_file,
        loss_fn=loss_fn,
        target_kind=target_kind,
        device=device,
        penalty_rho=penalty_rho,
        **kwargs
    )


# ============================================================================
# Example usage
# ============================================================================

def example_usage():
    """Example showing how to use with Cooper."""
    
    print("="*70)
    print("EXAMPLE: Cooper CMP with Per-Atom Constraints")
    print("="*70)
    
    # Your existing setup (unchanged)
    import torch_sim as ts
    from pymatgen.core import Structure
    
    structure = Structure.from_file('/home/advaitgore/PycharmProjects/torchdisorder/testing/Li3PS4.cif')
    state = ts.initialize_state(structure, device='cuda', dtype=torch.float64)
    
    # Your model and RDF data (unchanged)
    # model = YourXRDModel(...)
    # rdf_data = TargetRDFData(...)
    # loss_fn = your_loss_function
    
    # ===================================================================
    # CHANGE: Use new CMP with constraints from JSON
    # ===================================================================
    
    # OLD CODE (commented out):
    # cmp = StructureFactorCMP(
    #     model, state, rdf_data, 'S_Q', q_bins, loss_fn,
    #     q_threshold=0.7, device='cuda', penalty_rho=40.0
    # )
    
    # NEW CODE:
    cmp = StructureFactorCMPWithConstraints(
        model=model,
        base_state=state,
        target_vec=rdf_data,
        constraints_file='glass_structure_constraints.json',  # Generated by glass_generator.py
        loss_fn=loss_fn,
        target_kind='S_Q',
        device='cuda',
        penalty_rho=40.0
    )
    
    # Rest of Cooper setup (unchanged)
    formulation = cooper.formulations.AugmentedLagrangian()
    constrained_optimizer = cooper.optim.ConstrainedOptimizer(
        optimizer_class=torch.optim.Adam,
        params=[state.positions],
        lr=0.01,
        constrained_optimizer_kwargs={'primal_optimizer_kwargs': {'lr': 0.01}}
    )
    
    # Optimization loop (unchanged except no manual q_tet)
    for iteration in range(1000):
        # Compute CMP state (order parameters computed automatically!)
        cmp_state = cmp.compute_cmp_state(
            state.positions,
            state.cell,
            step=iteration
        )
        
        # Cooper optimization step
        constrained_optimizer.zero_grad()
        
        lagrangian = formulation.composite_objective(
            cmp_state.loss,
            cmp_state.observed_constraints
        )
        
        formulation.custom_backward(lagrangian)
        constrained_optimizer.step()
        
        # Logging
        if iteration % 10 == 0:
            misc = cmp_state.misc
            print(f"\nIter {iteration}:")
            print(f"  Loss: {cmp_state.loss.item():.6f}")
            print(f"  Violations: {misc['n_violations']}")
            print(f"  Avg: {misc['avg_violation']:.4f}")
            print(f"  Max: {misc['max_violation']:.4f}")
        
        # Detailed environment statistics
        if iteration % 50 == 0 and 'environment_stats' in cmp_state.misc:
            print("\n  Environment satisfaction:")
            for env, stats in cmp_state.misc['environment_stats'].items():
                print(f"    {env}: {stats['satisfaction_rate']:.1f}%")
    
    print("\nOptimization complete!")


if __name__ == "__main__":
    example_usage()


EXAMPLE: Cooper CMP with Per-Atom Constraints


NameError: name 'model' is not defined