In [None]:
from pathlib import Path
import json
import numpy as np
from collections import defaultdict
from typing import Dict, Tuple, Any, Optional, List
import time

# ===================== Basic utilities =====================
ANG2AU = 1.0 / 0.5291772083  # Angstrom -> Bohr (a0)

def load_json(p: Path):
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)

def load_charges(path: Path) -> np.ndarray:
    """Read a charge file: each line is `x y z q` (Angstrom, e). Return an array of shape (M, 4)."""
    rows = []
    with path.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = s.split()
            if len(parts) < 4:
                continue
            try:
                x, y, z, q = map(float, parts[:4])
                rows.append([x, y, z, q])
            except ValueError:
                continue
    return np.array(rows, dtype=float)

def parse_pdb_atoms(pdb_path: Path) -> List[dict]:
    """Parse atoms from a PDB-like file.

    Only lines starting with 'ATOM' are read, using whitespace splitting:
      0 ATOM | 1 serial | 2 atom | 3 resname | 4 resseq | 5 x | 6 y | 7 z

    Returns a list of dicts with keys: serial, atom, resname, resseq, xyz.
    """
    out = []
    with pdb_path.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if not line.startswith("ATOM"):
                continue
            toks = line.split()
            if len(toks) < 8:
                continue
            try:
                serial = int(toks[1])
            except ValueError:
                serial = None
            atom = toks[2]
            resname = toks[3]
            resseq_raw = toks[4]
            try:
                resseq = int(resseq_raw)
            except ValueError:
                resseq = resseq_raw
            try:
                x, y, z = float(toks[5]), float(toks[6]), float(toks[7])
            except Exception:
                continue
            out.append({
                "serial": serial,
                "atom": atom,
                "resname": resname,
                "resseq": resseq,
                "xyz": np.array([x, y, z], dtype=float),
            })
    return out

def group_by_residue(atoms: List[dict]) -> Dict[Tuple[str, Any], Dict[str, np.ndarray]]:
    """Group atoms by residue.

    Returns:
        {(resname, resseq): {atom_name: xyz}}
    """
    g = defaultdict(dict)
    for a in atoms:
        key = (a["resname"], a["resseq"])
        g[key][a["atom"]] = a["xyz"]
    return g

def E_at(point_xyz: np.ndarray, charges_xyzq: np.ndarray) -> np.ndarray:
    """Electric field (atomic units) at a point generated by point charges.

    E = sum_i q_i * r_i / |r_i|^3, where r_i = (point - charge_position).
    """
    if charges_xyzq.size == 0:
        return np.zeros(3)
    r_vec = (point_xyz - charges_xyzq[:, :3]) * ANG2AU  # (M, 3)
    r_mag = np.linalg.norm(r_vec, axis=1)               # (M,)
    mask = r_mag > 0.0
    if not np.any(mask):
        return np.zeros(3)
    r = r_vec[mask]                                     # (K, 3)
    q = charges_xyzq[:, 3][mask].reshape(-1, 1)         # (K, 1)
    r3 = (r_mag[mask] ** 3).reshape(-1, 1)              # (K, 1)
    return np.sum(q * r / r3, axis=0)                   # (3,)

def unit(v: np.ndarray) -> Optional[np.ndarray]:
    n = np.linalg.norm(v)
    return v / n if n > 0 else None

# ===================== Polarization-energy evaluation =====================
def compute_polarization_energy(
    pdb_path: Path,
    charge_path: Path,
    polar_json: Path,
    name_index_json: Path,
    connect_json: Path,
    only_resnames: Optional[set] = None,
):
    """Compute polarization energies and report per-atom electric-field components.

    Returns:
        total_energy (float)
        by_residue: {(resname, resseq): {
            "energy": float,
            "by_atom": {
                atom: {
                    "energy": float,
                    "Epar": float,
                    "Eperp": float,
                    "pos": [x, y, z],
                    "dx": [dx, dy, dz],
                    "E": [Ex, Ey, Ez],
                    "E_norm": float,
                    "Epar_norm": float,
                    "Eperp_norm": float,
                }
            }
        }}
    """
    charges = load_charges(charge_path)
    polar = load_json(polar_json)              # {RES: {ATOM: [a1..a5]}}
    NAME_INDEX = load_json(name_index_json)    # {RES: {ATOM: idx (1-based)}}
    CONNECT = load_json(connect_json)          # {RES: [[i1, i2], ...]}

    atoms = parse_pdb_atoms(pdb_path)
    by_res_atoms = group_by_residue(atoms)

    total_energy = 0.0
    by_residue = {}

    for (resname, resseq), amap in by_res_atoms.items():
        if only_resnames and resname not in only_resnames:
            continue
        if resname not in CONNECT or resname not in NAME_INDEX or resname not in polar:
            continue

        name2idx = NAME_INDEX[resname]
        idx2name = {v: k for k, v in name2idx.items()}

        res_energy = 0.0
        by_atom_energy = {}
        done_a1 = set()

        for (i1, i2) in CONNECT[resname]:
            a1 = idx2name.get(i1)
            a2 = idx2name.get(i2)
            if not a1 or not a2 or a1 in done_a1:
                continue
            if a1 not in amap or a2 not in amap:
                continue
            if a1 not in polar[resname]:
                continue

            p1 = amap[a1]
            p2 = amap[a2]
            u12 = unit(p2 - p1)
            if u12 is None:
                continue

            E = E_at(p1, charges)  # electric-field vector
            Epar = float(np.dot(E, u12))
            Eperp = float(np.linalg.norm(E - Epar * u12))

            feats = np.array([
                Epar**2,
                Eperp**2,
                abs(Epar * Eperp),
                abs(Epar),
                abs(Eperp),
            ], dtype=float)

            alpha = np.array(polar[resname][a1], dtype=float)
            if alpha.size != 5:
                continue

            U_atom = float(np.dot(feats, alpha))
            res_energy += U_atom

            # Field decomposition
            Epar_vec = Epar * u12
            Eperp_vec = E - Epar_vec

            by_atom_energy[a1] = {
                "energy": U_atom,
                "Epar": Epar,
                "Eperp": Eperp,
                "pos": p1.tolist(),
                "dx": (p2 - p1).tolist(),
                "E": E.tolist(),
                "E_norm": float(np.linalg.norm(E)),
                "Epar_norm": float(np.linalg.norm(Epar_vec)),
                "Eperp_norm": float(np.linalg.norm(Eperp_vec)),
            }
            done_a1.add(a1)

        if by_atom_energy:
            total_energy += res_energy
            by_residue[(resname, resseq)] = {
                "energy": res_energy,
                "by_atom": by_atom_energy,
            }

    return total_energy, by_residue

