In [3]:
import jax
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"JAX default backend: {jax.default_backend()}")

JAX version: 0.8.0


ERROR:2025-11-01 13:45:27,189:jax._src.xla_bridge:473: Jax plugin configuration error: Exception when calling jax_plugins.xla_cuda12.initialize()
Traceback (most recent call last):
  File "/scicore/home/meuwly/boitti0000/mmml/.venv/lib/python3.12/site-packages/jax/_src/xla_bridge.py", line 471, in discover_pjrt_plugins
    plugin_module.initialize()
  File "/scicore/home/meuwly/boitti0000/mmml/.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 328, in initialize
    _check_cuda_versions(raise_on_first_error=True)
  File "/scicore/home/meuwly/boitti0000/mmml/.venv/lib/python3.12/site-packages/jax_plugins/xla_cuda12/__init__.py", line 285, in _check_cuda_versions
    local_device_count = cuda_versions.cuda_device_count()
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: jaxlib/cuda/versions_helpers.cc:113: operation cuInit(0) failed: Unknown CUDA error 303; cuGetErrorName failed. This probably means that JAX was unable to load the CUDA lib

JAX devices: [CpuDevice(id=0)]
JAX default backend: cpu


In [4]:
import e3x

In [5]:
# Utilities to compute RDF (pair correlation), bond-angle distributions,
# and a simple three-body distribution surface from ASE trajectories.
#
# How to use after running this cell:
#   frames = load_frames("your_file.traj")           # or .xyz, .pdb, etc.
#   rdf = compute_rdf(frames, r_max=10.0, bin_width=0.05)
#   plot_rdf(rdf)
#   ang = compute_bond_angle_distribution(frames, r_cut=3.5)
#   plot_angles(ang)
#   tbd = compute_three_body_surface(frames, r_cut=3.5)
#   plot_three_body_projections(tbd, r_slice=2.8)   # or theta_slice=...
#   write_traj(frames, "output.traj")               # save frames to ASE .traj

from typing import Iterable, List, Optional, Tuple, Dict, Union
import numpy as np
from math import pi
import matplotlib.pyplot as plt
from dataclasses import dataclass
from ase import Atoms
from ase.io import read, Trajectory
from ase.neighborlist import neighbor_list

FrameSource = Union[str, Atoms, List[Atoms], Trajectory]

def load_frames(source: FrameSource) -> List[Atoms]:
    if isinstance(source, Atoms):
        return [source.copy()]
    if isinstance(source, list):
        return [f.copy() for f in source]
    if isinstance(source, Trajectory):
        return [f.copy() for f in source]
    if isinstance(source, str):
        return read(source, index=":")
    raise TypeError("Unsupported frame source type")

def write_traj(source: FrameSource, out_path: str) -> None:
    frames = load_frames(source)
    with Trajectory(out_path, "w") as traj:
        for f in frames:
            traj.write(f)

def _number_density(frame: Atoms) -> float:
    vol = frame.get_volume()
    n = len(frame)
    if vol is None or vol <= 0:
        raise ValueError("Frame has non-positive volume; set a valid periodic cell for normalization.")
    return n / vol

@dataclass
class RDFResult:
    r_edges: np.ndarray
    r_centers: np.ndarray
    g_r: np.ndarray
    counts: np.ndarray
    frames_used: int
    rho_mean: float
    n_atoms_mean: float

