In [23]:
import glob
import json
import os
import shutil
from utils import *
import numpy as np

In [None]:
design_name = "test"
work_dir = "/home/ge/app/MSD_design/test"
pdbs_dir = "/home/ge/app/MSD_design/test/pdbs"
seed = 111
fixed_pos = "A207 A208"
designed_seqs_num = 100

MPNN_path = "/home/ge/app/MSD_design/LigandMPNN_MSD/run.py"
run_mpnn_msd_env = "/data/ge/conda/envs/mpnn_env"

ret_params = "/home/ge/app/MSD_design/test/RET.params"

run_af3_env = "/venv/alphafold3_venv/bin/activate"
run_af3_py = "/home/ge/app/MSD_design/run_alphafold.py"

run_only_con_env = "/data/ge/conda/envs/alphafold3"
run_only_con_py = "/home/ge/app/MSD_design/run_only_confidence1.py"

- design_name  
---|pdbs                # 所有的state结构，基于这些state去设计序列  
---|designed_sequences  # 设计序列的保存路径  
------|packed           # 同一个序列不同结构的packed结果（pyrosetta）  
------|seqs             # 设计的MSD序列  
---|run_af3             # 运行完整的single sequence af3的结果路径  
---|run_only_con        # 只运行confidence head的结果路径

# MSD design sequences
---

In [None]:
# more states pdb to design
pdbs = glob.glob(f'{pdbs_dir}/*.pdb')
dic = {}
for pdb in pdbs:
    pdb_name = pdb.split("/")[-1].split(".")[0]
    dic[pdb] = pdb_name
with open(f"{pdbs_dir}/multi.json", "w") as f:
    json.dump(dic, f, indent=4)

MSD_bash_file = f"{work_dir}/MSD_design_sequences.sh"
MSD_out = f"{work_dir}/designed_sequences"
if os.path.exists(MSD_out):
    shutil.rmtree(MSD_out)
else:
    os.mkdir(MSD_out)

with open(MSD_bash_file, "w") as f:
    f.write(f"""#!/bin/bash
source /opt/conda/etc/profile.d/conda.sh
conda activate {run_mpnn_msd_env}
python {MPNN_path} \\
    --seed 111 \\
    --pdb_path_multi "{pdbs_dir}/multi.json" \\
    --out_folder "{MSD_out}" \\
    --multistate_design True  \\
    --number_of_batches {designed_seqs_num} \\
    --fixed_residues "{fixed_pos}"
""")
# os.chmod(MSD_bash_file, 0o755)

In [None]:
! bash MSD_design_sequences.sh

In [24]:
fa_file = glob.glob(f"{MSD_out}/seqs/*.fa")[0]
designed_seqs = {}
with open(fa_file, "r") as f:
    lines = f.readlines()
    for i in range(3, len(lines), 2):
        la = lines[i-1].strip().split(">")[1]
        seq_name = f"{design_name}_{la.split('id=')[1].split(',')[0]}"
        designed_seqs[seq_name] = lines[i].strip()


# Pyrosetta packed structures
---

In [None]:
import pyrosetta
from pyrosetta import pose_from_pdb
from pyrosetta.rosetta.core.chemical import TypeSetMode
from pyrosetta.rosetta.core.pack.task import TaskFactory
from pyrosetta.rosetta.core.conformation import ResidueFactory
from pyrosetta.rosetta.core.pack.task.operation import InitializeFromCommandline, RestrictToRepacking
from pyrosetta.rosetta.protocols.minimization_packing import PackRotamersMover
aa1_to_aa3 = {
    "A": "ALA", "R": "ARG", "N": "ASN", "D": "ASP",
    "C": "CYS", "Q": "GLN", "E": "GLU", "G": "GLY",
    "H": "HIS", "I": "ILE", "L": "LEU", "K": "LYS",
    "M": "MET", "F": "PHE", "P": "PRO", "S": "SER",
    "T": "THR", "W": "TRP", "Y": "TYR", "V": "VAL"
}
extra_res_fa = ""
params = [ret_params]
if len(params) > 0:
    extra_res_fa = "-extra_res_fa"
    for p in params:
        extra_res_fa += " " + p
