In [4]:
from pathlib import Path
import json
import numpy as np
from collections import defaultdict
from typing import Dict, Tuple, Any, Optional, List
import time

# ===================== 基础工具 =====================
ANG2AU = 1.0 / 0.5291772083  # Å -> a0

def load_json(p: Path):
    with p.open("r", encoding="utf-8") as f:
        return json.load(f)

def load_charges(path: Path) -> np.ndarray:
    """读取 charge 文件：每行 x y z q（Å, e），返回 (M,4)"""
    rows = []
    with path.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = s.split()
            if len(parts) < 4:
                continue
            try:
                x, y, z, q = map(float, parts[:4])
                rows.append([x, y, z, q])
            except ValueError:
                continue
    return np.array(rows, dtype=float)

def parse_pdb_atoms(pdb_path: Path) -> List[dict]:
    """
    只取 'ATOM' 行，按空白分列：
      0 ATOM | 1 serial | 2 atom | 3 resname | 4 resseq | 5 x | 6 y | 7 z
    """
    out = []
    with pdb_path.open("r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if not line.startswith("ATOM"):
                continue
            toks = line.split()
            if len(toks) < 8:
                continue
            try:
                serial = int(toks[1])
            except ValueError:
                serial = None
            atom = toks[2]
            resname = toks[3]
            resseq_raw = toks[4]
            try:
                resseq = int(resseq_raw)
            except ValueError:
                resseq = resseq_raw
            try:
                x, y, z = float(toks[5]), float(toks[6]), float(toks[7])
            except Exception:
                continue
            out.append({
                "serial": serial,
                "atom": atom,
                "resname": resname,
                "resseq": resseq,
                "xyz": np.array([x, y, z], dtype=float)
            })
    return out

def group_by_residue(atoms: List[dict]) -> Dict[Tuple[str, Any], Dict[str, np.ndarray]]:
    """{(resname, resseq): {atom_name: xyz}}"""
    g = defaultdict(dict)
    for a in atoms:
        key = (a["resname"], a["resseq"])
        g[key][a["atom"]] = a["xyz"]
    return g

def E_at(point_xyz: np.ndarray, charges_xyzq: np.ndarray) -> np.ndarray:
    """点电荷产生的电场向量（原子单位）：E = Σ q * r / |r|^3"""
    if charges_xyzq.size == 0:
        return np.zeros(3)
    r_vec = (point_xyz - charges_xyzq[:, :3]) * ANG2AU  # (M,3)
    r_mag = np.linalg.norm(r_vec, axis=1)               # (M,)
    mask = r_mag > 0.0
    if not np.any(mask):
        return np.zeros(3)
    r = r_vec[mask]                                     # (K,3)
    q = charges_xyzq[:, 3][mask].reshape(-1, 1)         # (K,1)
    r3 = (r_mag[mask] ** 3).reshape(-1, 1)              # (K,1)
    return np.sum(q * r / r3, axis=0)                   # (3,)

def unit(v: np.ndarray) -> Optional[np.ndarray]:
    n = np.linalg.norm(v)
    return v / n if n > 0 else None

# ===================== 计算核心（封装函数） =====================
# ...前面的代码保持不变...

# ===================== 计算极化能（输出包括电场信息） =====================
def compute_polarization_energy(
    pdb_path: Path,
    charge_path: Path,
    polar_json: Path,
    name_index_json: Path,
    connect_json: Path,
    only_resnames: Optional[set] = None,
):
    """
    返回：
      total_energy: float
      by_residue: { (resname, resseq): { "energy": float,
                                         "by_atom": {atom: {"energy": float,
                                                            "Epar": float, "Eperp": float,
                                                            "pos": [x,y,z], "dx": [dx,dy,dz], "E": [Ex,Ey,Ez], 
                                                            "E_norm": float, "Epar_norm": float, "Eperp_norm": float}} } }
    """
    charges = load_charges(charge_path)
    polar = load_json(polar_json)              # {RES: {ATOM: [α1..α5]}}
    NAME_INDEX = load_json(name_index_json)    # {RES: {ATOM: idx(1-based)}}
    CONNECT = load_json(connect_json)          # {RES: [[i1,i2], ...]}

    atoms = parse_pdb_atoms(pdb_path)
    by_res_atoms = group_by_residue(atoms)

    total_energy = 0.0
    by_residue = {}

    for (resname, resseq), amap in by_res_atoms.items():
        if only_resnames and resname not in only_resnames:
            continue
        if resname not in CONNECT or resname not in NAME_INDEX or resname not in polar:
            continue

        name2idx = NAME_INDEX[resname]
        idx2name = {v: k for k, v in name2idx.items()}

        res_energy = 0.0
        by_atom_energy = {}
        done_a1 = set()

        for (i1, i2) in CONNECT[resname]:
            a1 = idx2name.get(i1)
            a2 = idx2name.get(i2)
            if not a1 or not a2 or a1 in done_a1:
                continue
            if a1 not in amap or a2 not in amap:
                continue
            if a1 not in polar[resname]:
                continue

            p1 = amap[a1]
            p2 = amap[a2]
            u12 = unit(p2 - p1)
            if u12 is None:
                continue

            E = E_at(p1, charges)  # 电场向量
            Epar = float(np.dot(E, u12))
            Eperp = float(np.linalg.norm(E - Epar * u12))

            feats = np.array([
                Epar**2,
                Eperp**2,
                abs(Epar * Eperp),
                abs(Epar),
                abs(Eperp),
            ], dtype=float)

            alpha = np.array(polar[resname][a1], dtype=float)
            if alpha.size != 5:
                continue

            U_atom = float(np.dot(feats, alpha))
            #U_atom = feats @ alpha
            res_energy += U_atom

            # 计算电场分量
            Epar_vec = Epar * u12
            Eperp_vec = E - Epar_vec

            # 提取结果
            by_atom_energy[a1] = {
                "energy": U_atom,
                "Epar": Epar,
                "Eperp": Eperp,
                "pos": p1.tolist(),
                "dx": (p2 - p1).tolist(),
                "E": E.tolist(),
                "E_norm": float(np.linalg.norm(E)),
                "Epar_norm": float(np.linalg.norm(Epar_vec)),
                "Eperp_norm": float(np.linalg.norm(Eperp_vec)),
            }
            done_a1.add(a1)

        if by_atom_energy:
            total_energy += res_energy
            by_residue[(resname, resseq)] = {
                "energy": res_energy,
                "by_atom": by_atom_energy
            }

    return total_energy, by_residue

# ===================== 写出极化能报告（包括电场信息） =====================
def write_energy_report(total_U: float, details: dict, out_txt: Path):
    """把总能、分残基能、分原子能和电场信息写到文本文件（每个原子信息一行）"""
    lines = [f"# total_U {total_U:.12f}\n"]
    for (resname, resseq), rec in sorted(details.items(), key=lambda kv: (str(kv[0][0]), kv[0][1])):
        for atom, info in rec["by_atom"].items():
            x, y, z = info["pos"]
            dx, dy, dz = info["dx"]
            Ex, Ey, Ez = info["E"]
            E_norm = info["E_norm"]
            Epar_norm = info["Epar_norm"]
            Eperp_norm = info["Eperp_norm"]
            Utot = info["energy"]

            # 拼接为一行数据
            line = f"{atom:<4s} {x:12.6f} {y:12.6f} {z:12.6f} " \
                   f"{dx:12.6f} {dy:12.6f} {dz:12.6f} " \
                   f"{Ex:12.6f} {Ey:12.6f} {Ez:12.6f} " \
                   f"{E_norm:12.6f} " \
                   f"{Epar_norm:12.6f} {Eperp_norm:12.6f} " \
                   f"{Utot:12.6f}\n"
            lines.append(line)
    out_txt.parent.mkdir(parents=True, exist_ok=True)
    out_txt.write_text("".join(lines), encoding="utf-8")

# ===================== 批处理 =====================
def process_one_frame(base_dir: Path, tools_dir: Path, frame_idx: int):
    """
    处理单帧：读 molXX.pdb 和 chargeXX, 生成 polar_results/polarization_energy_XX.txt
    """
    sub = base_dir / "pdb_charge"
    pdb_path = sub / f"mol{frame_idx:02d}.pdb"
    charge_path = sub / f"charge{frame_idx:02d}"
    out_dir = base_dir / "polar_results"
    out_txt = out_dir / f"polarization_energy_{frame_idx:02d}.txt"

    polar_json = tools_dir / "polar.json"
    name_index = tools_dir / "name_index.json"
    connect = tools_dir / "connect.json"

    if not pdb_path.exists():
        print(f"[跳过] 找不到 {pdb_path}")
        return
    if not charge_path.exists():
        print(f"[跳过] 找不到 {charge_path}")
        return
    for p in [polar_json, name_index, connect]:
        if not p.exists():
            raise FileNotFoundError(f"[错误] 缺少映射文件：{p}")

    total_U, details = compute_polarization_energy(
        pdb_path, charge_path, polar_json, name_index, connect, only_resnames=None
    )
    write_energy_report(total_U, details, out_txt)
    print(f"[完成] 帧 {frame_idx:02d}  ->  {out_txt}")

def process_all_frames(base_dir: Path, tools_dir: Path, start: int = 0, end: int = 49, result_dirname: str = "polar_results"):
    """循环处理 mol00..mol49 / charge00..charge49，并统计每帧和平均运行时间"""
    runtimes = []  # 存 (frame_idx, 用时)

    for i in range(start, end + 1):
        t0 = time.perf_counter()
        process_one_frame(base_dir, tools_dir, i)
        t1 = time.perf_counter()

        dt = t1 - t0
        runtimes.append((i, dt))
        print(f"[计时] 帧 {i:02d} 用时 {dt:.3f} s")

    # 写入 runtime.txt
    if runtimes:
        avg_dt = sum(dt for _, dt in runtimes) / len(runtimes)
        runtime_file = base_dir / "runtime.txt"

        lines = [f"# base_dir: {base_dir}\n"]
        for idx, dt in runtimes:
            lines.append(f"frame {idx:02d}: {dt:.6f} s\n")
        lines.append(f"avg: {avg_dt:.6f} s\n")

        runtime_file.write_text("".join(lines), encoding="utf-8")
        print(f"[计时] 平均用时 {avg_dt:.3f} s, 已写入 {runtime_file}")

# ===================== 用法示例 =====================
if __name__ == "__main__":
    TOOLS = Path("/localhome2/wsren/proteinff/polar/test_project/alpha")
        
    BASE1  = Path("/localhome2/wsren/proteinff/polar/test_project/protein/1l2y")
    # 处理 mol00..mol49 / charge00..charge49
    process_all_frames(BASE1, TOOLS, start=0, end=1)




[完成] 帧 00  ->  /localhome2/wsren/proteinff/polar/test_project/protein/1l2y/polar_results/polarization_energy_00.txt
[计时] 帧 00 用时 1.560 s
[完成] 帧 01  ->  /localhome2/wsren/proteinff/polar/test_project/protein/1l2y/polar_results/polarization_energy_01.txt
[计时] 帧 01 用时 1.018 s
[计时] 平均用时 1.289 s, 已写入 /localhome2/wsren/proteinff/polar/test_project/protein/1l2y/runtime.txt
