### dcd to xtc

#### one file

In [3]:
import os
import glob
import MDAnalysis as mda

def dcds_to_xtc(pdb, dcds, out_xtc, stride=1, selection="all", convert_units=True):
    """
    把同一個系統的一個 PDB + 多個 DCD（依序串接）寫成單一 XTC。

    dcds 可以是：
      - 單一 dcd 檔案路徑字串："/path/a.dcd"
      - 多個 dcd 的 list：["/path/1.dcd", "/path/2.dcd"]
      - glob pattern："/path/*.dcd"
    """
    if not os.path.isfile(pdb):
        raise FileNotFoundError(f"PDB not found: {pdb}")

    # --- normalize dcd list ---
    if isinstance(dcds, str):
        # 可能是單一檔案，也可能是 glob pattern
        if any(ch in dcds for ch in ["*", "?", "["]):
            dcd_list = sorted(glob.glob(dcds))
        else:
            dcd_list = [dcds]
    else:
        dcd_list = sorted(list(dcds))

    if not dcd_list:
        raise FileNotFoundError(f"No DCD files found: {dcds}")

    for f in dcd_list:
        if not os.path.isfile(f):
            raise FileNotFoundError(f"DCD not found: {f}")

    out_dir = os.path.dirname(out_xtc)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)

    u = mda.Universe(pdb, dcd_list)
    ag = u.select_atoms(selection)

    with mda.Writer(out_xtc, n_atoms=ag.n_atoms, convert_units=convert_units) as W:
        for ts in u.trajectory[::stride]:
            W.write(ag)

    print(f"[OK] wrote XTC: {out_xtc}  (frames={len(u.trajectory[::stride])}, dcds={len(dcd_list)})")


In [4]:
pdb = "/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/collagen_wt.pdb"
dcd = "/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/collagen_wt_nvt.dcd"
out_xtc = "/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/wt.xtc"

dcds_to_xtc(pdb, dcd, out_xtc)




[OK] wrote XTC: /mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/wt.xtc  (frames=100, dcds=1)


#### a folder

In [None]:
import os
import glob
import shutil
import MDAnalysis as mda

def dcds_to_xtc(pdb, dcd_list, out_xtc):
    """
    把同一個系統的一個 PDB + 多個 DCD（依序串接）寫成單一 XTC。
    """
    if not os.path.isfile(pdb):
        raise FileNotFoundError(f"PDB not found: {pdb}")
    if not dcd_list:
        raise FileNotFoundError(f"No DCD files found for pdb: {pdb}")

    dcd_list = sorted(dcd_list)
    u = mda.Universe(pdb, dcd_list)

    os.makedirs(os.path.dirname(out_xtc), exist_ok=True)

    with mda.Writer(out_xtc, n_atoms=u.atoms.n_atoms) as W:
        for ts in u.trajectory:
            W.write(u.atoms)

download = "/mnt/hdd/download/a1"
pdbxtc = "/mnt/hdd/jeff/dataset/output/collagen/zh-all-pdbxtc"
os.makedirs(pdbxtc, exist_ok=True)

for name in sorted(os.listdir(download)):
    folder = os.path.join(download, name)
    if not os.path.isdir(folder):
        continue

    # 找 pdb
    pdb_list = sorted(glob.glob(os.path.join(folder, "*.pdb")))
    if len(pdb_list) == 0:
        print(f"[SKIP] no pdb: {folder}")
        continue
    if len(pdb_list) > 1:
        print(f"[WARN] multiple pdb found, use first: {pdb_list[0]} (folder={folder})")
    pdb = pdb_list[0]

    # 找 dcd
    dcd_list = sorted(glob.glob(os.path.join(folder, "*.dcd")))
    if len(dcd_list) == 0:
        print(f"[SKIP] no dcd: {folder}")
        continue

    out_xtc = os.path.join(pdbxtc, f"{name}-a1.xtc")
    out_pdb = os.path.join(pdbxtc, f"{name}-a1.pdb")

    try:
        dcds_to_xtc(pdb, dcd_list, out_xtc)
        shutil.copy2(pdb, out_pdb)

        print(f"[OK] {name}: {len(dcd_list)} dcd -> {out_xtc}, pdb -> {out_pdb}")
    except Exception as e:
        print(f"[FAIL] {name}: {e}")


### frame-1 to pdb

In [5]:
def save_last_frame_pdb(in_pdb, in_xtc, out_pdb):
    """
    Robustly save the last frame of an XTC trajectory as PDB.
    Works even when random seek is not supported.
    """
    import os
    import mdtraj as md

    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")
    if not os.path.isfile(in_xtc):
        raise FileNotFoundError(f"XTC not found: {in_xtc}")

    # Load trajectory header to get n_frames
    traj = md.load(in_xtc, top=in_pdb)

    if traj.n_frames == 0:
        raise RuntimeError("Trajectory has zero frames")

    last = traj[-1]
    last.save_pdb(out_pdb)

    print(f"✅ Saved last frame ({traj.n_frames-1}) to {out_pdb}")


In [7]:
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
frame_last = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/135_ARG-last.pdb'
save_last_frame_pdb(mut_pdb,mut_xtc,frame_last)

✅ Saved last frame (301) to /mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/135_ARG-last.pdb


### mutation list

#### folder

In [107]:
def plot_mutation_scatter_from_folder(npy_dir, out_png):
    """
    讀資料夾內的 .npy 檔名（例如：5_ALA_1-a1.npy）
    畫散點圖：
      - x = mutation unit（5）
      - y = mutation type（ALA）
      - color：a1 = 藍色、a2 = 灰色
    只要求 x/y 軸對得上，不畫 title。
    """

    import os
    import re
    import glob
    import pandas as pd
    import matplotlib.pyplot as plt

    if not os.path.isdir(npy_dir):
        raise NotADirectoryError(f"Folder not found: {npy_dir}")

    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in: {npy_dir}")

    # 支援：5_ALA_1-a1.npy、130_ARG_1-a2.npy
    pat = re.compile(r"^(?P<x>\d+)_(?P<aa>[A-Za-z]{3})_.*-(?P<chain>a1|a2)\.npy$", re.IGNORECASE)

    rows = []
    for fp in npy_files:
        base = os.path.basename(fp).strip()
        m = pat.match(base)
        if not m:
            continue
        x = int(m.group("x"))
        aa = m.group("aa").upper()
        chain = m.group("chain").lower()
        rows.append((x, aa, chain, base))

    if not rows:
        raise RuntimeError(
            "No valid .npy filenames matched pattern like '5_ALA_1-a1.npy'. "
            "Please check your filenames."
        )

    d = pd.DataFrame(rows, columns=["x", "aa", "chain", "fname"])

    # y 軸類別順序：依字母排序（若要固定順序可自行改）
    y_order = sorted(d["aa"].unique().tolist())
    y_map = {aa: i for i, aa in enumerate(y_order)}
    d["y"] = d["aa"].map(y_map)

    # 顏色：a1 藍、a2 灰
    color_map = {"a1": "tab:blue", "a2": "0.6"}  # 0.6 = 灰
    d["c"] = d["chain"].map(lambda z: color_map.get(z, "0.6"))

    # 畫圖
    plt.figure(figsize=(6.5, 3.8), dpi=150)
    plt.scatter(d["x"], d["y"], s=18, c=d["c"], linewidths=0)

    plt.xlabel("Mutation unit")
    plt.ylabel("Mutation Type")
    plt.yticks(range(len(y_order)), y_order)
    plt.ylim(-0.5, len(y_order) - 0.5)

    plt.tight_layout()
    plt.savefig(out_png, bbox_inches="tight")
    plt.close()

    return d  # 回傳解析後結果，方便你檢查


In [None]:
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy' 
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/mutationlist/mutation_scatter.png'
plot_mutation_scatter_from_folder(npy,png)

#### csv 

In [127]:
def plot_mutation_scatter_from_csv(csv_path, out_png, x_col="npy_file"):
    """
    讀 csv 裡的 npy_file（例如: 54_SER_1-a1.npy）
    畫散點圖：
      - x = mutation unit (54)
      - y = mutation type (SER)
      - color: a1 = 藍色, a2 = 灰色
    只要求 x/y 軸對得上，不畫 title。
    """

    import os
    import re
    import pandas as pd
    import matplotlib.pyplot as plt

    if not os.path.isfile(csv_path):
        raise FileNotFoundError(f"CSV not found: {csv_path}")

    df = pd.read_csv(csv_path)
    if x_col not in df.columns:
        raise ValueError(f"Column '{x_col}' not found in CSV. Columns: {list(df.columns)}")

    # 允許像：54_SER_1-a1.npy、12_ARG_1-a1.npy、130_ARG_1-a1.npy
    pat = re.compile(r"(?P<x>\d+)_(?P<aa>[A-Za-z]{3})_.*-(?P<chain>a1|a2)\.npy$", re.IGNORECASE)

    rows = []
    for s in df[x_col].astype(str).tolist():
        m = pat.search(s.strip())
        if not m:
            # 跳過不符合格式的列
            continue
        x = int(m.group("x"))
        aa = m.group("aa").upper()
        chain = m.group("chain").lower()
        rows.append((x, aa, chain, s))

    if not rows:
        raise RuntimeError(
            "No valid npy_file entries matched pattern like '54_SER_1-a1.npy'. "
            "Please check your CSV content / naming."
        )

    d = pd.DataFrame(rows, columns=["x", "aa", "chain", "raw"])

    # y 軸類別順序：用你資料中出現的 AA，依字母排序（想固定順序可改這段）
    y_order = sorted(d["aa"].unique().tolist())
    y_map = {aa: i for i, aa in enumerate(y_order)}
    d["y"] = d["aa"].map(y_map)

    # 顏色：a1 藍、a2 灰
    color_map = {"a1": "tab:blue", "a2": "0.6"}  # 0.6 = 灰
    d["c"] = d["chain"].map(lambda z: color_map.get(z, "0.6"))

    # 畫圖
    plt.figure(figsize=(6.5, 3.8), dpi=150)
    plt.scatter(d["x"], d["y"], s=18, c=d["c"], linewidths=0)

    plt.xlabel("Mutation unit")
    plt.ylabel("Mutation Type")
    plt.yticks(range(len(y_order)), y_order)

    # 讓邊界好看一點
    plt.ylim(-0.5, len(y_order) - 0.5)

    # 不要 title（依你的要求）
    plt.tight_layout()
    plt.savefig(out_png, bbox_inches="tight")
    plt.close()

    return d  # 回傳解析後資料，方便你檢查


In [128]:
csv = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/select_protein.csv'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/mutationlist/select.png'
plot_mutation_scatter_from_csv(csv,png)

Unnamed: 0,x,aa,chain,raw,y,c
0,36,ASP,a2,36_ASP_2-a2.npy,2,0.6
1,52,ALA,a1,52_ALA_1-a1.npy,0,tab:blue
2,66,VAL,a2,66_VAL_2-a2.npy,6,0.6
3,184,GLU,a2,184_GLU_2-a2.npy,4,0.6
4,252,ASP,a1,252_ASP_1-a1.npy,2,tab:blue
...,...,...,...,...,...,...
134,15,SER,a2,15_SER_2-a2.npy,5,0.6
135,290,CYS,a1,290_CYS_1-a1.npy,3,tab:blue
136,132,GLU,a2,132_GLU_2-a2.npy,4,0.6
137,323,ASP,a1,323_ASP_1-a1.npy,2,tab:blue


### pdbxtc last 10 frames->PCA

#### making npy

In [116]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import MDAnalysis as mda
from sklearn.decomposition import PCA
from scipy.spatial import ConvexHull