def compute_rdf(source: FrameSource,
                r_max: float,
                bin_width: float = 0.05,
                element_pairs: Optional[List[Tuple[str, str]]] = None
                ) -> Union[RDFResult, Dict[Tuple[str, str], RDFResult]]:
    frames = load_frames(source)
    if len(frames) == 0:
        raise ValueError("No frames to process.")
    r_edges = np.arange(0.0, r_max + bin_width, bin_width)
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])
    shell = 4.0 * pi * (r_centers ** 2) * bin_width

    if element_pairs is None:
        accum_counts = np.zeros_like(r_centers, dtype=float)
        accum_norm = 0.0
        frames_used = 0
        for frame in frames:
            rho = _number_density(frame)
            N = len(frame)
            i, j, d = neighbor_list("ijd", frame, r_max, self_interaction=False, bothways=True)
            counts, _ = np.histogram(d[d > 0], bins=r_edges)
            accum_counts += counts
            accum_norm += N * rho
            frames_used += 1
        g_r = accum_counts / (accum_norm * shell + 1e-30)
        return RDFResult(r_edges, r_centers, g_r, accum_counts, frames_used,
                         accum_norm / max(frames_used, 1),
                         sum(len(f) for f in frames) / max(frames_used, 1))

    results: Dict[Tuple[str, str], RDFResult] = {}
    symbols_all = [f.get_chemical_symbols() for f in frames]
    for (A, B) in element_pairs:
        accum_counts = np.zeros_like(r_centers, dtype=float)
        accum_norm = 0.0
        frames_used = 0
        for frame, syms in zip(frames, symbols_all):
            vol = frame.get_volume()
            nB = syms.count(B)
            if nB == 0 or vol <= 0:
                continue
            rho_B = nB / vol
            mask_A = np.array([s == A for s in syms])
            mask_B = np.array([s == B for s in syms])
            if not (mask_A.any() and mask_B.any()):
                continue
            i, j, d = neighbor_list("ijd", frame, r_max, self_interaction=False, bothways=True)
            sel = mask_A[i] & mask_B[j]
            d_sel = d[sel]
            counts, _ = np.histogram(d_sel[d_sel > 0], bins=r_edges)
            accum_counts += counts
            accum_norm += mask_A.sum() * rho_B
            frames_used += 1
        g_r = accum_counts / (accum_norm * shell + 1e-30)
        results[(A, B)] = RDFResult(r_edges, r_centers, g_r, accum_counts, frames_used,
                                    None if frames_used == 0 else accum_norm / frames_used,
                                    None)
    return results

def plot_rdf(rdf: Union[RDFResult, Dict[Tuple[str, str], RDFResult]], title: Optional[str] = None) -> None:
    plt.figure()
    if isinstance(rdf, RDFResult):
        plt.plot(rdf.r_centers, rdf.g_r)
        plt.xlabel("r (Å)"); plt.ylabel("g(r)")
        plt.title(title or "Radial Distribution Function")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show(); return
    for k, res in rdf.items():
        label = f"g_{k[0]}{k[1]}(r)"
        plt.plot(res.r_centers, res.g_r, label=label)
    plt.xlabel("r (Å)"); plt.ylabel("g(r)")
    plt.title(title or "Partial Radial Distribution Functions")
    plt.legend()
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()

@dataclass
class AngleDistribution:
    theta_edges: np.ndarray
    theta_centers: np.ndarray
    histogram: np.ndarray
    frames_used: int
    note: str = "Angles at central atom i among neighbors j,k within r_cut"

def compute_bond_angle_distribution(source: FrameSource,
                                    r_cut: float,
                                    n_theta_bins: int = 180,
                                    center_element: Optional[str] = None,
                                    neighbor_element: Optional[str] = None
                                    ) -> AngleDistribution:
    frames = load_frames(source)
    theta_edges = np.linspace(0.0, pi, n_theta_bins + 1)
    theta_centers = 0.5 * (theta_edges[:-1] + theta_edges[1:])
    hist = np.zeros(n_theta_bins, dtype=float)
    frames_used = 0
    for frame in frames:
        syms = frame.get_chemical_symbols()
        mask_center = np.array([True] * len(frame)) if center_element is None else np.array([s == center_element for s in syms])
        mask_nei = np.array([True] * len(frame)) if neighbor_element is None else np.array([s == neighbor_element for s in syms])
        i_idx, j_idx, S = neighbor_list("ijS", frame, r_cut, self_interaction=False, bothways=True)
        positions = frame.get_positions(); cell = frame.get_cell()
        disp = positions[j_idx] + np.dot(S, cell) - positions[i_idx]
        from collections import defaultdict
        per_i = defaultdict(list)
        for i, vec, j in zip(i_idx, disp, j_idx):
            if mask_center[i] and mask_nei[j]:
                per_i[i].append(vec)
        for vecs in per_i.values():
            if len(vecs) < 2: continue
            V = np.array(vecs)
            norms = np.linalg.norm(V, axis=1)
            keep = norms > 1e-12
            V = V[keep]
            if V.shape[0] < 2: continue
            U = V / norms[keep][:, None]
            m = U.shape[0]
            dots = np.clip(U @ U.T, -1.0, 1.0)
            iu = np.triu_indices(m, k=1)
            thetas = np.arccos(dots[iu])
            h, _ = np.histogram(thetas, bins=theta_edges)
            hist += h
        frames_used += 1
    return AngleDistribution(theta_edges, theta_centers, hist, frames_used)