pyrosetta.init(f"{extra_res_fa}")

def pack_sequence_on_backbone(pdb_file, sequence, scorefxn_name="ref2015",out_file="packed_output.pdb"):
    pose = pose_from_pdb(pdb_file)
    rts = pose.residue_type_set_for_pose(TypeSetMode.FULL_ATOM_t)

    for i, aa in enumerate(sequence, start=1):
        resname3 = aa1_to_aa3[aa]
        new_res = ResidueFactory.create_residue(rts.name_map(resname3))
        pose.replace_residue(i, new_res, True)
    tf = TaskFactory()
    tf.push_back(InitializeFromCommandline())
    tf.push_back(RestrictToRepacking())

    scorefxn = pyrosetta.create_score_function(scorefxn_name)
    packer = PackRotamersMover(scorefxn)
    packer.task_factory(tf)
    packer.apply(pose)
    pose.dump_pdb(out_file)


In [None]:
packed_dir = f"{MSD_out}/packed"
for seqs_name, seqs in designed_seqs.items():
    for pdb in pdbs:
        pdb_name = pdb.split("/")[-1].split(".")[0]
        packed_pdb_file = f"{packed_dir}/{seqs_name}__{pdb_name}.pdb"
        pack_sequence_on_backbone(pdb, seqs, out_file=packed_pdb_file)

# Run af3 single-sequence (to get plddt_0)
---

In [None]:
# run af3
af3_dir = f"{work_dir}/run_af3"
if os.path.exists(af3_dir):
    shutil.rmtree(af3_dir)
    os.mkdir(af3_dir)
else:
    os.mkdir(af3_dir)

af3_input_json_dir = f"{af3_dir}/input_json"
if os.path.exists(af3_input_json_dir):
    shutil.rmtree(af3_input_json_dir)
    os.mkdir(af3_input_json_dir)
else:
    os.mkdir(af3_input_json_dir)
for seqs_name, seq in designed_seqs.items():
    json_dic = af3_json(pdb_sequence=seq, chain_id=['A'], name=seqs_name, seed=[1], single=True, ligandccd=['RET'], ligand_id=['B'], modify=None)
    json_file = f"{af3_input_json_dir}/{seqs_name}.json"
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)

out_dir = f"{af3_dir}/outputs"
if os.path.exists(out_dir):
    shutil.rmtree(out_dir)
else:
    os.mkdir(out_dir)
af3_bash_file = f"{work_dir}/run_af3.sh"
with open(af3_bash_file, "w") as f:
    f.write(f"""#!/bin/bash
source "{run_af3_env}"
python {run_af3_py} \\
    --input_dir="{af3_input_json_dir}" \\
    --model_dir="/data/share/alphafold3"  \\
    --output_dir="{out_dir}" \\
    --max_template_date="9999-01-01" \\
    --run_data_pipeline=False""")

In [None]:
! bash run_af3.sh

In [44]:
af3_plddt_dic={}
af3_dir = f"{work_dir}/run_af3"
out_dir = f"{af3_dir}/outputs"
plddt_files = glob.glob(f"{out_dir}/*/*confidences.json")
plddt_files = [i for i in plddt_files if "summary" not in i]
for plddt_file in plddt_files:
    with open(plddt_file, "r") as f:
        plddt_data = json.load(f)
        name = plddt_file.split("/")[-2]
        af3_plddt_dic[name] = np.mean(plddt_data["atom_plddts"])

# Cacultate cross entropy loss in distogram
---

In [None]:
packed_pdbs = glob.glob(f"{packed_dir}/*.pdb")
dist_loss = {}
dist_logits_files = glob.glob(f"{out_dir}/*/distogram_logits.npz")
for packed_pdb in packed_pdbs:
    packed_pdb_name = packed_pdb.split("/")[-1].split(".")[0]
    seq_name = packed_pdb_name.split("__")[0]
    for dist_logits_file in dist_logits_files:
        dist_seq_name = dist_logits_file.split("/")[-2]
        if seq_name == dist_seq_name:
            data = np.load(dist_logits_file)
    dist_logits = data['distogram_logits']
    num_tokens = data['num_tokens']
    dist_logits = dist_logits[:num_tokens, :num_tokens,:]
    gt_pos = pdb_to_tensor(packed_pdb)
    mask = np.ones(gt_pos.shape[0])
    assert dist_logits.shape[0] == gt_pos.shape[0], "Distogram logits and ground truth positions must have the same number of residues."
    loss = distogram_loss(dist_logits, gt_pos, mask)
    dist_loss[packed_pdb_name] = loss