def save_lastframes_ca_to_npy(pdbxtc_dir, out_npy_dir, last_n=10, stride=1, max_proteins=None):
    """
    讀 pdbxtc_dir 內同名 aaa.pdb/aaa.xtc
    每個 protein 取最後 last_n frames 的 CA 座標，存成 out_npy_dir/aaa.npy
    aaa.npy shape: (n_frames_kept, n_ca, 3) dtype=float32
    """
    os.makedirs(out_npy_dir, exist_ok=True)

    pdbs, xtcs = {}, {}
    for fn in os.listdir(pdbxtc_dir):
        path = os.path.join(pdbxtc_dir, fn)
        if fn.lower().endswith(".pdb"):
            pdbs[os.path.splitext(fn)[0]] = path
        elif fn.lower().endswith(".xtc"):
            xtcs[os.path.splitext(fn)[0]] = path

    names = sorted(set(pdbs.keys()) & set(xtcs.keys()))
    if max_proteins is not None:
        names = names[:max_proteins]
    if len(names) == 0:
        raise RuntimeError(f"No matched *.pdb/*.xtc pairs found in: {pdbxtc_dir}")

    for name in names:
        pdb = pdbs[name]
        xtc = xtcs[name]
        out_npy = os.path.join(out_npy_dir, f"{name}.npy")

        try:
            u = mda.Universe(pdb, xtc)
            ca = u.select_atoms("name CA")
            if ca.n_atoms == 0:
                print(f"[SKIP] {name}: no CA")
                continue

            n_frames = len(u.trajectory)
            start = max(0, n_frames - last_n)
            frame_indices = list(range(start, n_frames, stride))

            arr = np.empty((len(frame_indices), ca.n_atoms, 3), dtype=np.float32)
            for i, fi in enumerate(frame_indices):
                u.trajectory[fi]
                arr[i] = ca.positions.astype(np.float32, copy=False)

            np.save(out_npy, arr)
            print(f"[OK] saved {out_npy}  shape={arr.shape}")

        except Exception as e:
            print(f"[FAIL] {name}: {e}")


In [None]:
pdbxtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc'
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy'
save_lastframes_ca_to_npy(pdbxtc,npy)

#### wt align, all pca