def plot_angles(ang: AngleDistribution, title: Optional[str] = None, density: bool = True) -> None:
    plt.figure()
    y = ang.histogram.astype(float)
    if density and y.sum() > 0:
        y = y / (y.sum() * (ang.theta_edges[1] - ang.theta_edges[0]))
        ylabel = "P(θ)"
    else:
        ylabel = "Counts"
    plt.step(ang.theta_centers, y, where="mid")
    plt.xlabel("θ (rad)"); plt.ylabel(ylabel)
    plt.title(title or "Bond-Angle Distribution")
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()

@dataclass
class ThreeBodySurface:
    r_edges: np.ndarray
    theta_edges: np.ndarray
    histogram: np.ndarray  # shape (n_r_bins, n_theta_bins)
    r_centers: np.ndarray
    theta_centers: np.ndarray
    frames_used: int
    note: str = "Counts of triplets per center i: pair (i,j) at r, angle jik = θ; neighbors within r_cut"

def compute_three_body_surface(source: FrameSource,
                               r_cut: float,
                               n_r_bins: int = 60,
                               n_theta_bins: int = 90,
                               center_element: Optional[str] = None,
                               neighbor_element: Optional[str] = None
                               ) -> ThreeBodySurface:
    frames = load_frames(source)
    r_edges = np.linspace(0.0, r_cut, n_r_bins + 1)
    theta_edges = np.linspace(0.0, pi, n_theta_bins + 1)
    r_centers = 0.5 * (r_edges[:-1] + r_edges[1:])
    theta_centers = 0.5 * (theta_edges[:-1] + theta_edges[1:])
    H = np.zeros((n_r_bins, n_theta_bins), dtype=float)
    frames_used = 0
    for frame in frames:
        syms = frame.get_chemical_symbols()
        mask_center = np.array([True] * len(frame)) if center_element is None else np.array([s == center_element for s in syms])
        mask_nei = np.array([True] * len(frame)) if neighbor_element is None else np.array([s == neighbor_element for s in syms])
        i_idx, j_idx, S = neighbor_list("ijS", frame, r_cut, self_interaction=False, bothways=True)
        positions = frame.get_positions(); cell = frame.get_cell()
        disp = positions[j_idx] + np.dot(S, cell) - positions[i_idx]
        dists = np.linalg.norm(disp, axis=1)
        from collections import defaultdict
        per_i_vecs, per_i_d = defaultdict(list), defaultdict(list)
        for i, vec, d, j in zip(i_idx, disp, dists, j_idx):
            if mask_center[i] and mask_nei[j]:
                per_i_vecs[i].append(vec); per_i_d[i].append(d)
        for i in per_i_vecs:
            vecs = np.array(per_i_vecs[i]); ds = np.array(per_i_d[i])
            if len(ds) < 2: continue
            norms = np.linalg.norm(vecs, axis=1)
            keep = norms > 1e-12
            vecs = vecs[keep]; ds = ds[keep]
            if len(ds) < 2: continue
            U = vecs / norms[keep][:, None]
            m = U.shape[0]
            for a in range(m):
                r = ds[a]
                if r <= 0: continue
                dots = np.clip(U[a:a+1] @ U.T, -1.0, 1.0).ravel()
                thetas = np.arccos(dots)
                thetas = np.delete(thetas, a)
                r_bin = np.searchsorted(r_edges, r, side="right") - 1
                if 0 <= r_bin < H.shape[0]:
                    h_theta, _ = np.histogram(thetas, bins=theta_edges)
                    H[r_bin, :] += h_theta
        frames_used += 1
    return ThreeBodySurface(r_edges, theta_edges, H, r_centers, theta_centers, frames_used)

