### preprocess

#### dcd to pdb and xtc

In [2]:
def dcd_to_frame0_pdb_and_xtc(
    dcd_path,
    top_path,
    out_pdb,
    out_xtc,
    stride= 1,):
    """
    Convert a DCD trajectory to:
      (1) frame-0 PDB (out_pdb)
      (2) XTC trajectory (out_xtc)

    Requirements:
      pip install mdtraj

    Args:
      dcd_path: input .dcd
      top_path: topology file path (typically .pdb / .prmtop / .psf / etc.)
      out_pdb: output PDB filename for the first frame
      out_xtc: output XTC filename for the trajectory
      stride: keep every `stride` frames (1 = keep all)

    Returns:
      (out_pdb, out_xtc)
    """
    import os

    if not os.path.isfile(dcd_path):
        raise FileNotFoundError(f"DCD not found: {dcd_path}")
    if not os.path.isfile(top_path):
        raise FileNotFoundError(f"Topology not found: {top_path}")
    if stride < 1:
        raise ValueError("stride must be >= 1")

    try:
        import mdtraj as md
    except Exception as e:
        raise RuntimeError("mdtraj is required. Install with: pip install mdtraj") from e

    traj = md.load_dcd(dcd_path, top=top_path, stride=stride)
    if traj.n_frames == 0:
        raise RuntimeError("Loaded trajectory has 0 frames. Check your DCD/topology.")

    # Save frame 0 as PDB
    traj[0].save_pdb(out_pdb)

    # Save full trajectory as XTC
    traj.save_xtc(out_xtc)

    return out_pdb, out_xtc
pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-raw/initial.pdb'
dcd = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-raw/nc_1atm_last5ns.dcd'
out_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.pdb'
out_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.xtc'
dcd_to_frame0_pdb_and_xtc(dcd,pdb,out_pdb,out_xtc)

('/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.pdb',
 '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.xtc')

#### remove water

In [7]:
def remove_water_from_pdb_xtc(
    in_pdb,
    in_xtc,
    out_pdb,
    out_xtc,
    remove_resnames=("CLA","SOD","HOH","WAT","SOL","TIP3","TIP3P","SPC","SPCE"),
    stride=1,
):
    import os
    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")
    if not os.path.isfile(in_xtc):
        raise FileNotFoundError(f"XTC not found: {in_xtc}")

    import mdtraj as md

    traj = md.load(in_xtc, top=in_pdb, stride=stride)

    sel = "not (" + " or ".join([f"resname {r}" for r in remove_resnames]) + ")"
    keep_idx = traj.topology.select(sel)

    if keep_idx.size == 0:
        raise RuntimeError("Selection resulted in 0 atoms. Check resnames / topology.")

    traj2 = traj.atom_slice(keep_idx)
    traj2[0].save_pdb(out_pdb)
    traj2.save_xtc(out_xtc)

    return out_pdb, out_xtc



# =====================
in_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.pdb'
in_xtc =  '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/raw.xtc'
out_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.pdb'
out_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.xtc'
remove_water_from_pdb_xtc(
    in_pdb,
    in_xtc,
    out_pdb,
    out_xtc,
)

('/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.pdb',
 '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.xtc')

#### residuetype - atom14