In [None]:
def pca_from_npy_folder(npy_dir, out_png, wt_npy, wt_ref_frame=-1,
                        include_wt_in_plot=True, certain_protein_list=None,
                        out_csv=None):
    """
    Global PCA：用「全部軌跡的全部 frames」一起 fit（PC1/PC2 由全體決定）
    Align：所有軌跡用 wt_npy 的 wt_ref_frame 當 reference 做 Kabsch 對齊
    Plot：
      - 全部點先畫成灰色背景
      - certain_protein_list 內的「編號」才上色（點 + hull + 圈內編號）

    編號規則（非常重要）：
      - 以「成功載入且通過 shape/CA 檢查」的軌跡順序編號
      - 順序 = (wt 若 include_wt_in_plot=True) 然後 npy_dir 裡其餘 *.npy (sorted)
      - CSV 會輸出 label_id -> npy_file（含 .npy），圖上的標註也是 label_id

    certain_protein_list：
      - None：全部上色
      - 例如 [1, 3, 7]：只 highlight 編號 1/3/7（其他都灰色背景）
      - 建議只放 int（可接受 "1" 這種字串，會嘗試轉 int）

    out_csv：
      - 若 None，預設用 out_png 同名加上 "_labels.csv"
    """
    import os, glob, csv
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    from scipy.spatial import ConvexHull
    import matplotlib.patheffects as pe

    def kabsch(P, Q):
        Pc = P - P.mean(axis=0, keepdims=True)
        Qc = Q - Q.mean(axis=0, keepdims=True)
        C = Pc.T @ Qc
        V, _, Wt = np.linalg.svd(C)
        d = np.sign(np.linalg.det(V @ Wt))
        D = np.diag([1.0, 1.0, d])
        R = V @ D @ Wt
        return Pc @ R + Q.mean(axis=0, keepdims=True)

    def base_to_npy_filename(path_or_name):
        b = os.path.basename(str(path_or_name))
        if b.lower().endswith(".npy"):
            return b
        return b + ".npy"

    def to_int_list(x):
        if x is None:
            return None
        out = []
        for v in x:
            try:
                out.append(int(v))
            except Exception:
                print(f"[WARN] certain_protein_list contains non-int value: {v} (skip)")
        return out

    certain_protein_list = to_int_list(certain_protein_list)

    # ===== 1) load wt and define reference frame =====
    wt_arr = np.load(wt_npy)
    if wt_arr.ndim != 3 or wt_arr.shape[-1] != 3:
        raise RuntimeError(f"wt_npy bad shape: {wt_arr.shape}")

    T_wt, N_ref, _ = wt_arr.shape
    if not (-T_wt <= wt_ref_frame < T_wt):
        raise RuntimeError(f"wt_ref_frame out of range: {wt_ref_frame} (T_wt={T_wt})")

    ref = wt_arr[wt_ref_frame].astype(np.float64)  # (N,3)

    # ===== 2) collect ALL trajectories (aligned) to fit PCA =====
    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in {npy_dir}")

    wt_abs = os.path.abspath(wt_npy)

    files_for_fit = []
    filenames_for_fit = []  # for CSV mapping

    if include_wt_in_plot:
        files_for_fit.append(wt_npy)
        filenames_for_fit.append(base_to_npy_filename(wt_npy))

    for f in npy_files:
        if os.path.abspath(f) == wt_abs:
            continue
        files_for_fit.append(f)
        filenames_for_fit.append(base_to_npy_filename(f))

    X_all = []
    protein_ids = []           # per-frame pid
    pid_to_fname = {}          # pid -> file name (.npy)

    pid = 0
    for f, fname in zip(files_for_fit, filenames_for_fit):
        arr = np.load(f)
        if arr.ndim != 3 or arr.shape[-1] != 3:
            print(f"[SKIP] {fname}: bad shape {arr.shape}")
            continue

        T, N, _ = arr.shape
        if N != N_ref:
            raise RuntimeError(f"CA count mismatch: {fname} has N={N}, but wt has N={N_ref}")

        rows = []
        for t in range(T):
            P = arr[t].astype(np.float64)
            P = kabsch(P, ref)
            rows.append(P.reshape(-1).astype(np.float32))
        rows = np.vstack(rows)

        X_all.append(rows)
        protein_ids.extend([pid] * T)
        pid_to_fname[pid] = fname

        print(f"[OK] load+align {fname}: frames={T}, CA={N}")
        pid += 1

    if len(X_all) == 0:
        raise RuntimeError("No valid trajectories for PCA fit.")

    X_all = np.vstack(X_all)
    protein_ids = np.array(protein_ids)

    unique_pids = sorted(set(protein_ids.tolist()))

    # ===== 3) fit PCA on ALL frames =====
    pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
    Z_all = pca.fit_transform(X_all)
    print("[PCA basis] fitted on ALL trajectories")
    print("[PCA] explained variance ratio:", pca.explained_variance_ratio_)

    # ===== 4) assign global labels (1..n) for ALL valid trajectories =====
    # label_id is defined over valid trajectories (unique_pids) in order.
    pid_to_label = {pid_: (i + 1) for i, pid_ in enumerate(unique_pids)}
    label_to_pid = {v: k for k, v in pid_to_label.items()}

    # ===== 4.5) decide which to highlight (BY LABEL) =====
    if certain_protein_list is None:
        highlight_pids = unique_pids[:]  # 全部 highlight
    else:
        highlight_pids = []
        for lab in certain_protein_list:
            if lab in label_to_pid:
                highlight_pids.append(label_to_pid[lab])
            else:
                print(f"[WARN] label {lab} out of range (valid: 1..{len(unique_pids)})")

        # 去重但保留順序
        seen = set()
        highlight_pids = [p for p in highlight_pids if not (p in seen or seen.add(p))]

    # ===== 4.6) write CSV mapping (label_id -> npy_file) =====
    if out_csv is None:
        base, _ = os.path.splitext(out_png)
        out_csv = base + "_labels.csv"

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    with open(out_csv, "w", newline="", encoding="utf-8") as fp:
        w = csv.writer(fp)
        w.writerow(["label_id", "npy_file"])
        for pid_ in unique_pids:
            w.writerow([pid_to_label[pid_], pid_to_fname[pid_]])
    print(f"[OK] saved: {out_csv}")

    # 色盤只給被 highlight 的那幾個
    cmap = plt.cm.get_cmap("hsv", max(1, len(highlight_pids)))
    pid_to_color = {pid_: cmap(i) for i, pid_ in enumerate(highlight_pids)}

    # ===== 5) plot =====
    plt.figure(figsize=(10, 8))

    # 背景灰色點：全部 frames
    plt.scatter(
        Z_all[:, 0], Z_all[:, 1],
        s=6,
        alpha=0.03,
        color="gray",
        zorder=1
    )

    spread_eps = 1e-4

    # 只畫被選中的（上色 + hull + label number）
    for pid_ in highlight_pids:
        mask = protein_ids == pid_
        pts = Z_all[mask]
        color = pid_to_color[pid_]
        lab = pid_to_label[pid_]

        plt.scatter(
            pts[:, 0], pts[:, 1],
            s=10,
            alpha=0.25,
            color=color,
            zorder=3
        )

        if pts.shape[0] >= 3 and np.linalg.norm(pts.std(axis=0)) >= spread_eps:
            try:
                hull = ConvexHull(pts, qhull_options="QJ")
                hull_pts = pts[hull.vertices]
                closed = np.vstack([hull_pts, hull_pts[0]])

                plt.plot(
                    closed[:, 0], closed[:, 1],
                    color=color,
                    linewidth=0.8,
                    alpha=0.95,
                    zorder=6
                )
                plt.fill(
                    closed[:, 0], closed[:, 1],
                    facecolor=color,
                    alpha=0.05,
                    linewidth=0,
                    zorder=2
                )
            except Exception as e:
                print(f"[HULL FAIL] label={lab} file={pid_to_fname[pid_]}: {e}")

        # 標註：放在中心，顯示編號（label_id）
        center = pts.mean(axis=0)
        txt = plt.text(
            center[0], center[1],
            str(lab),
            fontsize=9,
            color="black",
            ha="center",
            va="center",
            alpha=0.9,
            zorder=10
        )
        txt.set_path_effects([pe.withStroke(linewidth=2.5, foreground="white")])

    plt.xlabel("PC1 (fit on all)")
    plt.ylabel("PC2 (fit on all)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()
    print(f"[OK] saved: {out_png}")


In [None]:
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/wt-align-all-pca-list.png'
wt = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy/wt.npy'
csv = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/number.csv'
certain_protein_list = [235, 21, 23, 55, 215, 139, 
182, 267, 164, 118, 149, 145, 167, 132, 266, 107, 
169, 176, 80, 70, 179, 222, 72, 205, 186, 263, 143, 
199, 229, 296, 123, 52, 59, 219, 60, 156, 241, 192, 
109, 144, 138, 84, 61, 131, 248, 193, 168, 66, 286, 
299, 96, 126, 129, 12, 89, 253, 24, 288, 237, 191, 
257, 71, 243, 37, 264, 212, 147, 194, 137, 22, 254, 
68, 175, 244, 289, 220, 225, 31, 188, 239, 275, 295, 
13, 102, 114, 280, 26, 97, 67, 224, 273, 291, 255, 
300, 103, 122, 146, 135, 261, 200, 162, 196, 290, 
54, 207, 153, 183, 46, 47, 301, 160, 172, 163, 181, 
27, 201, 16, 262, 121, 165, 281, 230, 258, 88, 86, 
297, 108, 40, 292, 285, 76, 87, 245, 101, 283, 83, 77, 174, 64, 294]
pca_from_npy_folder(npy,png,wt,certain_protein_list=certain_protein_list,out_csv=csv)

# pca [0.09918377 0.0941855 ]

#### labe, wt align, all pca

In [15]:
def pca_from_npy_folder(npy_dir, out_png, wt_npy, wt_ref_frame=-1,
                        include_wt_in_plot=True, certain_protein_list=None):
    """
    Global PCA：用「全部軌跡的全部 frames」一起 fit（PC1/PC2 由全體決定）
    Align：所有軌跡用 wt_npy 的 wt_ref_frame 當 reference 做 Kabsch 對齊
    Plot：
      - 全部點先畫成灰色背景
      - certain_protein_list 內的 protein 才上色（點 + hull + 圈內檔名）
    certain_protein_list：
      - 可放 ["xxx", "yyy.npy"] 兩種都可（會自動去掉副檔名）
    """
    import os, glob
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    from scipy.spatial import ConvexHull
    import matplotlib.patheffects as pe

    def kabsch(P, Q):
        Pc = P - P.mean(axis=0, keepdims=True)
        Qc = Q - Q.mean(axis=0, keepdims=True)
        C = Pc.T @ Qc
        V, _, Wt = np.linalg.svd(C)
        d = np.sign(np.linalg.det(V @ Wt))
        D = np.diag([1.0, 1.0, d])
        R = V @ D @ Wt
        return Pc @ R + Q.mean(axis=0, keepdims=True)

    def norm_name(x):
        x = os.path.basename(str(x))
        if x.lower().endswith(".npy"):
            x = x[:-4]
        return x

    # ===== 0) parse selected list =====
    if certain_protein_list is None:
        selected_name_set = None  # 表示全部都上色
    else:
        selected_name_set = set(norm_name(x) for x in certain_protein_list)

    # ===== 1) load wt and define reference frame =====
    wt_arr = np.load(wt_npy)
    if wt_arr.ndim != 3 or wt_arr.shape[-1] != 3:
        raise RuntimeError(f"wt_npy bad shape: {wt_arr.shape}")

    T_wt, N_ref, _ = wt_arr.shape
    if not (-T_wt <= wt_ref_frame < T_wt):
        raise RuntimeError(f"wt_ref_frame out of range: {wt_ref_frame} (T_wt={T_wt})")

    ref = wt_arr[wt_ref_frame].astype(np.float64)  # (N,3)

    # ===== 2) collect ALL trajectories (aligned) to fit PCA =====
    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in {npy_dir}")

    wt_abs = os.path.abspath(wt_npy)

    files_for_fit = []
    names_for_fit = []

    if include_wt_in_plot:
        files_for_fit.append(wt_npy)
        names_for_fit.append(norm_name(wt_npy))

    for f in npy_files:
        if os.path.abspath(f) == wt_abs:
            continue
        files_for_fit.append(f)
        names_for_fit.append(norm_name(f))

    X_all = []
    protein_ids = []
    protein_names = []

    pid = 0
    for f, name in zip(files_for_fit, names_for_fit):
        arr = np.load(f)
        if arr.ndim != 3 or arr.shape[-1] != 3:
            print(f"[SKIP] {name}: bad shape {arr.shape}")
            continue

        T, N, _ = arr.shape
        if N != N_ref:
            raise RuntimeError(f"CA count mismatch: {name} has N={N}, but wt has N={N_ref}")

        rows = []
        for t in range(T):
            P = arr[t].astype(np.float64)
            P = kabsch(P, ref)
            rows.append(P.reshape(-1).astype(np.float32))
        rows = np.vstack(rows)

        X_all.append(rows)
        protein_ids.extend([pid] * T)
        protein_names.extend([name] * T)

        print(f"[OK] load+align {name}: frames={T}, CA={N}")
        pid += 1

    if len(X_all) == 0:
        raise RuntimeError("No valid trajectories for PCA fit.")

    X_all = np.vstack(X_all)
    protein_ids = np.array(protein_ids)

    # ===== 3) fit PCA on ALL frames =====
    pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
    Z_all = pca.fit_transform(X_all)
    print("[PCA basis] fitted on ALL trajectories")
    print("[PCA] explained variance ratio:", pca.explained_variance_ratio_)

    # ===== 4) decide which proteins to highlight =====
    unique_pids = sorted(set(protein_ids.tolist()))

    pid_to_name = {}
    for pid in unique_pids:
        idx0 = np.where(protein_ids == pid)[0][0]
        pid_to_name[pid] = protein_names[idx0]

    if selected_name_set is None:
        highlight_pids = unique_pids[:]  # 全部上色
    else:
        highlight_pids = [pid for pid in unique_pids if pid_to_name[pid] in selected_name_set]

        # 提醒沒找到的名字
        missing = sorted(list(selected_name_set - set(pid_to_name.values())))
        if len(missing) > 0:
            print(f"[WARN] not found in folder: {missing[:10]}" + (" ..." if len(missing) > 10 else ""))

    # 色盤只給被 highlight 的那幾個
    cmap = plt.cm.get_cmap("hsv", max(1, len(highlight_pids)))
    pid_to_color = {pid: cmap(i) for i, pid in enumerate(highlight_pids)}

    # ===== 5) plot =====
    plt.figure(figsize=(10, 8))

    # 背景灰色點：全部 frames
    plt.scatter(
        Z_all[:, 0], Z_all[:, 1],
        s=6,
        alpha=0.03,
        color="gray",
        zorder=1
    )

    spread_eps = 1e-4

    # 只畫被選中的（上色 + hull + label）
    for pid in highlight_pids:
        mask = protein_ids == pid
        pts = Z_all[mask]
        name = pid_to_name[pid]
        color = pid_to_color[pid]

        plt.scatter(
            pts[:, 0], pts[:, 1],
            s=10,
            alpha=0.25,
            color=color,
            zorder=3
        )

        if pts.shape[0] >= 3 and np.linalg.norm(pts.std(axis=0)) >= spread_eps:
            try:
                hull = ConvexHull(pts, qhull_options="QJ")
                hull_pts = pts[hull.vertices]
                closed = np.vstack([hull_pts, hull_pts[0]])

                plt.plot(
                    closed[:, 0], closed[:, 1],
                    color=color,
                    linewidth=0.8,
                    alpha=0.95,
                    zorder=6
                )
                plt.fill(
                    closed[:, 0], closed[:, 1],
                    facecolor=color,
                    alpha=0.05,
                    linewidth=0,
                    zorder=2
                )

            except Exception as e:
                print(f"[HULL FAIL] {name}: {e}")

        # 標註：放在該 protein 的中心
        center = pts.mean(axis=0)
        txt = plt.text(
            center[0], center[1],
            name,
            fontsize=7,
            color="black",
            ha="center",
            va="center",
            alpha=0.85,
            zorder=10
        )
        txt.set_path_effects([pe.withStroke(linewidth=2, foreground="white")])

    plt.xlabel("PC1 (fit on all)")
    plt.ylabel("PC2 (fit on all)")
    plt.title("PCA (highlight selected proteins)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()
    print(f"[OK] saved: {out_png}")


In [None]:
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca/wt-align-all-pca-label.png'
wt = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy/wt.npy'
certain_protein_list = ['wt', '295_SER_1-a1','54_SER_1-a1','']
pca_from_npy_folder(npy,png,wt, certain_protein_list=certain_protein_list)

[OK] load+align wt: frames=10, CA=3134
[OK] load+align 100_ARG_1-a1: frames=10, CA=3134
[OK] load+align 102_ALA_1-a1: frames=10, CA=3134
[OK] load+align 105_CYS_1-a1: frames=10, CA=3134
[OK] load+align 105_SER_1-a1: frames=10, CA=3134
[OK] load+align 109_CYS_1-a1: frames=10, CA=3134
[OK] load+align 111_VAL_1-a1: frames=10, CA=3134
[OK] load+align 117_ALA_1-a1: frames=10, CA=3134
[OK] load+align 117_ASP_1-a1: frames=10, CA=3134
[OK] load+align 117_CYS_1-a1: frames=10, CA=3134
[OK] load+align 118_CYS_1-a1: frames=10, CA=3134
[OK] load+align 118_SER_1-a1: frames=10, CA=3134
[OK] load+align 119_ASP_1-a1: frames=10, CA=3134
[OK] load+align 119_CYS_1-a1: frames=10, CA=3134
[OK] load+align 120_ARG_1-a1: frames=10, CA=3134
[OK] load+align 124_ALA_1-a1: frames=10, CA=3134
[OK] load+align 126_ALA_1-a1: frames=10, CA=3134
[OK] load+align 128_ARG_1-a1: frames=10, CA=3134
[OK] load+align 128_CYS_1-a1: frames=10, CA=3134
[OK] load+align 128_SER_1-a1: frames=10, CA=3134
[OK] load+align 12_ARG_1-a1: f

  cmap = plt.cm.get_cmap("hsv", max(1, len(highlight_pids)))


[OK] saved: /mnt/hdd/jeff/dataset/output/collagen/zh-all/wt-align-all-pca-label.png


#### special label

In [None]:
def pca_from_npy_folder(
    npy_dir,
    out_png,
    label_list=None,
    align_to_ref=True
):
    """
    PCA + per-protein convex hull
    只有出現在 label_list 的 npy 會上色、畫框、標名字
    其他全部用淺灰色當背景
    """
    import os, glob
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    from scipy.spatial import ConvexHull
    import matplotlib.patheffects as pe

    if label_list is None:
        label_list = []

    # 統一成「不含副檔名」的 set
    label_set = set(
        os.path.splitext(x)[0] for x in label_list
    )

    def kabsch(P, Q):
        Pc = P - P.mean(axis=0, keepdims=True)
        Qc = Q - Q.mean(axis=0, keepdims=True)
        C = Pc.T @ Qc
        V, _, Wt = np.linalg.svd(C)
        d = np.sign(np.linalg.det(V @ Wt))
        D = np.diag([1.0, 1.0, d])
        R = V @ D @ Wt
        return Pc @ R + Q.mean(axis=0, keepdims=True)

    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in {npy_dir}")

    # reference frame
    ref = None
    ref_n = None
    if align_to_ref:
        arr0 = np.load(npy_files[0])
        ref = arr0[-1].astype(np.float64)
        ref_n = ref.shape[0]

    X_rows = []
    protein_ids = []
    protein_names = []

    for pid, f in enumerate(npy_files):
        name = os.path.splitext(os.path.basename(f))[0]
        arr = np.load(f)

        if arr.ndim != 3 or arr.shape[-1] != 3:
            print(f"[SKIP] {name}: bad shape {arr.shape}")
            continue

        T, N, _ = arr.shape
        if align_to_ref and N != ref_n:
            raise RuntimeError(f"CA count mismatch: {name}")

        for t in range(T):
            P = arr[t].astype(np.float64)
            if align_to_ref:
                P = kabsch(P, ref)
            X_rows.append(P.reshape(-1).astype(np.float32))
            protein_ids.append(pid)
            protein_names.append(name)

        print(f"[OK] loaded {name}: frames={T}, CA={N}")

    X = np.vstack(X_rows)
    pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
    Z = pca.fit_transform(X)

    print("[PCA] explained variance ratio:", pca.explained_variance_ratio_)

    unique_pids = sorted(set(protein_ids))
    protein_ids = np.array(protein_ids)

    # 只對 label protein 配色
    label_pids = []
    for pid in unique_pids:
        name = protein_names[np.where(protein_ids == pid)[0][0]]
        if name in label_set:
            label_pids.append(pid)

    cmap = plt.cm.get_cmap("hsv", max(len(label_pids), 1))
    label_pid_to_color = {
        pid: cmap(i) for i, pid in enumerate(label_pids)
    }

    plt.figure(figsize=(10, 8))
    spread_eps = 1e-4

    for pid in unique_pids:
        mask = protein_ids == pid
        pts = Z[mask]
        name = protein_names[np.where(mask)[0][0]]

        # ===== 背景（不在 label_list）=====
        if name not in label_set:
            plt.scatter(
                pts[:, 0], pts[:, 1],
                s=8,
                alpha=0.05,
                color="lightgray",
                zorder=1
            )
            continue

        # ===== label protein =====
        color = label_pid_to_color[pid]

        # 點
        plt.scatter(
            pts[:, 0], pts[:, 1],
            s=12,
            alpha=0.35,
            color=color,
            zorder=3
        )

        if pts.shape[0] < 3:
            continue
        if np.linalg.norm(pts.std(axis=0)) < spread_eps:
            continue

        try:
            hull = ConvexHull(pts, qhull_options="QJ")
            hull_pts = pts[hull.vertices]
            closed = np.vstack([hull_pts, hull_pts[0]])

            # 外框
            plt.plot(
                closed[:, 0], closed[:, 1],
                color=color,
                linewidth=0.6,
                alpha=0.95,
                zorder=5
            )

            # 檔名標註（中心）
            center = pts.mean(axis=0)
            txt = plt.text(
                center[0], center[1],
                name,
                fontsize=7,
                color="black",
                ha="center",
                va="center",
                zorder=10
            )
            txt.set_path_effects([
                pe.withStroke(linewidth=2, foreground="white")
            ])

        except Exception as e:
            print(f"[HULL FAIL] {name}: {e}")

    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.title("PCA (highlight selected proteins)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()
    print(f"[OK] saved: {out_png}")


In [None]:
pdbxtc_dir = "/mnt/hdd/jeff/dataset/output/collagen/zh-all-pdbxtc"
npy_dir = "/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy"
out_png = "/mnt/hdd/jeff/dataset/output/collagen/zh-all/pca_last10_ca.png"

# save_lastframes_ca_to_npy(pdbxtc_dir, npy_dir, last_n=10, stride=1)
pca_from_npy_folder(npy_dir, out_png, align_to_ref=True)


### pca overlapping

#### making pca-npy

In [118]:
def save_pca_points_npy(npy_dir, wt_npy, out_points_npy,
                        wt_ref_frame=-1, include_wt_in_plot=True,
                        max_frames_per_traj=None, random_state=0):
    """
    產生一個 .npy（np.save 存 dict），內容包含：
      - labels: (M,) int，label_id = 1..M
      - fnames: (M,) str，每個 label 對應的 .npy 檔名（含副檔名）
      - points: list of np.ndarray，每個元素 shape=(T_i, 2) 的 PCA (PC1,PC2)
    注意：
      - PCA basis 是用「全部軌跡全部 frames」fit（跟你原本一致）
      - 每條軌跡先 Kabsch align 到 wt_ref_frame
      - 若 max_frames_per_traj 不為 None，會對每條軌跡隨機抽樣 frames 以減少資料量
    """
    import os, glob
    import numpy as np
    from sklearn.decomposition import PCA

    rng = np.random.RandomState(random_state)

    def kabsch(P, Q):
        Pc = P - P.mean(axis=0, keepdims=True)
        Qc = Q - Q.mean(axis=0, keepdims=True)
        C = Pc.T @ Qc
        V, _, Wt = np.linalg.svd(C)
        d = np.sign(np.linalg.det(V @ Wt))
        D = np.diag([1.0, 1.0, d])
        R = V @ D @ Wt
        return Pc @ R + Q.mean(axis=0, keepdims=True)

    def base_to_npy_filename(path_or_name):
        b = os.path.basename(str(path_or_name))
        if b.lower().endswith(".npy"):
            return b
        return b + ".npy"

    # --- load wt + reference ---
    wt_arr = np.load(wt_npy)
    if wt_arr.ndim != 3 or wt_arr.shape[-1] != 3:
        raise RuntimeError(f"wt_npy bad shape: {wt_arr.shape}")

    T_wt, N_ref, _ = wt_arr.shape
    if not (-T_wt <= wt_ref_frame < T_wt):
        raise RuntimeError(f"wt_ref_frame out of range: {wt_ref_frame} (T_wt={T_wt})")
    ref = wt_arr[wt_ref_frame].astype(np.float64)

    # --- collect files in the exact same order rule as your function ---
    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in {npy_dir}")

    wt_abs = os.path.abspath(wt_npy)

    files_for_fit = []
    filenames_for_fit = []

    if include_wt_in_plot:
        files_for_fit.append(wt_npy)
        filenames_for_fit.append(base_to_npy_filename(wt_npy))

    for f in npy_files:
        if os.path.abspath(f) == wt_abs:
            continue
        files_for_fit.append(f)
        filenames_for_fit.append(base_to_npy_filename(f))

    # --- build X_all and also keep per-traj aligned rows indices ---
    X_all_list = []
    traj_slices = []   # list of (start, end) in stacked X_all
    valid_fnames = []

    start = 0
    for f, fname in zip(files_for_fit, filenames_for_fit):
        arr = np.load(f)
        if arr.ndim != 3 or arr.shape[-1] != 3:
            print(f"[SKIP] {fname}: bad shape {arr.shape}")
            continue

        T, N, _ = arr.shape
        if N != N_ref:
            raise RuntimeError(f"CA count mismatch: {fname} has N={N}, but wt has N={N_ref}")

        frame_idx = np.arange(T)
        if (max_frames_per_traj is not None) and (T > max_frames_per_traj):
            frame_idx = rng.choice(T, size=max_frames_per_traj, replace=False)
            frame_idx = np.sort(frame_idx)

        rows = []
        for t in frame_idx:
            P = arr[t].astype(np.float64)
            P = kabsch(P, ref)
            rows.append(P.reshape(-1).astype(np.float32))

        rows = np.vstack(rows)
        X_all_list.append(rows)

        end = start + rows.shape[0]
        traj_slices.append((start, end))
        valid_fnames.append(fname)
        start = end

        print(f"[OK] load+align {fname}: frames_used={rows.shape[0]}/{T}, CA={N}")

    if len(X_all_list) == 0:
        raise RuntimeError("No valid trajectories to build PCA points.")

    X_all = np.vstack(X_all_list)

    # --- PCA fit on all frames ---
    pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
    Z_all = pca.fit_transform(X_all).astype(np.float32)
    print("[PCA] explained variance ratio:", pca.explained_variance_ratio_)

    # --- pack per-protein points ---
    M = len(traj_slices)
    labels = np.arange(1, M + 1, dtype=np.int32)
    points = []
    for (s, e) in traj_slices:
        points.append(Z_all[s:e])

    out = {
        "labels": labels,
        "fnames": np.array(valid_fnames, dtype=object),
        "points": np.array(points, dtype=object),
        "explained_variance_ratio": np.array(pca.explained_variance_ratio_, dtype=np.float64),
    }

    os.makedirs(os.path.dirname(out_points_npy) or ".", exist_ok=True)
    np.save(out_points_npy, out, allow_pickle=True)
    print(f"[OK] saved: {out_points_npy}")


In [None]:
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy'
wt = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy/wt.npy'
out_npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/pca.npy'
save_pca_points_npy(npy,wt,out_npy)

#### pca-npy->no overlapping list

In [120]:
def select_low_overlap_labels_from_pca_npy(points_npy, k=140,
                                          grid_n=80,
                                          bw_method="scott",
                                          max_points_per_label=5000,
                                          random_state=0,
                                          seed_label=None,
                                          out_selected_npy=None):
    """
    input: points_npy (save_pca_points_npy 的輸出)
    output: selected_labels（長度 k 的 list[int]）

    重疊度定義：
      - 每個 label 的點雲 -> 2D KDE -> grid 上機率分布 P
      - 兩者距離用 Jensen–Shannon distance（越大越不重疊）
      - 選 k 個用 farthest-point sampling，讓彼此盡量遠

    若 out_selected_npy != None，會另外把選到的 labels/fnames 存成 npy。
    """
    import os
    import numpy as np
    from scipy.stats import gaussian_kde

    rng = np.random.RandomState(random_state)

    data = np.load(points_npy, allow_pickle=True).item()
    labels = data["labels"].astype(int).tolist()
    fnames = list(data["fnames"])
    points = list(data["points"])

    if k > len(labels):
        raise RuntimeError(f"k={k} > total labels={len(labels)}")

    # --- downsample points per label (optional) ---
    pts_list = []
    kept_labels = []
    kept_fnames = []
    for lab, fn, pts in zip(labels, fnames, points):
        pts = np.asarray(pts, dtype=np.float64)
        if pts.ndim != 2 or pts.shape[1] != 2:
            continue
        if (max_points_per_label is not None) and (pts.shape[0] > max_points_per_label):
            idx = rng.choice(pts.shape[0], size=max_points_per_label, replace=False)
            pts = pts[idx]
        pts_list.append(pts)
        kept_labels.append(int(lab))
        kept_fnames.append(fn)

    if k > len(kept_labels):
        raise RuntimeError(f"k={k} > valid labels={len(kept_labels)} (some labels had bad pts?)")

    # --- build grid from all points ---
    all_pts = np.vstack(pts_list)
    x_min, y_min = all_pts.min(axis=0)
    x_max, y_max = all_pts.max(axis=0)
    pad_x = 0.05 * (x_max - x_min + 1e-12)
    pad_y = 0.05 * (y_max - y_min + 1e-12)
    x_min -= pad_x; x_max += pad_x
    y_min -= pad_y; y_max += pad_y

    xs = np.linspace(x_min, x_max, grid_n)
    ys = np.linspace(y_min, y_max, grid_n)
    Xg, Yg = np.meshgrid(xs, ys, indexing="xy")
    grid = np.vstack([Xg.ravel(), Yg.ravel()])  # (2, G)

    # --- KDE -> probability vector P ---
    def kde_to_prob(pts):
        kde = gaussian_kde(pts.T, bw_method=bw_method)
        dens = kde(grid)
        dens = np.maximum(dens, 0.0)
        s = dens.sum()
        if s <= 0:
            dens = np.ones_like(dens) / dens.size
        else:
            dens = dens / s
        return dens.astype(np.float64)

    P = np.vstack([kde_to_prob(pts) for pts in pts_list])  # (M, G)
    M = P.shape[0]

    # --- Jensen–Shannon distance ---
    eps = 1e-12

    def js_dist(p, q):
        p = np.clip(p, eps, 1.0)
        q = np.clip(q, eps, 1.0)
        m = 0.5 * (p + q)
        js = 0.5 * (np.sum(p * (np.log(p) - np.log(m))) + np.sum(q * (np.log(q) - np.log(m))))
        return float(np.sqrt(max(js, 0.0)))

    # --- pick start ---
    avgP = P.mean(axis=0)

    if seed_label is not None:
        seed_label = int(seed_label)
        if seed_label not in kept_labels:
            raise RuntimeError(f"seed_label={seed_label} not found in points_npy.")
        start_i = kept_labels.index(seed_label)
    else:
        d0 = np.array([js_dist(P[i], avgP) for i in range(M)])
        start_i = int(np.argmax(d0))

    selected = [start_i]
    min_dist = np.full(M, np.inf, dtype=np.float64)
    for i in range(M):
        min_dist[i] = js_dist(P[i], P[start_i]) if i != start_i else 0.0

    # --- farthest-point sampling ---
    selected_set = set(selected)
    while len(selected) < k:
        next_i = int(np.argmax(min_dist))
        selected.append(next_i)
        selected_set.add(next_i)

        for i in range(M):
            if i in selected_set:
                min_dist[i] = 0.0
                continue
            d = js_dist(P[i], P[next_i])
            if d < min_dist[i]:
                min_dist[i] = d

    selected_labels = [kept_labels[i] for i in selected]
    selected_fnames = [kept_fnames[i] for i in selected]

    # optional save
    if out_selected_npy is not None:
        out = {
            "selected_labels": np.array(selected_labels, dtype=np.int32),
            "selected_fnames": np.array(selected_fnames, dtype=object),
            "source_points_npy": os.path.abspath(points_npy),
            "k": int(k),
            "grid_n": int(grid_n),
            "bw_method": str(bw_method),
            "max_points_per_label": (None if max_points_per_label is None else int(max_points_per_label)),
            "random_state": int(random_state),
            "seed_label": (None if seed_label is None else int(seed_label)),
        }
        os.makedirs(os.path.dirname(out_selected_npy) or ".", exist_ok=True)
        np.save(out_selected_npy, out, allow_pickle=True)
        print(f"[OK] saved: {out_selected_npy}")

    return selected_labels


In [121]:
out_npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/pca.npy'

out = select_low_overlap_labels_from_pca_npy(out_npy)
print(out)

[427, 452, 490, 141, 275, 33, 359, 46, 424, 279, 538, 9, 227, 111, 328, 230, 270, 11, 476, 401, 445, 352, 7, 457, 277, 502, 462, 215, 140, 144, 394, 246, 337, 243, 199, 193, 335, 544, 362, 351, 311, 421, 58, 332, 211, 406, 37, 338, 475, 113, 440, 220, 39, 569, 75, 342, 183, 190, 165, 298, 274, 580, 519, 209, 514, 492, 482, 76, 169, 218, 429, 572, 471, 320, 182, 384, 558, 485, 425, 525, 223, 152, 423, 300, 49, 307, 414, 184, 546, 586, 291, 478, 349, 224, 557, 110, 466, 120, 541, 501, 276, 51, 522, 123, 521, 285, 402, 304, 18, 185, 438, 428, 93, 565, 326, 45, 583, 322, 470, 380, 273, 115, 302, 54, 175, 548, 234, 1, 210, 409, 267, 129, 346, 396, 535, 107, 325, 53, 381, 444]


#### draw

In [None]:
def pca_from_npy_folder(npy_dir, out_png, wt_npy, wt_ref_frame=-1,
                        include_wt_in_plot=True, certain_protein_list=None,
                        out_csv=None):
    """
    Global PCA：用「全部軌跡的全部 frames」一起 fit（PC1/PC2 由全體決定）
    Align：所有軌跡用 wt_npy 的 wt_ref_frame 當 reference 做 Kabsch 對齊
    Plot：
      - 全部點先畫成灰色背景
      - certain_protein_list 內的「編號」才上色（點 + hull + 圈內編號）

    編號規則（非常重要）：
      - 以「成功載入且通過 shape/CA 檢查」的軌跡順序編號
      - 順序 = (wt 若 include_wt_in_plot=True) 然後 npy_dir 裡其餘 *.npy (sorted)
      - CSV 會輸出 label_id -> npy_file（含 .npy），圖上的標註也是 label_id

    certain_protein_list：
      - None：全部上色
      - 例如 [1, 3, 7]：只 highlight 編號 1/3/7（其他都灰色背景）
      - 建議只放 int（可接受 "1" 這種字串，會嘗試轉 int）

    out_csv：
      - 若 None，預設用 out_png 同名加上 "_labels.csv"
    """
    import os, glob, csv
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA
    from scipy.spatial import ConvexHull
    import matplotlib.patheffects as pe

    def kabsch(P, Q):
        Pc = P - P.mean(axis=0, keepdims=True)
        Qc = Q - Q.mean(axis=0, keepdims=True)
        C = Pc.T @ Qc
        V, _, Wt = np.linalg.svd(C)
        d = np.sign(np.linalg.det(V @ Wt))
        D = np.diag([1.0, 1.0, d])
        R = V @ D @ Wt
        return Pc @ R + Q.mean(axis=0, keepdims=True)

    def base_to_npy_filename(path_or_name):
        b = os.path.basename(str(path_or_name))
        if b.lower().endswith(".npy"):
            return b
        return b + ".npy"

    def to_int_list(x):
        if x is None:
            return None
        out = []
        for v in x:
            try:
                out.append(int(v))
            except Exception:
                print(f"[WARN] certain_protein_list contains non-int value: {v} (skip)")
        return out

    certain_protein_list = to_int_list(certain_protein_list)

    # ===== 1) load wt and define reference frame =====
    wt_arr = np.load(wt_npy)
    if wt_arr.ndim != 3 or wt_arr.shape[-1] != 3:
        raise RuntimeError(f"wt_npy bad shape: {wt_arr.shape}")

    T_wt, N_ref, _ = wt_arr.shape
    if not (-T_wt <= wt_ref_frame < T_wt):
        raise RuntimeError(f"wt_ref_frame out of range: {wt_ref_frame} (T_wt={T_wt})")

    ref = wt_arr[wt_ref_frame].astype(np.float64)  # (N,3)

    # ===== 2) collect ALL trajectories (aligned) to fit PCA =====
    npy_files = sorted(glob.glob(os.path.join(npy_dir, "*.npy")))
    if not npy_files:
        raise RuntimeError(f"No .npy files found in {npy_dir}")

    wt_abs = os.path.abspath(wt_npy)

    files_for_fit = []
    filenames_for_fit = []  # for CSV mapping

    if include_wt_in_plot:
        files_for_fit.append(wt_npy)
        filenames_for_fit.append(base_to_npy_filename(wt_npy))

    for f in npy_files:
        if os.path.abspath(f) == wt_abs:
            continue
        files_for_fit.append(f)
        filenames_for_fit.append(base_to_npy_filename(f))

    X_all = []
    protein_ids = []           # per-frame pid
    pid_to_fname = {}          # pid -> file name (.npy)

    pid = 0
    for f, fname in zip(files_for_fit, filenames_for_fit):
        arr = np.load(f)
        if arr.ndim != 3 or arr.shape[-1] != 3:
            print(f"[SKIP] {fname}: bad shape {arr.shape}")
            continue

        T, N, _ = arr.shape
        if N != N_ref:
            raise RuntimeError(f"CA count mismatch: {fname} has N={N}, but wt has N={N_ref}")

        rows = []
        for t in range(T):
            P = arr[t].astype(np.float64)
            P = kabsch(P, ref)
            rows.append(P.reshape(-1).astype(np.float32))
        rows = np.vstack(rows)

        X_all.append(rows)
        protein_ids.extend([pid] * T)
        pid_to_fname[pid] = fname

        print(f"[OK] load+align {fname}: frames={T}, CA={N}")
        pid += 1

    if len(X_all) == 0:
        raise RuntimeError("No valid trajectories for PCA fit.")

    X_all = np.vstack(X_all)
    protein_ids = np.array(protein_ids)

    unique_pids = sorted(set(protein_ids.tolist()))

    # ===== 3) fit PCA on ALL frames =====
    pca = PCA(n_components=2, svd_solver="randomized", random_state=0)
    Z_all = pca.fit_transform(X_all)
    print("[PCA basis] fitted on ALL trajectories")
    print("[PCA] explained variance ratio:", pca.explained_variance_ratio_)

    # ===== 4) assign global labels (1..n) for ALL valid trajectories =====
    # label_id is defined over valid trajectories (unique_pids) in order.
    pid_to_label = {pid_: (i + 1) for i, pid_ in enumerate(unique_pids)}
    label_to_pid = {v: k for k, v in pid_to_label.items()}

    # ===== 4.5) decide which to highlight (BY LABEL) =====
    if certain_protein_list is None:
        highlight_pids = unique_pids[:]  # 全部 highlight
    else:
        highlight_pids = []
        for lab in certain_protein_list:
            if lab in label_to_pid:
                highlight_pids.append(label_to_pid[lab])
            else:
                print(f"[WARN] label {lab} out of range (valid: 1..{len(unique_pids)})")

        # 去重但保留順序
        seen = set()
        highlight_pids = [p for p in highlight_pids if not (p in seen or seen.add(p))]

    # ===== 4.6) write CSV mapping (label_id -> npy_file) =====
    if out_csv is None:
        base, _ = os.path.splitext(out_png)
        out_csv = base + "_labels.csv"

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)
    with open(out_csv, "w", newline="", encoding="utf-8") as fp:
        w = csv.writer(fp)
        w.writerow(["label_id", "npy_file"])
        for pid_ in unique_pids:
            w.writerow([pid_to_label[pid_], pid_to_fname[pid_]])
    print(f"[OK] saved: {out_csv}")

    # 色盤只給被 highlight 的那幾個
    cmap = plt.cm.get_cmap("hsv", max(1, len(highlight_pids)))
    pid_to_color = {pid_: cmap(i) for i, pid_ in enumerate(highlight_pids)}

    # ===== 5) plot =====
    plt.figure(figsize=(10, 8))

    # 背景灰色點：全部 frames
    plt.scatter(
        Z_all[:, 0], Z_all[:, 1],
        s=6,
        alpha=0.03,
        color="gray",
        zorder=1
    )

    spread_eps = 1e-4

    # 只畫被選中的（上色 + hull + label number）
    for pid_ in highlight_pids:
        mask = protein_ids == pid_
        pts = Z_all[mask]
        color = pid_to_color[pid_]
        lab = pid_to_label[pid_]

        plt.scatter(
            pts[:, 0], pts[:, 1],
            s=10,
            alpha=0.25,
            color=color,
            zorder=3
        )

        if pts.shape[0] >= 3 and np.linalg.norm(pts.std(axis=0)) >= spread_eps:
            try:
                hull = ConvexHull(pts, qhull_options="QJ")
                hull_pts = pts[hull.vertices]
                closed = np.vstack([hull_pts, hull_pts[0]])

                plt.plot(
                    closed[:, 0], closed[:, 1],
                    color=color,
                    linewidth=0.8,
                    alpha=0.95,
                    zorder=6
                )
                plt.fill(
                    closed[:, 0], closed[:, 1],
                    facecolor=color,
                    alpha=0.05,
                    linewidth=0,
                    zorder=2
                )
            except Exception as e:
                print(f"[HULL FAIL] label={lab} file={pid_to_fname[pid_]}: {e}")

        # 標註：放在中心，顯示編號（label_id）
        center = pts.mean(axis=0)
        txt = plt.text(
            center[0], center[1],
            str(lab),
            fontsize=9,
            color="black",
            ha="center",
            va="center",
            alpha=0.9,
            zorder=10
        )
        txt.set_path_effects([pe.withStroke(linewidth=2.5, foreground="white")])

    plt.xlabel("PC1 (9.9%)")
    plt.ylabel("PC2 (9.4%)")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()
    print(f"[OK] saved: {out_png}")


In [None]:
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/wt-align-all-pca-all.png'
wt = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/npy/wt.npy'
csv = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/number.csv'
certain_protein_list = [427, 452, 490, 141, 275,
 33, 359, 46, 424, 279, 538, 9, 227, 111, 328, 
 230, 270, 11, 476, 401, 445, 352, 7, 457, 277, 
 502, 462, 215, 140, 144, 394, 246, 337, 243, 199, 
 193, 335, 544, 362, 351, 311, 421, 58, 332, 211, 
 406, 37, 338, 475, 113, 440, 220, 39, 569, 75, 
 342, 183, 190, 165, 298, 274, 580, 519, 209, 514, 
 492, 482, 76, 169, 218, 429, 572, 471, 320, 182, 
 384, 558, 485, 425, 525, 223, 152, 423, 300, 49, 
 307, 414, 184, 546, 586, 291, 478, 349, 224, 557, 
 110, 466, 120, 541, 501, 276, 51, 522, 123, 521, 
 285, 402, 304, 18, 185, 438, 428, 93, 565, 326, 
 45, 583, 322, 470, 380, 273, 115, 302, 54, 175, 
 548, 234, 1, 210, 409, 267, 129, 346, 396, 535, 
 107, 325, 53, 381, 444]
pca_from_npy_folder(npy,png,wt,out_csv=csv)

# pca [0.09918377 0.0941855 ]

#### number->protein, csv

In [125]:
def label_numbers_to_protein_csv(number_list, label_csv_path, out_csv_path):
    """
    input:
        number_list     : list[int | str]
        label_csv_path  : str, CSV with columns: label_id,npy_file
        out_csv_path    : str, output CSV path

    output:
        write protein_list.csv with column: npy_file
    """
    import csv
    import os

    # --- 1) number_list -> unique label_id list (preserve order) ---
    seen = set()
    labels = []
    for x in number_list:
        try:
            xi = int(x)
            if xi not in seen:
                seen.add(xi)
                labels.append(xi)
        except Exception:
            print(f"[WARN] invalid label value skipped: {x}")

    # --- 2) read label_id -> npy_file mapping ---
    label_to_file = {}
    with open(label_csv_path, newline="", encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            label_to_file[int(row["label_id"])] = row["npy_file"]

    # --- 3) map labels -> protein list ---
    protein_list = []
    for lab in labels:
        if lab in label_to_file:
            protein_list.append(label_to_file[lab])
        else:
            print(f"[WARN] label_id {lab} not found in CSV")

    # --- 4) write output CSV ---
    os.makedirs(os.path.dirname(out_csv_path) or ".", exist_ok=True)
    with open(out_csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["npy_file"])
        for p in protein_list:
            writer.writerow([p])

    print(f"[OK] saved: {out_csv_path}")


In [126]:
csv = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/number.csv'
protein_csv = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/select_protein.csv'
certain_protein_list = [427, 452, 490, 141, 275,
 33, 359, 46, 424, 279, 538, 9, 227, 111, 328, 
 230, 270, 11, 476, 401, 445, 352, 7, 457, 277, 
 502, 462, 215, 140, 144, 394, 246, 337, 243, 199, 
 193, 335, 544, 362, 351, 311, 421, 58, 332, 211, 
 406, 37, 338, 475, 113, 440, 220, 39, 569, 75, 
 342, 183, 190, 165, 298, 274, 580, 519, 209, 514, 
 492, 482, 76, 169, 218, 429, 572, 471, 320, 182, 
 384, 558, 485, 425, 525, 223, 152, 423, 300, 49, 
 307, 414, 184, 546, 586, 291, 478, 349, 224, 557, 
 110, 466, 120, 541, 501, 276, 51, 522, 123, 521, 
 285, 402, 304, 18, 185, 438, 428, 93, 565, 326, 
 45, 583, 322, 470, 380, 273, 115, 302, 54, 175, 
 548, 234, 1, 210, 409, 267, 129, 346, 396, 535, 
 107, 325, 53, 381, 444]
label_numbers_to_protein_csv(certain_protein_list, csv, protein_csv)

[OK] saved: /mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca-overlapping/select_protein.csv


# analysis

## pca

In [1]:
def output_pca_kabsch_order(md_pdb, md_xtc, out_png,
                            out_npz=None,
                            align_sel="protein and name CA",
                            ref_frame=0,
                            max_frames=None,
                            random_state=0):
    """
    用「純幾何 + Kabsch 對齊」做 PCA（單一 trajectory）

    - 不分 fit / eval
    - 不用 stride
    - PCA 用全部 frame
    - 圖上顏色 = frame order（時間順序）

    Parameters
    ----------
    md_pdb : str
        Topology PDB
    md_xtc : str
        Trajectory XTC
    out_png : str
        Output figure path
    """

    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA

    # ---------- Kabsch rotation ----------
    def kabsch_rot(P, Q):
        H = P.T @ Q
        U, S, Vt = np.linalg.svd(H)
        R = Vt.T @ U.T
        if np.linalg.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt.T @ U.T
        return R

    def align_to_ref(P, Qc, cQ):
        cP = P.mean(axis=0)
        Pc = P - cP
        R = kabsch_rot(Pc, Qc)
        return Pc @ R + cQ

    # ---------- Load trajectory ----------
    traj = md.load(md_xtc, top=md_pdb)

    # if max_frames is not None:
    traj = traj[-10:]

    atom_idx = traj.topology.select(align_sel)
    if atom_idx is None or atom_idx.size == 0:
        raise ValueError(f"align_sel returned 0 atoms: {align_sel}")

    xyz = traj.xyz[:, atom_idx, :].astype(np.float64)
    T, N, _ = xyz.shape

    if ref_frame < 0 or ref_frame >= T:
        raise ValueError(f"ref_frame out of range: {ref_frame}")

    # ---------- Reference ----------
    ref = xyz[ref_frame].copy()
    cQ = ref.mean(axis=0)
    Qc = ref - cQ

    # ---------- Align all frames ----------
    aligned = np.empty_like(xyz)
    for t in range(T):
        aligned[t] = align_to_ref(xyz[t], Qc, cQ)

    # ---------- Flatten ----------
    X = aligned.reshape(T, N * 3)

    # ---------- PCA ----------
    pca = PCA(n_components=2, random_state=random_state)
    Z = pca.fit_transform(X)

    # ---------- Plot: color = frame order ----------
    plt.figure(figsize=(6, 5))

    order = np.arange(T)

    sc = plt.scatter(
        Z[:, 0], Z[:, 1],
        c=order,
        cmap="viridis",
        s=12,
        alpha=0.8
    )

    plt.xlabel("PC1")
    plt.ylabel("PC2")

    cbar = plt.colorbar(sc)
    cbar.set_label("Frame order")

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

    # ---------- Save ----------
    if out_npz is not None:
        np.savez(
            out_npz,
            Z=Z,
            explained_variance_ratio=pca.explained_variance_ratio_,
            components=pca.components_,
            mean=pca.mean_,
            atom_idx=atom_idx,
            align_sel=align_sel,
            ref_frame=ref_frame,
        )



In [2]:
pdb_out = '/mnt/hdd/jeff/dataset/output/collagen/zh-wt/raw-data/collagen_wt.pdb'
xtc_out = '/mnt/hdd/jeff/dataset/output/collagen/zh-wt/raw-xtc/wt.xtc'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-wt/analysis/pca/wt-pca.png'
output_pca_kabsch_order(pdb_out, xtc_out, png)

## rmsf

### wt v.s. mutation

In [3]:
def rmsf_geo_wt_mut(wt_pdb, wt_xtc, mut_pdb, mut_xtc, out_png,k,
                    stride=1,
                    align_sel="protein and name CA",
                    w=27):
    """
    純幾何 RMSF（WT vs MUT 對照版）

    流程：
      (1) 讀 WT / MUT 的 pdb+xtc，抓 CA 的 xyz
      (2) 以 WT 的 frame0 CA 作 reference
      (3) WT / MUT 的每個 frame 都用 Kabsch 對齊到 WT frame0
      (4) 各自計算 RMSF
      (5) WT / MUT 疊在同一張圖上

    注意：
      - WT / MUT 的 CA 拓撲順序必須一致
      - 若有 PBC / broken molecule，請先處理
      - x 軸使用 CA index（0-based）

    回傳：
      wt_rmsf_nm, mut_rmsf_nm
    """
    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt

    # ---------- Kabsch rotation ----------
    def kabsch_rot(P, Q):
        H = P.T @ Q
        U, S, Vt = np.linalg.svd(H)
        R = Vt.T @ U.T
        if np.linalg.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt.T @ U.T
        return R

    def align_traj(xyz, Qc, cQ):
        """
        xyz: (T,N,3)
        回傳 aligned: (T,N,3)
        """
        T = xyz.shape[0]
        aligned = np.empty_like(xyz)
        for t in range(T):
            P = xyz[t]
            cP = P.mean(axis=0)
            Pc = P - cP
            R = kabsch_rot(Pc, Qc)
            aligned[t] = Pc @ R + cQ
        return aligned

    # ---------- Load WT ----------
    wt_traj = md.load(wt_xtc, top=wt_pdb)
    wt_idx = wt_traj.topology.select(align_sel)
    if wt_idx.size == 0:
        raise ValueError("WT align_sel returned 0 atoms")

    wt_xyz = wt_traj.xyz[-10:, wt_idx, :].astype(np.float64)
    Tw, N, _ = wt_xyz.shape

    # ---------- Load MUT ----------
    mut_traj = md.load(mut_xtc, top=mut_pdb)
    mut_idx = mut_traj.topology.select(align_sel)
    if mut_idx.size == 0:
        raise ValueError("MUT align_sel returned 0 atoms")

    mut_xyz = mut_traj.xyz[-10:, mut_idx, :].astype(np.float64)
    Tm, Nm, _ = mut_xyz.shape

    if Nm != N:
        raise ValueError(f"WT and MUT CA counts differ: {N} vs {Nm}")

    # ---------- WT reference (frame0) ----------
    ref = wt_xyz[0].copy()        # (N,3)
    cQ = ref.mean(axis=0)
    Qc = ref - cQ

    # ---------- Align both to WT reference ----------
    wt_aligned = align_traj(wt_xyz, Qc, cQ)
    mut_aligned = align_traj(mut_xyz, Qc, cQ)

    # ---------- RMSF ----------
    def calc_rmsf(aligned):
        mean_pos = aligned.mean(axis=0)
        disp = aligned - mean_pos[None, :, :]
        return np.sqrt(np.mean(np.sum(disp**2, axis=2), axis=0))

    wt_rmsf_nm = calc_rmsf(wt_aligned)
    mut_rmsf_nm = calc_rmsf(mut_aligned)
    def moving_average(y, w):
        return np.convolve(y, np.ones(w) / w, mode="same")
    # ---------- Plot ----------
    x = np.arange(N)

    # smooth
    wt_smooth = moving_average(wt_rmsf_nm, w)
    mut_smooth = moving_average(mut_rmsf_nm, w)

    plt.figure(figsize=(10, 4))

    # raw
    plt.plot(x, wt_rmsf_nm, color="gray", alpha=0.3, linewidth=0.5, label="WT (raw)")
    plt.plot(x, mut_rmsf_nm, color="blue", alpha=0.3, linewidth=0.5, label="Mutant (raw)")

    # smooth
    plt.plot(x, wt_smooth, color="gray", alpha=0.9, linewidth=2, label="WT (smooth, w=10)")
    plt.plot(x, mut_smooth, color="blue", alpha=0.9, linewidth=2, label="Mutant (smooth, w=10)")

    plt.xlabel("CA index (0-based)")
    plt.ylabel("RMSF (nm)")
    plt.axvline(k, color="black", linestyle="--", linewidth=1, alpha=0.7, label="Mutation")
    plt.legend(frameon=False)
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()


    return wt_rmsf_nm, mut_rmsf_nm


In [4]:
wt_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.pdb'
wt_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.xtc'
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
out_png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/rmsf/wt_vs_135_ARG_1-a1_rmsf.png'
rmsf_geo_wt_mut(wt_pdb, wt_xtc, mut_pdb, mut_xtc, out_png,135)

(array([0.140728  , 0.11954648, 0.11148779, ..., 0.04316649, 0.04127674,
        0.04580501]),
 array([0.10331045, 0.08309535, 0.06321198, ..., 0.08928415, 0.08736391,
        0.08278624]))

## dccm

In [5]:
def residue_correlation_map_geo(pdb_path, xtc_path, out_png,
                               stride=1,
                               align_sel="protein and name CA",
                               ref_frame=0,
                               show_chain_boundaries=False,
                               tick_mode="resSeq"):
    """
    用「純幾何 + Kabsch 對齊」自己算 DCCM (residue–residue correlation)，不呼叫任何套裝的 corr/superpose。

    Input
      - pdb_path, xtc_path
      - out_png: 輸出 heatmap
      - stride: 取樣步長
      - align_sel: 預設抓 CA
      - ref_frame: 用哪個 frame 當 reference (Kabsch 對齊目標)
      - show_chain_boundaries: True 會在圖上畫 chain 分界線（前提是 topology chain 資訊正確）
      - tick_mode: "resSeq" 或 "index"（x/y 軸 tick label）

    DCCM 定義
      C_ij = <Δr_i · Δr_j> / sqrt(<|Δr_i|^2><|Δr_j|^2>)
      Δr_i(t) = r_i(t) – <r_i>
      這裡的 r_i(t) 是「對齊到 reference 後」的座標。

    Return
      - C: (N,N) correlation matrix
      - labels: list of dict, 每個 CA 對應的 residue 資訊
    """
    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt

    # ---------- Kabsch rotation (pure geometry) ----------
    def kabsch_rot(P, Q):
        """
        P, Q: (N,3) 已扣掉 centroid
        回傳 R (3,3) 使得 P @ R ~ Q
        """
        H = P.T @ Q
        U, S, Vt = np.linalg.svd(H)
        R = Vt.T @ U.T
        if np.linalg.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt.T @ U.T
        return R

    def align_to_ref(P, Qc, cQ):
        """
        P: (N,3) 單一 frame 的座標
        Qc: reference 的 centered 座標 (Q - cQ)
        cQ: reference centroid
        回傳 aligned: (N,3)
        """
        cP = P.mean(axis=0)
        Pc = P - cP
        R = kabsch_rot(Pc, Qc)
        return Pc @ R + cQ

    # ---------- Load trajectory ----------
    traj = md.load(xtc_path, top=pdb_path)

    atom_idx = traj.topology.select(align_sel)
    if atom_idx.size == 0:
        raise ValueError(f"align_sel returned 0 atoms: {align_sel}")

    xyz = traj.xyz[-10:, atom_idx, :].astype(np.float64)  # (T,N,3) nm
    T, N, _ = xyz.shape
    if ref_frame < 0 or ref_frame >= T:
        raise ValueError(f"ref_frame out of range: {ref_frame}, valid: 0..{T-1}")

    # ---------- Build labels (residue info) ----------
    # 注意：這裡 assume align_sel 選到的是 CA 或至少每個 atom 對應一個 residue
    labels = []
    for a in traj.topology.atoms:
        # 只挑我們選到的 atom
        # 這樣做 O(n_atoms) 但比較穩；N 通常不大（CA 數量）
        pass
    idx_set = set(atom_idx.tolist())
    for a in traj.topology.atoms:
        if a.index in idx_set:
            res = a.residue
            chain = res.chain
            labels.append({
                "atom_index": a.index,
                "resSeq": res.resSeq,
                "resName": res.name,      # mdtraj 的 residue.name 是 resName
                "chain_index": chain.index,
                "chain_id": getattr(chain, "id", None),
            })
    if len(labels) != N:
        # 保守處理：如果順序不一致，就用 atom_idx 的順序重建
        atom_to_info = {}
        for a in traj.topology.atoms:
            if a.index in idx_set:
                res = a.residue
                chain = res.chain
                atom_to_info[a.index] = {
                    "atom_index": a.index,
                    "resSeq": res.resSeq,
                    "resName": res.name,
                    "chain_index": chain.index,
                    "chain_id": getattr(chain, "id", None),
                }
        labels = [atom_to_info[i] for i in atom_idx.tolist()]

    # ---------- Reference centered ----------
    ref = xyz[ref_frame].copy()
    cQ = ref.mean(axis=0)
    Qc = ref - cQ

    # ---------- Pass 1: compute mean position after alignment ----------
    sum_pos = np.zeros((N, 3), dtype=np.float64)
    for t in range(T):
        aligned = align_to_ref(xyz[t], Qc, cQ)
        sum_pos += aligned
    mean_pos = sum_pos / float(T)

    # ---------- Pass 2: accumulate numerator and msd ----------
    num_sum = np.zeros((N, N), dtype=np.float64)
    msd_sum = np.zeros((N,), dtype=np.float64)

    for t in range(T):
        aligned = align_to_ref(xyz[t], Qc, cQ)
        disp = aligned - mean_pos  # (N,3)

        # num_ij += disp_i · disp_j  ->  (N,3) @ (3,N) = (N,N)
        num_sum += disp @ disp.T

        # msd_i += |disp_i|^2
        msd_sum += np.sum(disp * disp, axis=1)

    num = num_sum / float(T)
    msd = msd_sum / float(T)

    denom = np.sqrt(np.outer(msd, msd))
    C = np.zeros_like(num)
    mask = denom > 0
    C[mask] = num[mask] / denom[mask]
    C = np.clip(C, -1.0, 1.0)

    # 對角線如果 denom>0 會接近 1；若某些點 msd=0 導致 0/0，這裡強制設 1
    np.fill_diagonal(C, 1.0)

    # ---------- Plot ----------
    plt.figure(figsize=(7.5, 6.5))
    im = plt.imshow(C,
                    origin="lower",
                    vmin=-1, vmax=1,
                    interpolation="nearest")

    plt.colorbar(im,
                fraction=0.046,
                pad=0.04,
                label="Correlation (DCCM)")

    plt.xlabel("Residue index")
    plt.ylabel("Residue index")
    plt.title("Residue–Residue Motion Correlation (Kabsch-aligned DCCM)")

    # === 關鍵：統一用 index 座標 ===
    n = N
    step = max(1, n // 6)   # 不要太密
    ticks = np.arange(0, n, step)

    plt.xticks(ticks, ticks)
    plt.yticks(ticks, ticks)

    # Optional: chain boundaries（這個 OK）
    if show_chain_boundaries:
        chain_idx = np.array([lab["chain_index"] for lab in labels], dtype=int)
        cuts = np.where(chain_idx[1:] != chain_idx[:-1])[0] + 1
        for c in cuts:
            plt.axvline(c - 0.5, lw=0.8)
            plt.axhline(c - 0.5, lw=0.8)

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()


    return C, labels


In [7]:
wt_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.pdb'
wt_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.xtc'
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
out_png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/dccm/wt_dccm.png'
residue_correlation_map_geo(wt_pdb,wt_xtc,out_png)

(array([[ 1.        ,  0.90031602,  0.82076923, ..., -0.43879915,
         -0.44547181, -0.35486085],
        [ 0.90031602,  1.        ,  0.95917657, ..., -0.51340818,
         -0.52933817, -0.30248465],
        [ 0.82076923,  0.95917657,  1.        , ..., -0.53149405,
         -0.57715614, -0.32100196],
        ...,
        [-0.43879915, -0.51340818, -0.53149405, ...,  1.        ,
          0.91020428,  0.74343009],
        [-0.44547181, -0.52933817, -0.57715614, ...,  0.91020428,
          1.        ,  0.74764828],
        [-0.35486085, -0.30248465, -0.32100196, ...,  0.74343009,
          0.74764828,  1.        ]]),
 [{'atom_index': 4,
   'resSeq': 1,
   'resName': 'LEU',
   'chain_index': 0,
   'chain_id': None},
  {'atom_index': 23,
   'resSeq': 2,
   'resName': 'SER',
   'chain_index': 0,
   'chain_id': None},
  {'atom_index': 34,
   'resSeq': 3,
   'resName': 'TYR',
   'chain_index': 0,
   'chain_id': None},
  {'atom_index': 55,
   'resSeq': 4,
   'resName': 'GLY',
   'chain_ind

## pca

In [73]:
def pca_project_plot_joint(atlas_pdb, atlas_xtc, mdgen_pdb, mdgen_xtc, k, out_png,
                           ca_only=True, remove_h=True,
                           ref_frame=0,
                           s=6, alpha=0.8):
    """
    需求：
      - 全部都 align 到 atlas[ref_frame]
      - PCA 的 PC1/PC2 用 (atlas + mdgen) 對齊後的全部 frames 一起 fit（joint PCA）
      - atlas / mdgen 的所有 frame 都投影到 joint PC1/PC2
      - k: 'atlas' / 'mdgen' / 'both' 決定要畫誰

    依賴：mdtraj, numpy, sklearn, matplotlib
    """
    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt
    from sklearn.decomposition import PCA

    if k not in ["atlas", "mdgen", "both"]:
        raise ValueError("k must be 'atlas', 'mdgen', or 'both'")

    def _atom_key(atom):
        return (atom.residue.chain.index, atom.residue.index, atom.name)

    def _common_atom_indices(top_a, top_b):
        keys_a = [_atom_key(a) for a in top_a.atoms]
        keys_b = [_atom_key(a) for a in top_b.atoms]
        set_b = set(keys_b)

        common_keys = [ka for ka in keys_a if ka in set_b]
        if len(common_keys) == 0:
            raise RuntimeError("No common atoms found between atlas and mdgen topologies.")

        # 用 dict 加速（避免 index() O(N^2)）
        key_to_idx_a = {ka: i for i, ka in enumerate(keys_a)}
        key_to_idx_b = {kb: i for i, kb in enumerate(keys_b)}

        idx_a = [key_to_idx_a[ck] for ck in common_keys]
        idx_b = [key_to_idx_b[ck] for ck in common_keys]
        return np.array(idx_a, dtype=int), np.array(idx_b, dtype=int)

    def _filter_indices_by_sel(traj, atom_indices, ca_only=True, remove_h=True):
        sel = []
        for i in atom_indices:
            a = traj.topology.atom(int(i))
            if remove_h and a.element is not None and a.element.symbol == "H":
                continue
            if ca_only and a.name != "CA":
                continue
            sel.append(int(i))
        if len(sel) == 0:
            raise RuntimeError("After filtering (CA/H), no atoms left to align/project.")
        return np.array(sel, dtype=int)

    # --- load ---
    atlas = md.load(atlas_xtc, top=atlas_pdb)
    mdgen = md.load(mdgen_xtc, top=mdgen_pdb)

    # --- get common atoms, then (optional) CA-only and/or remove H ---
    idx_a, idx_m = _common_atom_indices(atlas.topology, mdgen.topology)
    idx_a = _filter_indices_by_sel(atlas, idx_a, ca_only=ca_only, remove_h=remove_h)
    idx_m = _filter_indices_by_sel(mdgen, idx_m, ca_only=ca_only, remove_h=remove_h)

    atlas_c = atlas.atom_slice(idx_a, inplace=False)
    mdgen_c = mdgen.atom_slice(idx_m, inplace=False)

    # --- align everything to atlas[ref_frame] ---
    if ref_frame < 0:
        ref_frame = atlas_c.n_frames + ref_frame
    if ref_frame < 0 or ref_frame >= atlas_c.n_frames:
        raise ValueError(f"ref_frame out of range: {ref_frame}")

    ref0 = atlas_c[ref_frame]
    atlas_c.superpose(ref0)
    mdgen_c.superpose(ref0)

    # --- joint PCA: fit on ALL frames (atlas + mdgen) ---
    X_atlas = atlas_c.xyz.reshape(atlas_c.n_frames, -1)  # nm
    X_mdgen = mdgen_c.xyz.reshape(mdgen_c.n_frames, -1)

    X_all = np.vstack([X_atlas, X_mdgen])

    pca = PCA(n_components=2)
    pca.fit(X_all)

    atlas_pc = pca.transform(X_atlas)
    mdgen_pc = pca.transform(X_mdgen)

    pc1_var = pca.explained_variance_ratio_[0] * 100.0
    pc2_var = pca.explained_variance_ratio_[1] * 100.0

    # --- plot: color by dataset (atlas vs mdgen) ---
    plt.figure(figsize=(6.2, 5.2))

    if k in ["atlas", "both"]:
        plt.scatter(atlas_pc[:, 0], atlas_pc[:, 1],
                    s=s, alpha=0.25 if k == "both" else alpha,
                    color="gray", label="wt")

    if k in ["mdgen", "both"]:
        # 如果同圖疊兩者，Atlas 做淡背景會比較清楚
        plt.scatter(mdgen_pc[:, 0], mdgen_pc[:, 1],
                    s=s, alpha=alpha,
                    color="tab:blue", label="135_ARG_1-a1")

    plt.xlabel(f"PC1 ({pc1_var:.1f}%)")
    plt.ylabel(f"PC2 ({pc2_var:.1f}%)")
    plt.legend(frameon=False)

    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()


In [74]:
wt_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.pdb'
wt_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.xtc'
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
out_png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/pca/135_ARG_pca.png'
pca_project_plot_joint(wt_pdb, wt_xtc, mut_pdb, mut_xtc, 'both',out_png)

## neq

In [8]:
def save_neq_smooth_png_wt_mut(wt_pdb, wt_xtc, mut_pdb, mut_xtc,
                              k,
                              out_png="Neq_wt_vs_mut.png",
                              w1=27):
    """
    Compute Neq-like backbone conformational entropy from WT and MUT (pdb+xtc),
    plot raw Neq with smoothed curves, and draw a vertical line at mutation k.

    - Neq computed from discretized (phi, psi) states: 0–8
    - Smoothing: moving average window = w1
    - x-axis: residue resSeq (from mdtraj topology)

    Input
      - wt_pdb, wt_xtc
      - mut_pdb, mut_xtc
      - k: mutation position in resSeq (used for axvline)
    Return
      - wt_res, wt_neq, wt_neq_s
      - mut_res, mut_neq, mut_neq_s
    """
    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt

    def smooth(x, w):
        if w <= 1:
            return x.copy()
        kernel = np.ones(w) / w
        return np.convolve(x, kernel, mode="same")

    def compute_neq(pdb_path, xtc_path):
        traj = md.load(xtc_path, top=pdb_path)
        traj = traj[-10:]

        phi_idx, phi = md.compute_phi(traj)
        psi_idx, psi = md.compute_psi(traj)

        # discretize Ramachandran space into 3x3 = 9 states
        bins = np.deg2rad([-180, -60, 60, 180])
        phi_bin = np.digitize(phi, bins) - 1
        psi_bin = np.digitize(psi, bins) - 1
        states = phi_bin * 3 + psi_bin  # 0–8

        neq = np.zeros(states.shape[1], dtype=np.float64)
        for i in range(states.shape[1]):
            vals, counts = np.unique(states[:, i], return_counts=True)
            p = counts / counts.sum()
            neq[i] = np.exp(-np.sum(p * np.log(p)))

        # residue resSeq for each (phi/psi) position
        residues = []
        for atom_ids in phi_idx:
            atom = traj.topology.atom(int(atom_ids[0]))
            residues.append(atom.residue.resSeq)
        residues = np.asarray(residues, dtype=int)

        return residues, neq

    # ---- WT ----
    wt_res, wt_neq = compute_neq(wt_pdb, wt_xtc)
    wt_neq_s = smooth(wt_neq, w1)

    # ---- MUT ----
    mut_res, mut_neq = compute_neq(mut_pdb, mut_xtc)
    mut_neq_s = smooth(mut_neq, w1)

    # ---- Plot ----
    plt.figure(figsize=(10, 3.6))

    # WT
    plt.plot(wt_res, wt_neq, color="0.7", lw=0.4, alpha=0.5, label="WT (raw Neq)")
    plt.plot(wt_res, wt_neq_s, color="0.4", lw=2.0, alpha=0.9, label=f"WT (smooth w={w1})")

    # MUT
    plt.plot(mut_res, mut_neq, color="blue", lw=0.4, alpha=0.25, label="Mutant (raw Neq)")
    plt.plot(mut_res, mut_neq_s, color="blue", lw=2.0, alpha=0.9, label=f"Mutant (smooth w={w1})")

    # mutation marker (resSeq)
    plt.axvline(k, color="black", linestyle="--", linewidth=1, alpha=0.7, label="Mutation")

    plt.xlabel("Residue number (resSeq)")
    plt.ylabel("Neq")
    plt.title("Backbone Conformational Entropy (Neq-like): WT vs Mutant")
    plt.legend(frameon=False, ncol=2)
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close()

    return wt_res, wt_neq, wt_neq_s, mut_res, mut_neq, mut_neq_s


In [9]:
wt_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.pdb'
wt_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.xtc'
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
out_png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/neq/wt-v.s.-135_ARG_neq.png'
save_neq_smooth_png_wt_mut(wt_pdb, wt_xtc, mut_pdb, mut_xtc, 135, out_png)

(array([   1,    2,    3, ..., 3131, 3132, 3133]),
 array([1.84202278, 1.        , 1.        , ..., 1.84202278, 1.        ,
        1.64938489]),
 array([0.72745418, 0.76449122, 0.81575587, ..., 0.896358  , 0.85932097,
        0.80805632]),
 array([   1,    2,    3, ..., 3131, 3132, 3133]),
 array([1.64938489, 1.        , 1.        , ..., 1.9601317 , 2.80009407,
        1.9601317 ]),
 array([0.84583673, 0.88287377, 0.95109683, ..., 1.04184857, 1.00481153,
        0.9677745 ]))

## rmsd

In [11]:
def rmsd_geo(pdb_path, xtc_path, out_png,
             stride=1,
             align_sel="protein and name CA",
             ref_frame=0,
             y_range=None):
    """
    純幾何 RMSD（不呼叫 mdtraj/MDAnalysis 的 RMSD/superpose）：
      (1) 讀 pdb+xtc，抓出 align_sel 的 xyz
      (2) 以 ref_frame 的選取原子作 reference，用 Kabsch 對每個 frame 對齊到 reference
      (3) 對齊後，對每個 frame 計算 RMSD = sqrt( mean_i ||r_i(t) - r_i(ref)||^2 )
          並畫成 out_png（x= frame index, y= RMSD(nm)）

    參數：
      - stride: 取樣步長
      - align_sel: 選取原子（預設 CA）
      - ref_frame: reference frame index（注意會作用在 stride 後的 frame 序列上）
      - y_range: None（預設自動）或 (ymin, ymax) / [ymin, ymax]，例如 [0, 1]

    回傳：
      sel_indices (np.ndarray): selection 的 atom index（global）
      rmsd_nm (np.ndarray): 每個 frame 的 RMSD（nm）
    """
    import numpy as np
    import mdtraj as md
    import matplotlib.pyplot as plt

    if stride is None:
        stride = 1
    if stride <= 0:
        raise ValueError(f"stride must be >= 1, got {stride}")

    # --------- Kabsch rotation (pure geometry) ----------
    def kabsch_rot(P, Q):
        H = P.T @ Q
        U, S, Vt = np.linalg.svd(H)
        R = Vt.T @ U.T
        if np.linalg.det(R) < 0:
            Vt[-1, :] *= -1
            R = Vt.T @ U.T
        return R

    # --------- Load trajectory ----------
    traj = md.load(xtc_path, top=pdb_path)
    traj = traj[-10:]

    # 1) selection xyz (nm)
    sel_indices = traj.topology.select(align_sel)
    if sel_indices is None or sel_indices.size == 0:
        raise ValueError(f"align_sel returned 0 atoms: {align_sel}")

    xyz = traj.xyz[::stride, sel_indices, :].astype(np.float64)  # (T, N, 3)
    T, N, _ = xyz.shape

    if ref_frame < 0 or ref_frame >= T:
        raise ValueError(f"ref_frame out of range: ref_frame={ref_frame}, T={T}")

    # Reference coordinates
    ref = xyz[ref_frame].copy()  # (N,3)
    cQ = ref.mean(axis=0)
    Qc = ref - cQ

    # Per-frame RMSD
    rmsd_nm = np.empty(T, dtype=np.float64)

    # 2) Align each frame to ref, then compute RMSD
    for t in range(T):
        P = xyz[t]
        cP = P.mean(axis=0)
        Pc = P - cP
        R = kabsch_rot(Pc, Qc)
        aligned = Pc @ R + cQ

        d = aligned - ref
        rmsd_nm[t] = np.sqrt(np.mean(np.sum(d**2, axis=1)))

    # --------- Plot ----------
    fig, ax = plt.subplots(figsize=(10, 4))
    x = np.arange(T)

    ax.plot(x, rmsd_nm)
    ax.set_xlabel("Frame index (0-based)")
    ax.set_ylabel("RMSD (nm)")
    ax.set_ylim(0,1)

    # ---- y-axis range (optional) ----
    if y_range is not None:
        if (not hasattr(y_range, "__len__")) or len(y_range) != 2:
            raise ValueError("y_range must be None or a (ymin, ymax) / [ymin, ymax] pair, e.g. [0, 1]")
        ymin, ymax = float(y_range[0]), float(y_range[1])
        ax.set_ylim(ymin, ymax)

    fig.tight_layout()
    fig.savefig(out_png, dpi=300)
    plt.close(fig)


In [13]:
wt_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.pdb'
wt_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/wt.xtc'
mut_pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.pdb'
mut_xtc = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/pdbxtc/135_ARG_1-a1.xtc'
out_png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/rmsd/mut-rmsd.png'
rmsd_geo(mut_pdb, mut_xtc, out_png)

## anm

In [10]:
def compute_anm_corr_matrix(pdb_path, out_npy,
                            chain_id=None,
                            cutoff=15.0,
                            gamma=1.0,
                            ignore_hetatm=True):
    """
    從 PDB 用 CA-based ANM 計算 residue–residue correlation matrix，
    並將結果存成 .npy（shape = [N, N]）。

    回傳：
      corr : np.ndarray (N, N)
    """

    import numpy as np

    # ---------- PDB parser (CA only) ----------
    def _parse_ca_coords(pdb_file, chain_id, ignore_hetatm):
        coords = []
        seen = set()

        with open(pdb_file, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                rec = line[0:6].strip()
                if rec not in ("ATOM", "HETATM"):
                    continue
                if ignore_hetatm and rec == "HETATM":
                    continue
                if line[12:16].strip() != "CA":
                    continue

                ch = line[21].strip() or " "
                if chain_id is not None and ch != chain_id:
                    continue

                resseq_str = line[22:26].strip()
                if not resseq_str:
                    continue
                resseq = int(resseq_str)
                icode = line[26].strip() or " "

                key = (ch, resseq, icode)
                if key in seen:
                    continue
                seen.add(key)

                try:
                    x = float(line[30:38])
                    y = float(line[38:46])
                    z = float(line[46:54])
                except:
                    continue

                coords.append([x, y, z])

        if len(coords) < 3:
            raise ValueError("CA atoms too few (<3).")

        return np.asarray(coords, dtype=float)

    # ---------- auto-detect chain ----------
    if chain_id is None:
        with open(pdb_path, "r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                if line.startswith("ATOM") and line[12:16].strip() == "CA":
                    chain_id = line[21].strip() or " "
                    break

    xyz = _parse_ca_coords(pdb_path, chain_id, ignore_hetatm)
    N = xyz.shape[0]

    # ---------- Build Hessian ----------
    H = np.zeros((3 * N, 3 * N), dtype=float)
    cutoff2 = cutoff * cutoff

    for i in range(N):
        for j in range(i + 1, N):
            d = xyz[j] - xyz[i]
            dist2 = float(d @ d)
            if dist2 <= 1e-12 or dist2 > cutoff2:
                continue

            kij = gamma * np.outer(d, d) / dist2
            si = slice(3 * i, 3 * i + 3)
            sj = slice(3 * j, 3 * j + 3)

            H[si, sj] -= kij
            H[sj, si] -= kij
            H[si, si] += kij
            H[sj, sj] += kij

    # ---------- Pseudo-inverse (remove rigid-body modes) ----------
    evals, evecs = np.linalg.eigh(H)
    idx = np.argsort(evals)
    evals = evals[idx]
    evecs = evecs[:, idx]

    eps = 1e-8
    n_zero = max(6, np.sum(evals < eps))
    inv_evals = np.zeros_like(evals)
    inv_evals[n_zero:] = 1.0 / evals[n_zero:]

    H_pinv = (evecs * inv_evals) @ evecs.T

    # ---------- Correlation matrix ----------
    C_trace = np.zeros((N, N))
    for i in range(N):
        si = slice(3 * i, 3 * i + 3)
        for j in range(N):
            sj = slice(3 * j, 3 * j + 3)
            C_trace[i, j] = np.trace(H_pinv[si, sj])

    diag = np.diag(C_trace).copy()
    diag[diag <= 0] = np.nan
    corr = C_trace / np.sqrt(np.outer(diag, diag))
    corr = np.nan_to_num(corr)
    corr = np.clip(corr, -1.0, 1.0)

    np.save(out_npy, corr)
    return corr


In [9]:
def plot_anm_corr_from_npy(npy_path, out_png,
                           title=None,
                           vmin=-1.0,
                           vmax=1.0,
                           dpi=300):
    """
    讀取 ANM correlation matrix (.npy)，畫 heatmap 並輸出 PNG。
    """

    import numpy as np
    import matplotlib.pyplot as plt

    corr = np.load(npy_path)
    N = corr.shape[0]

    plt.figure(figsize=(7.5, 6.5))
    im = plt.imshow(corr,
                    origin="lower",
                    vmin=vmin, vmax=vmax,
                    interpolation="nearest")

    plt.colorbar(im,
                fraction=0.046,
                pad=0.04,
                label="ANM Correlation")

    plt.xlabel("Residue index")
    plt.ylabel("Residue index")

    if title is None:
        title = f"ANM correlation (CA), N={N}"
    plt.title(title)

    plt.tight_layout()
    plt.savefig(out_png, dpi=dpi)
    plt.close()


In [14]:
pdb = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/wt/pdbxtc/135_ARG-last.pdb'
npy = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/anm/anm-135_ARG.npy'
png = '/mnt/hdd/jeff/dataset/output/collagen/zh-all/analysis/anm/anm-135_ARG.png'
compute_anm_corr_matrix(pdb,npy,chain_id='A')
plot_anm_corr_from_npy(npy,png)