def plot_three_body_projections(tbd: ThreeBodySurface,
                                r_slice: Optional[float] = None,
                                theta_slice: Optional[float] = None
                                ) -> None:
    plt.figure()
    if r_slice is not None:
        r_bin = np.searchsorted(tbd.r_edges, r_slice, side="right") - 1
        r_bin = np.clip(r_bin, 0, tbd.histogram.shape[0]-1)
        y = tbd.histogram[r_bin, :].astype(float)
        if y.sum() > 0:
            y = y / (y.sum() * (tbd.theta_edges[1] - tbd.theta_edges[0]))
        plt.plot(tbd.theta_centers, y)
        plt.xlabel("θ (rad)"); plt.ylabel(f"P(θ | r≈{tbd.r_centers[r_bin]:.2f} Å)")
        plt.title("Three-Body: θ profile at fixed r")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show(); return
    if theta_slice is not None:
        th_bin = np.searchsorted(tbd.theta_edges, theta_slice, side="right") - 1
        th_bin = np.clip(th_bin, 0, tbd.histogram.shape[1]-1)
        y = tbd.histogram[:, th_bin].astype(float)
        dr = tbd.r_edges[1] - tbd.r_edges[0]
        if y.sum() > 0:
            y = y / (y.sum() * dr)
        plt.plot(tbd.r_centers, y)
        plt.xlabel("r (Å)"); plt.ylabel(f"P(r | θ≈{tbd.theta_centers[th_bin]:.2f} rad)")
        plt.title("Three-Body: radial profile at fixed θ")
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
        plt.show(); return
    y = tbd.histogram.sum(axis=0).astype(float)
    dtheta = tbd.theta_edges[1] - tbd.theta_edges[0]
    if y.sum() > 0:
        y = y / (y.sum() * dtheta)
    plt.plot(tbd.theta_centers, y)
    plt.xlabel("θ (rad)"); plt.ylabel("P(θ)  (marginal over r)")
    plt.title("Three-Body: θ marginal")
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.5)
    plt.show()

print("Loaded: RDF, angle distribution, three-body surface; and write_traj().")


Loaded: RDF, angle distribution, three-body surface; and write_traj().


In [None]:
frames = load_frames("simulation.traj")   # or "water.xyz" etc.

rdf = compute_rdf(frames, r_max=10.0, bin_width=0.05)
plot_rdf(rdf)

# Partial RDFs (e.g., water)
partials = compute_rdf(frames, r_max=10.0, bin_width=0.05, element_pairs=[("O","O"), ("O","H"), ("H","H")])
plot_rdf(partials)

ang = compute_bond_angle_distribution(frames, r_cut=3.3, center_element="O", neighbor_element="H")
plot_angles(ang)

tbd = compute_three_body_surface(frames, r_cut=3.3, n_r_bins=60, n_theta_bins=90, center_element="O", neighbor_element="H")
plot_three_body_projections(tbd, r_slice=2.8)   # θ profile at ~2.8 Å
# or:
plot_three_body_projections(tbd, theta_slice=1.91)  # r profile at θ≈109.5°

# Convert any input to ASE .traj
write_traj(frames, "output.traj")


In [7]:
#!/usr/bin/env python3
"""
Enumerate relative orientations between two dimers (treated as rigid bodies).

Given two input structures (XYZ files) representing dimers A and B, this
program samples the orientation space and generates transformed coordinates
for dimer B around dimer A. Optionally it also rotates dimer A. Results can
be filtered by a minimum interatomic distance and deduplicated by RMSD.

Usage (examples):

  # Basic: place B around A on a 6.5 Å shell, 200 directions, 24 twists
  python enumerate_dimer_orientations.py A.xyz B.xyz --r 6.5 \
      --n_dirs 200 --n_twists 24 --out out_dir

  # Scan a range of separations (Å), ensure 2.0 Å minimum interatomic distance
  python enumerate_dimer_orientations.py A.xyz B.xyz --r 5.0 8.0 --n_dirs 400 \
      --n_twists 36 --r_step 0.5 --min_dist 2.0 --out placements

  # Also rotate A (useful when A is not spherically symmetric)
  python enumerate_dimer_orientations.py A.xyz B.xyz --r 6.0 --rotate_A \
      --n_dirs 100 --n_twists 12 --out grid --c2_B

Notes
-----
* "All orientations" in continuous space is infinite; here we cover it by a
  uniform-ish grid: a Fibonacci sphere for placement directions and a uniform
  twist (0..2π) around the inter-center axis. You can increase n_dirs and
  n_twists for finer coverage.
* If your dimers have C2 symmetry, pass --c2_A and/or --c2_B to prune
  redundant twists (saves ~×2 for perfect C2 symmetry around the dimer axis).
* Output files are named like: orient_r{R}_k{dirIdx}_t{twistIdx}[...].xyz

Author: (you)
"""