In [8]:
def in_pdb_residuetype_minus_atom14_residuetype(
    pdb_path,
    include_hetatm=True,
    ignore_resnames=("HOH", "WAT", "SOL", "TIP3", "TIP3P", "SPC", "SPCE"),
):
    import os
    if not os.path.isfile(pdb_path):
        raise FileNotFoundError(f"PDB not found: {pdb_path}")

    atom14_set = {
        "ALA","ARG","ASN","ASP","CYS","GLN","GLU","GLY","HIS","ILE",
        "LEU","LYS","MET","PHE","PRO","SER","THR","TRP","TYR","VAL",
    }
    ignore_set = set(ignore_resnames)

    in_pdb_set = set()
    with open(pdb_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            is_atom = line.startswith("ATOM  ")
            is_het = line.startswith("HETATM")
            if not (is_atom or (include_hetatm and is_het)):
                continue
            if len(line) < 20:
                continue

            resname = line[17:20].strip()
            if not resname or resname in ignore_set:
                continue

            in_pdb_set.add(resname)

    return sorted(in_pdb_set - atom14_set)


# =====================
pdb_path = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.pdb'
in_pdb_residuetype_minus_atom14_residuetype(
    pdb_path)

['HYP']

#### HYP->PRO

In [9]:
def rename_hyp_to_pro_in_pdb_xtc(
    in_pdb,
    in_xtc,
    out_pdb,
    out_xtc,
):
    """
    Input:  pdb + xtc
    Output: pdb + xtc

    做的事：
    - 只把 PDB 裡面 residue name = "HYP" 的地方改成 "PRO"
    - 其他內容完全不動
    - XTC 本身不含胺基酸名字/拓撲資訊，所以直接原封不動複製一份到 out_xtc
      （之後用 out_pdb 當拓撲配合 out_xtc 讀取即可）
    """
    import os
    import shutil

    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")
    if not os.path.isfile(in_xtc):
        raise FileNotFoundError(f"XTC not found: {in_xtc}")

    n_changed = 0
    with open(in_pdb, "r", encoding="utf-8", errors="ignore") as fin, open(out_pdb, "w", encoding="utf-8") as fout:
        for line in fin:
            if line.startswith("ATOM  ") or line.startswith("HETATM"):
                if len(line) >= 20:
                    resname = line[17:20]
                    if resname.strip() == "HYP":
                        line = line[:17] + "PRO" + line[20:]
                        n_changed += 1
            fout.write(line)

    shutil.copyfile(in_xtc, out_xtc)

    return out_pdb, out_xtc, n_changed

# =====================
in_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.pdb'
in_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/no-water.xtc'
out_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.pdb'
out_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.xtc'
rename_hyp_to_pro_in_pdb_xtc(
    in_pdb,
    in_xtc,
    out_pdb,
    out_xtc,
)

('/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.pdb',
 '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.xtc',
 4935)

#### pro reorder atom 

In [31]:
import mdtraj as md
import os
from tqdm import tqdm
import shutil

PRO_TARGET_ATOMS = ['N', 'CA', 'C', 'CB', 'O', 'CG', 'CD']

def rebuild_traj_with_pro_order(traj):
    old_top = traj.topology
    new_top = md.Topology()

    # 建立 mapping: old_atom_index -> new_atom_index
    old_to_new = {}

    for old_chain in old_top.chains:
        new_chain = new_top.add_chain()

        for old_res in old_chain.residues:
            new_res = new_top.add_residue(old_res.name, new_chain, resSeq=old_res.resSeq)

            if old_res.name == 'PRO':
                atom_dict = {a.name: a for a in old_res.atoms}
                missing = [name for name in PRO_TARGET_ATOMS if name not in atom_dict]
                if missing:
                    return None, f"PRO residue {old_res.index} missing atoms: {missing}"

                # 只加入 7 顆，且固定順序
                for name in PRO_TARGET_ATOMS:
                    old_atom = atom_dict[name]
                    new_atom = new_top.add_atom(old_atom.name, old_atom.element, new_res)
                    old_to_new[old_atom.index] = new_atom.index
            else:
                # 其他殘基：完全照原本順序加入（包含 H）
                for old_atom in old_res.atoms:
                    new_atom = new_top.add_atom(old_atom.name, old_atom.element, new_res)
                    old_to_new[old_atom.index] = new_atom.index

    # 複製 bonds（可有可無，但保留比較完整）
    for bond in old_top.bonds:
        a1, a2 = bond
        if a1.index in old_to_new and a2.index in old_to_new:
            new_top.add_bond(
                new_top.atom(old_to_new[a1.index]),
                new_top.atom(old_to_new[a2.index])
            )

    # 依照 new_top 的 atom 順序，建立 old indices list（用來重排 xyz）
    # old_to_new 是 old->new，所以我們反過來做 new->old
    new_to_old = [None] * new_top.n_atoms
    for old_i, new_i in old_to_new.items():
        new_to_old[new_i] = old_i

    if any(v is None for v in new_to_old):
        return None, "Atom mapping incomplete (unexpected)."

    new_xyz = traj.xyz[:, new_to_old, :]

    new_traj = md.Trajectory(
        xyz=new_xyz,
        topology=new_top,
        time=traj.time,
        unitcell_lengths=traj.unitcell_lengths,
        unitcell_angles=traj.unitcell_angles
    )
    return new_traj, None


def process_and_copy_folders(input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    subfolders = [f.name for f in os.scandir(input_dir) if f.is_dir()]

    for folder_name in tqdm(subfolders, desc="Processing Folders"):
        in_subfolder = os.path.join(input_dir, folder_name)
        out_subfolder = os.path.join(output_dir, folder_name)

        pdb_file = next((f for f in os.listdir(in_subfolder) if f.endswith(".pdb")), None)
        xtc_file = next((f for f in os.listdir(in_subfolder) if f.endswith(".xtc")), None)
        if not pdb_file or not xtc_file:
            continue

        if not os.path.exists(out_subfolder):
            os.makedirs(out_subfolder)

        try:
            in_pdb_path = os.path.join(in_subfolder, pdb_file)
            in_xtc_path = os.path.join(in_subfolder, xtc_file)
            out_pdb_path = os.path.join(out_subfolder, pdb_file)
            out_xtc_path = os.path.join(out_subfolder, xtc_file)

            traj = md.load(in_xtc_path, top=in_pdb_path)

            new_traj, err = rebuild_traj_with_pro_order(traj)
            if err is not None:
                print(f"\n[Skip] {folder_name}: {err}")
                continue

            new_traj[0].save_pdb(out_pdb_path)
            new_traj.save_xtc(out_xtc_path)

            for file in os.listdir(in_subfolder):
                if not file.endswith(".pdb") and not file.endswith(".xtc"):
                    shutil.copy2(os.path.join(in_subfolder, file),
                                 os.path.join(out_subfolder, file))

        except Exception as e:
            print(f"\nError processing {folder_name}: {e}")


input_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment'
output_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segmentPRO'
process_and_copy_folders(input_folder, output_folder)


Processing Folders: 100%|██████████| 80/80 [00:11<00:00,  6.86it/s]


#### residue id start from 1

In [32]:
import mdtraj as md
import os
from tqdm import tqdm
import shutil

def renumber_residues_from_1(input_dir, output_dir):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    subfolders = [f.name for f in os.scandir(input_dir) if f.is_dir()]

    for folder_name in tqdm(subfolders, desc="Renumber residues"):
        in_subfolder = os.path.join(input_dir, folder_name)
        out_subfolder = os.path.join(output_dir, folder_name)
        if not os.path.exists(out_subfolder):
            os.makedirs(out_subfolder)

        pdb_file = next((f for f in os.listdir(in_subfolder) if f.endswith(".pdb")), None)
        xtc_file = next((f for f in os.listdir(in_subfolder) if f.endswith(".xtc")), None)
        if not pdb_file or not xtc_file:
            continue

        try:
            in_pdb_path = os.path.join(in_subfolder, pdb_file)
            in_xtc_path = os.path.join(in_subfolder, xtc_file)
            out_pdb_path = os.path.join(out_subfolder, pdb_file)
            out_xtc_path = os.path.join(out_subfolder, xtc_file)

            traj = md.load(in_xtc_path, top=in_pdb_path)
            old_top = traj.topology

            # --- rebuild topology but renumber residues from 1 (per chain) ---
            new_top = md.Topology()
            old_to_new_atom = {}

            for old_chain in old_top.chains:
                new_chain = new_top.add_chain()
                new_resSeq = 1

                for old_res in old_chain.residues:
                    new_res = new_top.add_residue(
                        old_res.name,
                        new_chain,
                        resSeq=new_resSeq
                    )
                    new_resSeq += 1

                    for old_atom in old_res.atoms:
                        new_atom = new_top.add_atom(old_atom.name, old_atom.element, new_res)
                        old_to_new_atom[old_atom.index] = new_atom.index

            # copy bonds
            for bond in old_top.bonds:
                a1, a2 = bond
                if a1.index in old_to_new_atom and a2.index in old_to_new_atom:
                    new_top.add_bond(
                        new_top.atom(old_to_new_atom[a1.index]),
                        new_top.atom(old_to_new_atom[a2.index])
                    )

            new_traj = md.Trajectory(
                xyz=traj.xyz,
                topology=new_top,
                time=traj.time,
                unitcell_lengths=traj.unitcell_lengths,
                unitcell_angles=traj.unitcell_angles
            )

            # save
            new_traj[0].save_pdb(out_pdb_path)
            new_traj.save_xtc(out_xtc_path)

            # copy other files
            for file in os.listdir(in_subfolder):
                if not file.endswith(".pdb") and not file.endswith(".xtc"):
                    shutil.copy2(os.path.join(in_subfolder, file),
                                 os.path.join(out_subfolder, file))

        except Exception as e:
            print(f"\nError processing {folder_name}: {e}")

# example
input_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segmentPRO'
output_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment_order'
renumber_residues_from_1(input_folder, output_folder)


Renumber residues: 100%|██████████| 80/80 [00:10<00:00,  7.34it/s]


#### check xtc frames

In [13]:
def count_frames_in_pdb_xtc(in_pdb, in_xtc):
    import os
    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")
    if not os.path.isfile(in_xtc):
        raise FileNotFoundError(f"XTC not found: {in_xtc}")

    import mdtraj as md

    traj = md.load(in_xtc, top=in_pdb)
    return traj.n_frames

# ====================
in_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment/1-256/1-256.pdb'
in_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment/1-256/1-256.xtc'
count_frames_in_pdb_xtc(in_pdb,in_xtc)

101

### csv split

#### making csv

In [10]:
def split_pdb_to_seq_csv(
    in_pdb,
    out_csv,
    seg_len=256,
    n_segments=80,
    overlap=True,
    include_hetatm=False,
):
    """
    Input:  pdb
    Output: csv with columns: name, seqres

    - 以「殘基(residue)」為單位切序列
    - name 格式："{start}-{end}"（1-based, inclusive）
      例如 1-256
    - seqres：該段的一字母胺基酸序列
    """
    import os
    import csv

    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")

    # 3-letter -> 1-letter
    restype_3to1 = {
        "ALA": "A", "ARG": "R", "ASN": "N", "ASP": "D", "CYS": "C",
        "GLN": "Q", "GLU": "E", "GLY": "G", "HIS": "H", "ILE": "I",
        "LEU": "L", "LYS": "K", "MET": "M", "PHE": "F", "PRO": "P",
        "SER": "S", "THR": "T", "TRP": "W", "TYR": "Y", "VAL": "V",
    }

    # 解析 PDB：依照 (chain_id, resSeq, iCode, resname) 收集殘基順序
    residues = []
    seen = set()

    with open(in_pdb, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if line.startswith("ATOM  ") or (include_hetatm and line.startswith("HETATM")):
                if len(line) < 54:
                    continue
                resname = line[17:20].strip()
                chain_id = line[21].strip() or "_"
                resseq = line[22:26].strip()
                icode = line[26].strip() or "_"
                if not resseq:
                    continue
                key = (chain_id, resseq, icode)
                if key in seen:
                    continue
                seen.add(key)
                residues.append(resname)

    n_res = len(residues)
    if n_res == 0:
        raise RuntimeError("No residues parsed from PDB. Check file format.")
    if n_res < seg_len:
        raise ValueError(f"Sequence too short ({n_res}) for seg_len={seg_len}")

    if n_segments < 1:
        raise ValueError("n_segments must be >= 1")

    if overlap:
        if n_segments == 1:
            step = 0
        else:
            step = (n_res - seg_len) // (n_segments - 1)
            if step < 1:
                step = 1
    else:
        step = seg_len

    os.makedirs(os.path.dirname(out_csv) or ".", exist_ok=True)

    with open(out_csv, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["name", "seqres"])

        for i in range(n_segments):
            start0 = i * step
            end0 = start0 + seg_len
            if end0 > n_res:
                end0 = n_res
                start0 = end0 - seg_len

            # 1-based inclusive range for name
            start1 = start0 + 1
            end1 = end0
            name = f"{start1}-{end1}"

            seq = "".join(restype_3to1.get(rn, "X") for rn in residues[start0:end0])
            writer.writerow([name, seq])

    return out_csv

# =====================
in_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.pdb'
out_csv = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv'
split_pdb_to_seq_csv(
    in_pdb,
    out_csv,
    seg_len=256,
    n_segments=80,
    overlap=True,
    include_hetatm=False,
)

'/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv'

#### making segment pdb xtc

In [None]:
def cut_pdb_xtc_by_csv_ranges(
    in_pdb,
    in_xtc,
    in_csv,
    save_folder,
    stride=1,
    protein_only=True,
):
    """
    依照 csv 的 name 欄位（格式：start-end，例如 1-256）去切 pdb/xtc，
    並輸出到：
      save_folder/start-end/start-end.pdb
      save_folder/start-end/start-end.xtc

    假設：
    - csv 至少有欄位：name, seqres（seqres 這裡不會用到，只用 name）
    - start/end 是 1-based，且 end 是 inclusive
    - protein_only=True 時，start/end 是以「protein residues 的順序」計數
      (也就是 mdtraj topology 裡 r.is_protein 的那串殘基)
    """
    import os
    import csv
    import mdtraj as md

    if not os.path.isfile(in_pdb):
        raise FileNotFoundError(f"PDB not found: {in_pdb}")
    if not os.path.isfile(in_xtc):
        raise FileNotFoundError(f"XTC not found: {in_xtc}")
    if not os.path.isfile(in_csv):
        raise FileNotFoundError(f"CSV not found: {in_csv}")

    os.makedirs(save_folder, exist_ok=True)

    traj = md.load(in_xtc, top=in_pdb, stride=stride)
    topo = traj.topology

    if protein_only:
        res_list = [r.index for r in topo.residues if r.is_protein]
    else:
        res_list = [r.index for r in topo.residues]

    n_res = len(res_list)
    if n_res == 0:
        raise RuntimeError("No residues found (protein_only setting may be wrong).")

    written = []

    with open(in_csv, "r", encoding="utf-8", errors="ignore", newline="") as f:
        reader = csv.DictReader(f)
        if "name" not in reader.fieldnames:
            raise ValueError(f"CSV must contain a 'name' column, got: {reader.fieldnames}")

        for row in reader:
            name = (row.get("name") or "").strip()
            if not name:
                continue

            # parse "start-end"
            if "-" not in name:
                raise ValueError(f"Bad name format (expected start-end): {name}")
            s, e = name.split("-", 1)
            start = int(s)
            end = int(e)

            if start < 1 or end < start:
                raise ValueError(f"Invalid range: {name}")
            if end > n_res:
                raise ValueError(f"Range {name} exceeds available residues (n_res={n_res})")

            # convert to 0-based slice over res_list
            start0 = start - 1
            end0 = end  # python slice end is exclusive, but end is inclusive -> keep end0=end
            selected_res_indices = set(res_list[start0:end0])

            atom_indices = [a.index for a in topo.atoms if a.residue.index in selected_res_indices]
            if len(atom_indices) == 0:
                raise RuntimeError(f"Selection produced 0 atoms for range {name}")

            seg_traj = traj.atom_slice(atom_indices)

            seg_dir = os.path.join(save_folder, name)
            os.makedirs(seg_dir, exist_ok=True)

            out_pdb = os.path.join(seg_dir, f"{name}.pdb")
            out_xtc = os.path.join(seg_dir, f"{name}.xtc")

            seg_traj[0].save_pdb(out_pdb)
            seg_traj.save_xtc(out_xtc)

            written.append((name, out_pdb, out_xtc))

    return written


# =====================
in_pdb = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.pdb'
in_xtc = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO.xtc'
in_csv = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv'
save_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment'
cut_pdb_xtc_by_csv_ranges(
    in_pdb,
    in_xtc,
    in_csv,
    save_folder,
    stride=1,
    protein_only=True,
)

### preprocess-npy

#### npy

python -m scripts.prep_sims#4 --split /mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv --sim_dir /mnt/hdd/jeff/dataset/output/mdgen-collagen/segment --outdir /mnt/hdd/jeff/dataset/output/mdgen-collagen/npy --num_workers 4  --stride 40 --atlas

#### check npy

In [None]:
def print_npy_info(npy_path, max_items=20, max_width=120):
    import os
    if not os.path.isfile(npy_path):
        raise FileNotFoundError(f"NPY not found: {npy_path}")

    import numpy as np

    arr = np.load(npy_path, allow_pickle=True)
    print("path:", npy_path)
    print("dtype:", arr.dtype)
    print("shape:", arr.shape)

    # 內容預覽
    np.set_printoptions(threshold=max_items, linewidth=max_width)
    print("content preview:")
    print(arr)
    return arr

# ==================
npy_path = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/npy/1-256.npy'
print_npy_info(npy_path, max_items=20, max_width=120)

### inference model : mdgen-ckpt-from-paper

python sim_inference#4.py --sim_ckpt /mnt/hdd/jeff/dataset/output/mdgen-#1/ckpt/epoch=999-step=111000.ckpt --data_dir /mnt/hdd/jeff/dataset/output/mdgen-collagen/npy --num_frames 250 --num_rollouts 1 --split /mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv --out_dir /mnt/hdd/jeff/dataset/output/mdgen-collagen/inference --xtc

### metrix

#### residue number+1

In [None]:
def shift_residue_number_plus_one_in_folder(folder_path, recursive=True):
    """
    直接在資料夾內「就地修改」所有 .pdb：把 PDB 的 residue number(resSeq, 欄位 23–26) 全部 +1
    - .xtc 本身沒有 residue number（只存座標），所以不改；只會順便檢查同名 .xtc 是否存在。
    """
    import os

    if not os.path.isdir(folder_path):
        raise NotADirectoryError(f"Folder not found: {folder_path}")

    def iter_pdb_files(root):
        if recursive:
            for dp, _, fns in os.walk(root):
                for fn in fns:
                    if fn.lower().endswith(".pdb"):
                        yield os.path.join(dp, fn)
        else:
            for fn in os.listdir(root):
                if fn.lower().endswith(".pdb"):
                    yield os.path.join(root, fn)

    changed = []
    for pdb_path in iter_pdb_files(folder_path):
        with open(pdb_path, "r", encoding="utf-8", errors="ignore") as f:
            lines = f.readlines()

        n_changed = 0
        out_lines = []

        for line in lines:
            if (line.startswith("ATOM  ") or line.startswith("HETATM")) and len(line) >= 27:
                # PDB fixed columns: resSeq is columns 23-26 -> 0-based [22:26]
                resseq_raw = line[22:26]
                resseq_str = resseq_raw.strip()
                if resseq_str:
                    try:
                        resseq = int(resseq_str) + 1
                        line = line[:22] + f"{resseq:4d}" + line[26:]
                        n_changed += 1
                    except:
                        pass
            out_lines.append(line)

        if n_changed > 0:
            with open(pdb_path, "w", encoding="utf-8") as f:
                f.writelines(out_lines)

        base = os.path.splitext(pdb_path)[0]
        xtc_path = base + ".xtc"
        has_xtc = os.path.isfile(xtc_path)

        changed.append({
            "pdb": pdb_path,
            "xtc": xtc_path,
            "xtc_exists": has_xtc,
            "pdb_lines_modified": n_changed,
        })

    return changed


folder_path = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/inference'
shift_residue_number_plus_one_in_folder(folder_path)

#### metrix

In [None]:
import os
import pandas as pd
import subprocess
from tqdm import tqdm

inference_dir = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/inference'
csv = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/wh-wt-nc1atm/HYPPRO_split.csv'
log_dir = '/mnt/hdd/jeff/mdgen-#1/bin/log/collagen-0113'
os.makedirs(log_dir, exist_ok=True)

## making list
df = pd.read_csv(csv)
names = df['name'].tolist()
#壞掉的
broken_pdb = ["901-1156","937-1192"]
names = [n for n in names if str(n) not in broken_pdb] 
#已經做好的
pkl_dir = "/mnt/hdd/jeff/dataset/output/mdgen-collagen/pkl" 
done = {os.path.splitext(fn)[0] for fn in os.listdir(pkl_dir)}
names = [n for n in names if str(n) not in done]

for name in tqdm(names, desc="Running Analysis"):
    print(name)
    command = [
        "python", "-m", "scripts.analyze_ensembles#4",
        "--atlas_dir", "/mnt/hdd/jeff/dataset/output/mdgen-collagen/segment_order",
        "--pkl", "/mnt/hdd/jeff/dataset/output/mdgen-collagen/pkl",
        "--pdbdir", inference_dir,
        "--pdb_id", name,
        "--num_workers", "4"
    ]

    log_path = os.path.join(log_dir, f"{name}.log")
    try:
        with open(log_path, "w") as f:
            subprocess.run(command, check=True, stdout=f, stderr=subprocess.STDOUT, text=True)
    except subprocess.CalledProcessError as e:
        print(f"Error analyzing {name}. See log: {log_path}")


#### 移除沒算成功的

In [None]:
import pickle
import os 
in_dir = "/mnt/hdd/jeff/dataset/output/mdgen-collagen/pkl"
out_dir = "/mnt/hdd/jeff/dataset/output/mdgen-collagen/clean-pkl"

for name in os.listdir(in_dir):
    inp = os.path.join(in_dir,name)
    outp = os.path.join(out_dir,name)

    with open(inp, "rb") as f:
        data = pickle.load(f)

    clean = {k: v for k, v in data.items() if isinstance(v, dict) and ("error" not in v)}

    print("before:", len(data), "after:", len(clean))

    with open(outp, "wb") as f:
        pickle.dump(clean, f)

#### 多個pkl合併為一個

In [16]:
import os
import glob
import pickle

def merge_pkl_folder(pkl_folder, out_pkl_path):
    """
    把資料夾內所有 .pkl 合併成一個 .pkl
    輸出格式會是：{name: out_dict, name2: out_dict2, ...}
    以符合 scripts.print_analysis 的 analyze_data(data) 期待的結構。
    """
    pkl_files = sorted(glob.glob(os.path.join(pkl_folder, "*.pkl")))
    if not pkl_files:
        raise FileNotFoundError(f"No .pkl files found in: {pkl_folder}")

    merged = {}
    for fp in pkl_files:
        with open(fp, "rb") as f:
            obj = pickle.load(f)

        # obj 正常情況：dict(name -> out_dict)
        if isinstance(obj, dict):
            # 直接把這包的所有 (name -> out_dict) 展開進 merged
            # （避免你之前多包一層造成 KeyError）
            for k, v in obj.items():
                merged[k] = v
        else:
            # 很少見：如果某些檔案不是 dict，就略過或報錯
            raise ValueError(f"{fp} is not a dict. Got type={type(obj)}")

    os.makedirs(os.path.dirname(out_pkl_path), exist_ok=True)
    with open(out_pkl_path, "wb") as f:
        pickle.dump(merged, f)

    print(f"Merged {len(pkl_files)} files -> {out_pkl_path}")
    print(f"Total targets in merged dict: {len(merged)}")

# =====================
pkl_folder = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/pkl'
out_pkl = '/mnt/hdd/jeff/dataset/output/mdgen-collagen/print-pkl/merged.pkl'
merge_pkl_folder(pkl_folder, out_pkl)

Merged 78 files -> /mnt/hdd/jeff/dataset/output/mdgen-collagen/print-pkl/merged.pkl
Total targets in merged dict: 78


python -m scripts.print_analysis /mnt/hdd/jeff/dataset/output/mdgen-collagen/print-pkl/merged.pkl

In [None]:
def check_abs_time_emb_from_ckpt(ckpt_path, trust_ckpt=True):
    import os
    import torch
    import argparse

    if not os.path.isfile(ckpt_path):
        raise FileNotFoundError(f"ckpt not found: {ckpt_path}")

    def _extract_state_dict(obj):
        if isinstance(obj, dict):
            if "state_dict" in obj and isinstance(obj["state_dict"], dict):
                return obj["state_dict"]
            # some checkpoints may directly be a state_dict-like dict
            return obj
        return {}

    # --- 1) Try safest path: weights_only=True (no arbitrary code execution) ---
    ckpt = None
    state_dict = {}
    load_mode = None
    try:
        ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
        state_dict = _extract_state_dict(ckpt)
        load_mode = "weights_only=True"
    except Exception as e1:
        # --- 2) If user trusts the file, allowlist argparse.Namespace and load fully ---
        if not trust_ckpt:
            raise RuntimeError(
                "Failed to load with weights_only=True. "
                "Set trust_ckpt=True only if you trust the ckpt source.\n"
                f"Original error: {e1}"
            )
        try:
            # allowlist the class that blocks unpickling
            torch.serialization.add_safe_globals([argparse.Namespace])
            ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
            state_dict = _extract_state_dict(ckpt)
            load_mode = "weights_only=False (allowlisted argparse.Namespace)"
        except Exception as e2:
            raise RuntimeError(
                "Failed to load checkpoint in both safe and trusted modes.\n"
                f"Safe-mode error: {e1}\nTrusted-mode error: {e2}"
            )

    # --- check time embedding existence ---
    keys = list(state_dict.keys())
    has_time_embed_key = any(k.endswith("time_embed") or ".time_embed" in k for k in keys)

    # --- also try to read flags if present (best-effort) ---
    abs_time_emb_val = None
    no_rope_val = None
    num_frames_val = None

    if isinstance(ckpt, dict):
        hp = ckpt.get("hyper_parameters", {})
        args = hp.get("args", hp)

        if isinstance(args, dict):
            abs_time_emb_val = args.get("abs_time_emb", None)
            no_rope_val = args.get("no_rope", None)
            num_frames_val = args.get("num_frames", None)
        else:
            # sometimes args is a Namespace or an object
            abs_time_emb_val = getattr(args, "abs_time_emb", None)
            no_rope_val = getattr(args, "no_rope", None)
            num_frames_val = getattr(args, "num_frames", None)

    inferred_abs_time_emb = bool(abs_time_emb_val) if abs_time_emb_val is not None else bool(has_time_embed_key)

    result = {
        "ckpt_path": ckpt_path,
        "load_mode": load_mode,
        "has_time_embed_key_in_state_dict": has_time_embed_key,
        "abs_time_emb_in_hparams": abs_time_emb_val,   # may be None
        "inferred_abs_time_emb": inferred_abs_time_emb,
        "no_rope_in_hparams": no_rope_val,             # may be None
        "num_frames_in_hparams": num_frames_val,       # may be None
        "n_state_dict_keys": len(keys),
        "example_state_dict_keys": keys[:30],
    }

    print("=== CKPT CHECK ===")
    print("ckpt:", ckpt_path)
    print("loaded via:", load_mode)
    print("time_embed key in state_dict:", has_time_embed_key)
    print("abs_time_emb in hparams:", abs_time_emb_val)
    print("=> inferred abs_time_emb:", inferred_abs_time_emb)
    print("no_rope in hparams:", no_rope_val)
    print("num_frames in hparams:", num_frames_val)

    return result
check_abs_time_emb_from_ckpt('/mnt/hdd/jeff/dataset/output/mdgen-#0/ckpt/atlas.ckpt')