# ===================== Write report =====================
def write_energy_report(total_U: float, details: dict, out_txt: Path):
    """Write total energy, per-residue/per-atom energies, and field descriptors to a text file (one atom per line)."""
    lines = [f"# total_U {total_U:.12f}\n"]
    for (resname, resseq), rec in sorted(details.items(), key=lambda kv: (str(kv[0][0]), kv[0][1])):
        for atom, info in rec["by_atom"].items():
            x, y, z = info["pos"]
            dx, dy, dz = info["dx"]
            Ex, Ey, Ez = info["E"]
            E_norm = info["E_norm"]
            Epar_norm = info["Epar_norm"]
            Eperp_norm = info["Eperp_norm"]
            Utot = info["energy"]

            line = (
                f"{atom:<4s} {x:12.6f} {y:12.6f} {z:12.6f} "
                f"{dx:12.6f} {dy:12.6f} {dz:12.6f} "
                f"{Ex:12.6f} {Ey:12.6f} {Ez:12.6f} "
                f"{E_norm:12.6f} "
                f"{Epar_norm:12.6f} {Eperp_norm:12.6f} "
                f"{Utot:12.6f}\n"
            )
            lines.append(line)

    out_txt.parent.mkdir(parents=True, exist_ok=True)
    out_txt.write_text("".join(lines), encoding="utf-8")

# ===================== Batch processing =====================
def process_one_frame(base_dir: Path, tools_dir: Path, frame_idx: int):
    """Process a single frame.

    Reads `pdb_charge/molXX.pdb` and `pdb_charge/chargeXX`, and writes
    `polar_results/polarization_energy_XX.txt`.
    """
    sub = base_dir / "pdb_charge"
    pdb_path = sub / f"mol{frame_idx:02d}.pdb"
    charge_path = sub / f"charge{frame_idx:02d}"
    out_dir = base_dir / "polar_results"
    out_txt = out_dir / f"polarization_energy_{frame_idx:02d}.txt"

    polar_json = tools_dir / "polar.json"
    name_index = tools_dir / "name_index.json"
    connect = tools_dir / "connect.json"

    if not pdb_path.exists():
        print(f"[SKIP] Missing file: {pdb_path}")
        return
    if not charge_path.exists():
        print(f"[SKIP] Missing file: {charge_path}")
        return

    for p in [polar_json, name_index, connect]:
        if not p.exists():
            raise FileNotFoundError(f"[ERROR] Missing mapping file: {p}")

    total_U, details = compute_polarization_energy(
        pdb_path, charge_path, polar_json, name_index, connect, only_resnames=None
    )
    write_energy_report(total_U, details, out_txt)
    print(f"[DONE] Frame {frame_idx:02d} -> {out_txt}")

def process_all_frames(base_dir: Path, tools_dir: Path, start: int = 0, end: int = 49):
    """Process mol00..mol49 / charge00..charge49 and record runtime statistics."""
    runtimes = []  # list of (frame_idx, wall time)

    for i in range(start, end + 1):
        t0 = time.perf_counter()
        process_one_frame(base_dir, tools_dir, i)
        t1 = time.perf_counter()

        dt = t1 - t0
        runtimes.append((i, dt))
        print(f"[TIME] Frame {i:02d}: {dt:.3f} s")

    # Write runtime summary
    if runtimes:
        avg_dt = sum(dt for _, dt in runtimes) / len(runtimes)
        runtime_file = base_dir / "runtime.txt"

        lines = [f"# base_dir: {base_dir}\n"]
        for idx, dt in runtimes:
            lines.append(f"frame {idx:02d}: {dt:.6f} s\n")
        lines.append(f"avg: {avg_dt:.6f} s\n")

        runtime_file.write_text("".join(lines), encoding="utf-8")
        print(f"[TIME] Average: {avg_dt:.3f} s (written to {runtime_file})")

# ===================== Example usage =====================
if __name__ == "__main__":
    # Use relative paths assuming the notebook/script is executed from the repository root.
    # - tools/   : contains polar.json, name_index.json, connect.json
    # - example/ : contains pdb_charge/ and outputs will be written under example/polar_results/
    TOOLS_DIR = Path("tools")
    BASE_DIR = Path("example")

    # Process mol00..mol01 / charge00..charge01 for a quick test
    process_all_frames(BASE_DIR, TOOLS_DIR, start=0, end=1)