from __future__ import annotations
import argparse
import math
import os
import sys
from dataclasses import dataclass
from typing import List, Tuple, Iterable

import numpy as np

# --------------------------- IO utilities ----------------------------------

def read_xyz(path: str) -> Tuple[np.ndarray, List[str]]:
    """Read a simple XYZ file.

    Returns
    -------
    coords : (N, 3) float array
    elems  : list[str]
    """
    with open(path, 'r') as f:
        lines = [l.rstrip() for l in f]
    try:
        n = int(lines[0].strip())
    except Exception as e:
        raise ValueError(f"{path}: invalid XYZ header: {e}")
    body = lines[2:2+n]
    elems = []
    coords = []
    for i, line in enumerate(body):
        parts = line.split()
        if len(parts) < 4:
            raise ValueError(f"{path}: line {i+3} has <4 fields")
        elems.append(parts[0])
        coords.append([float(parts[1]), float(parts[2]), float(parts[3])])
    return np.array(coords, dtype=float), elems


def write_xyz(path: str, elems: List[str], coords: np.ndarray, comment: str = "") -> None:
    with open(path, 'w') as f:
        f.write(f"{len(elems)}\n")
        f.write(comment + "\n")
        for el, (x, y, z) in zip(elems, coords):
            f.write(f"{el:2s} {x: .8f} {y: .8f} {z: .8f}\n")

# --------------------------- Chemistry bits --------------------------------

ATOMIC_MASS = {
    'H': 1.00784, 'C': 12.0107, 'N': 14.0067, 'O': 15.9994, 'F': 18.9984,
    'P': 30.9738, 'S': 32.065, 'Cl': 35.453, 'Br': 79.904, 'I': 126.90447
}


def center_of_mass(coords: np.ndarray, elems: List[str]) -> np.ndarray:
    m = np.array([ATOMIC_MASS.get(e, 12.0) for e in elems])
    total = m.sum()
    return (coords * m[:, None]).sum(axis=0) / total


def translate(coords: np.ndarray, vec: np.ndarray) -> np.ndarray:
    return coords + vec[None, :]

# --------------------------- Rotations -------------------------------------

@dataclass
class Rigid:
    coords: np.ndarray  # (N,3)
    elems: List[str]

    @property
    def com(self) -> np.ndarray:
        return center_of_mass(self.coords, self.elems)

    def centered(self) -> "Rigid":
        return Rigid(self.coords - self.com[None, :], self.elems)

    def rotate(self, R: np.ndarray) -> "Rigid":
        return Rigid(self.coords @ R.T, self.elems)

    def translate(self, v: np.ndarray) -> "Rigid":
        return Rigid(self.coords + v[None, :], self.elems)


def rot_from_axis_angle(axis: np.ndarray, theta: float) -> np.ndarray:
    """Rodrigues rotation matrix."""
    a = np.asarray(axis, dtype=float)
    n = np.linalg.norm(a)
    if n == 0:
        return np.eye(3)
    k = a / n
    K = np.array([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]])
    I = np.eye(3)
    return I + math.sin(theta) * K + (1 - math.cos(theta)) * (K @ K)


def fibonacci_sphere(n: int) -> np.ndarray:
    """Quasi-uniform points on the unit sphere (north-south symmetric)."""
    if n < 2:
        return np.array([[0, 0, 1]])
    # Use the method with i from 0..n-1, y in (-1,1)
    i = np.arange(n)
    y = 1 - (2*i + 1)/n  # avoid poles exactly
    r = np.sqrt(np.maximum(0.0, 1 - y*y))
    golden = (1 + 5 ** 0.5) / 2
    phi = 2 * math.pi * i / (golden)
    x = r * np.cos(phi)
    z = r * np.sin(phi)
    pts = np.stack([x, y, z], axis=1)
    # Normalize for safety
    pts /= np.linalg.norm(pts, axis=1, keepdims=True)
    return pts