# Run only confidence head (to get plddt_1 plddt_2)
---

In [None]:
# run only confidence
only_con_dir = f"{work_dir}/run_only_con"
if os.path.exists(only_con_dir):
    shutil.rmtree(only_con_dir)
    os.mkdir(only_con_dir)
else:
    os.mkdir(only_con_dir)

only_con_sh = f"{work_dir}/run_only_con.sh"
only_input_dir = f"{only_con_dir}/input_json"
if not os.path.exists(only_input_dir):
    os.mkdir(only_input_dir)

only_output_dir = f"{only_con_dir}/outputs"
packed_pdbs = glob.glob(f"{packed_dir}/*.pdb")
for pdb in packed_pdbs:
    packed_name = pdb.split("/")[-1].split(".")[0]
    seq_name = packed_name.split("__")[0]
    json_dic = af3_json(pdb_sequence=designed_seqs[seq_name], chain_id=['A'], name=packed_name, seed=[1], single=True, ligandccd=['RET'], ligand_id=['B'], modify=None)
    json_file = f"{only_input_dir}/{packed_name}.json"
    with open(json_file, "w") as f:
        json.dump(json_dic, f, indent=4)


with open(only_con_sh, "w") as f:
    f.write(f"""#!/bin/bash
source /opt/conda/etc/profile.d/conda.sh
conda activate {run_only_con_env}
INPUT_JSON="{only_input_dir}"
INPUT_PDB="{packed_dir}"
for pdb_file in "$INPUT_PDB"/*.pdb; do
    filename=$(basename "$pdb_file" .pdb)
    json_file="$INPUT_JSON/$filename.json"
    if [[ -f "$json_file" ]]; then
        echo "Processing $filename"
        python {run_only_con_py} \\
        --json_path="$json_file" \\
        --model_dir="/data/share/alphafold3" \\
        --output_dir="{only_output_dir}" \\
        --max_template_date="9999-01-01" \\
        --run_data_pipeline=False \\
        --structure_pdb_path="$pdb_file"
    else
        echo "JSON file $json_file does not exist for $filename"
    fi
done
""")

In [None]:
! bash run_only_con.sh

In [38]:
only_con_plddt = glob.glob(f"{only_output_dir}/*/*_atomplddt.json")
only_con_plddt_dic = {}
for plddt_file in only_con_plddt:
    with open(plddt_file, "r") as f:
        plddt_data = json.load(f)
        name = plddt_file.split("/")[-2]
        only_con_plddt_dic[name] = np.mean(plddt_data["atom_plddts"][0])

# Integrate all information
---

In [None]:
# 统计所有信息，每个序列有三个plddt，两个loss
total_info ={}
for seqs_name in designed_seqs.keys():
    total_info[seqs_name] = {}
    total_info[seqs_name]["seq"] = designed_seqs[seqs_name]
    total_info[seqs_name]["plddt_0"] = af3_plddt_dic.get(f"{seqs_name}", None)
    for key, value in only_con_plddt_dic.items():
        if key.startswith(seqs_name):
            if "plddt_1_2" not in total_info[seqs_name]:
                total_info[seqs_name]["plddt_1_2"] = [value]
                total_info[seqs_name]["dist_loss_1_2"] = [float(dist_loss.get(key, None))]
            else:
                total_info[seqs_name]["plddt_1_2"].append(value)
                total_info[seqs_name]["dist_loss_1_2"].append(float(dist_loss.get(key, None)))


In [52]:
total_info_file = f"{work_dir}/total_info.json"
with open(total_info_file, "w") as f:
    json.dump(total_info, f, indent=4)