# ------------------------- Geometry helpers --------------------------------

def min_interatomic_distance(A: Rigid, B: Rigid) -> float:
    # naive O(N*M); for dimers this is fine
    dmin = np.inf
    for a in A.coords:
        d = np.linalg.norm(B.coords - a[None, :], axis=1).min()
        if d < dmin:
            dmin = d
    return float(dmin)


def rmsd(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.sqrt(((a - b) ** 2).sum(axis=1).mean()))

# --------------------------- Enumeration -----------------------------------

def enumerate_orientations(
    A0: Rigid,
    B0: Rigid,
    r_values: Iterable[float],
    n_dirs: int,
    n_twists: int,
    rotate_A: bool = False,
    c2_A: bool = False,
    c2_B: bool = False,
    min_dist: float | None = None,
    dedup_rmsd: float | None = None,
) -> List[Tuple[float, int, int, Rigid, Rigid]]:
    """Generate placements.

    Returns a list of tuples: (r, dir_idx, twist_idx, A, B)
    where A and B are placed in space (absolute coordinates).
    """
    A = A0.centered()
    B = B0.centered()

    # Pre-rotate A if requested: enumerate same directions/twists for A and B
    dirs = fibonacci_sphere(n_dirs)

    results = []
    used = []  # for deduplication of B coordinates only (translated frame)

    for r in r_values:
        for k, u in enumerate(dirs):
            # place COM of B at vector r*u relative to A COM (A COM at origin)
            for t in range(n_twists):
                # twist angle about axis u
                theta = 2 * math.pi * t / n_twists

                # If C2 symmetry: half the twists are redundant
                if c2_B and t >= n_twists // 2:
                    continue

                R_twist = rot_from_axis_angle(u, theta)

                A_rot = A
                if rotate_A:
                    # rotate A by the same twist to balance sampling
                    # (you could also sample a second independent twist)
                    A_rot = A.rotate(R_twist)
                    if c2_A and t >= n_twists // 2:
                        continue

                B_rot = B.rotate(R_twist)
                B_placed = B_rot.translate(u * r)

                # clash filter
                if min_dist is not None:
                    if min_interatomic_distance(A_rot, B_placed) < min_dist:
                        continue

                # deduplicate by RMSD on B in its absolute placement (up to translation)
                if dedup_rmsd is not None:
                    # Use coordinates of B relative to A's COM at origin
                    coords_B = B_placed.coords
                    is_dup = False
                    for prev in used:
                        if rmsd(coords_B, prev) < dedup_rmsd:
                            is_dup = True
                            break
                    if is_dup:
                        continue
                    used.append(coords_B.copy())

                results.append((r, k, t, A_rot, B_placed))

    return results

# --------------------------- CLI -------------------------------------------

def frange(start: float, stop: float, step: float) -> List[float]:
    if step <= 0:
        return [start]
    out = []
    x = start
    # include stop if nearly integer number of steps
    while x <= stop + 1e-9:
        out.append(round(x, 6))
        x += step
    return out


def main(argv: List[str] | None = None) -> int:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument('A', help='XYZ file for dimer A')
    p.add_argument('B', help='XYZ file for dimer B')
    p.add_argument('--out', required=True, help='Output directory for enumerated XYZ files')
    p.add_argument('--r', nargs='+', type=float, required=True,
                   help='Separation(s) in Å. Provide 1 value, or 2 values with optional --r_step.')
    p.add_argument('--r_step', type=float, default=0.0, help='Step for scanning radii when two values are given to --r')
    p.add_argument('--n_dirs', type=int, default=200, help='Number of directions on the sphere (Fibonacci)')
    p.add_argument('--n_twists', type=int, default=24, help='Number of twist angles around the axis (0..2π)')
    p.add_argument('--rotate_A', action='store_true', help='Also rotate dimer A by the twist')
    p.add_argument('--c2_A', action='store_true', help='Assume C2 symmetry for A to prune redundant twists')
    p.add_argument('--c2_B', action='store_true', help='Assume C2 symmetry for B to prune redundant twists')
    p.add_argument('--min_dist', type=float, default=None, help='Reject placements with any interatomic distance below this Å')
    p.add_argument('--dedup_rmsd', type=float, default=None, help='RMSD (Å) threshold to deduplicate similar B placements')

    args = p.parse_args(argv)

    # radii
    if len(args.r) == 1:
        r_values = [args.r[0]]
    elif len(args.r) >= 2:
        r0, r1 = args.r[0], args.r[1]
        step = args.r_step if args.r_step > 0 else (r1 - r0)
        if step <= 0:
            step = abs(r1 - r0)
        r_values = frange(min(r0, r1), max(r0, r1), step)
    else:
        p.error('--r requires at least one value')
        return 2

    # load systems
    A_coords, A_elems = read_xyz(args.A)
    B_coords, B_elems = read_xyz(args.B)
    A = Rigid(A_coords, A_elems)
    B = Rigid(B_coords, B_elems)

    os.makedirs(args.out, exist_ok=True)

    placements = enumerate_orientations(
        A, B,
        r_values=r_values,
        n_dirs=args.n_dirs,
        n_twists=args.n_twists,
        rotate_A=args.rotate_A,
        c2_A=args.c2_A,
        c2_B=args.c2_B,
        min_dist=args.min_dist,
        dedup_rmsd=args.dedup_rmsd,
    )

    # write outputs
    count = 0
    for (r, k, t, A_out, B_out) in placements:
        fname = f"orient_r{r:.2f}_k{k:04d}_t{t:03d}.xyz"
        path = os.path.join(args.out, fname)
        elems = A_out.elems + B_out.elems
        coords = np.vstack([A_out.coords, B_out.coords])
        comment = f"r={r:.3f} dir={k} twist={t}"
        write_xyz(path, elems, coords, comment)
        count += 1

    print(f"Wrote {count} placements to {args.out}")
    return 0


# if __name__ == '__main__':
#     sys.exit(main())


In [10]:
class ARG:
    def __init__(self):
        self.r = 4
        self.out = 

usage: ipykernel_launcher.py [-h] --out OUT --r R [R ...] [--r_step R_STEP]
                             [--n_dirs N_DIRS] [--n_twists N_TWISTS]
                             [--rotate_A] [--c2_A] [--c2_B]
                             [--min_dist MIN_DIST] [--dedup_rmsd DEDUP_RMSD]
                             A B
ipykernel_launcher.py: error: the following arguments are required: A, B


SystemExit: 2

In [None]:
if len(args.r) == 1:
    r_values = [args.r[0]]
elif len(args.r) >= 2:
    r0, r1 = args.r[0], args.r[1]
    step = args.r_step if args.r_step > 0 else (r1 - r0)
    if step <= 0:
        step = abs(r1 - r0)
    r_values = frange(min(r0, r1), max(r0, r1), step)
else:
    p.error('--r requires at least one value')
    return 2

# load systems
A_coords, A_elems = read_xyz(args.A)
B_coords, B_elems = read_xyz(args.B)
A = Rigid(A_coords, A_elems)
B = Rigid(B_coords, B_elems)

os.makedirs(args.out, exist_ok=True)

placements = enumerate_orientations(
    A, B,
    r_values=r_values,
    n_dirs=args.n_dirs,
    n_twists=args.n_twists,
    rotate_A=args.rotate_A,
    c2_A=args.c2_A,
    c2_B=args.c2_B,
    min_dist=args.min_dist,
    dedup_rmsd=args.dedup_rmsd,
)

# write outputs
count = 0
for (r, k, t, A_out, B_out) in placements:
    fname = f"orient_r{r:.2f}_k{k:04d}_t{t:03d}.xyz"
    path = os.path.join(args.out, fname)
    elems = A_out.elems + B_out.elems
    coords = np.vstack([A_out.coords, B_out.coords])
    comment = f"r={r:.3f} dir={k} twist={t}"
    write_xyz(path, elems, coords, comment)
    count += 1
