In [1]:
import os
import esm
import time
import csv
import torch
from data import utils as du
from typing import Dict
from Bio import SeqIO


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def run_folding(sequence, save_path, target_gpu_index):
    """Run ESMFold on sequence."""


    if torch.cuda.is_available():
        device_index = target_gpu_index
        device_name = f"cuda:{device_index}"
        device = torch.device(device_name)
        model = esm.pretrained.esmfold_v1()
        model = model.to(device)
        model = model.eval()    
    
    
        
    with torch.no_grad():
        output = model.infer_pdb(sequence)

    with open(save_path, "w") as f:
        f.write(output)
    import biotite.structure.io as bsio
    struct = bsio.load_structure(save_path, extra_fields=["b_factor"])
    plddt = struct.b_factor.mean()
    return output, plddt


In [3]:
def run_esm_all(sample_sequences, decoy_pdb_dir, model_dataset):
    total_samples = len(sample_sequences)
    start_time = time.time()

    os.makedirs(decoy_pdb_dir, exist_ok=True)
    plddt_csv_path = os.path.join(decoy_pdb_dir, model_dataset, "plddt_values.csv")

    write_header = not os.path.exists(plddt_csv_path)

    # 使用 'a' 模式打开 CSV 文件并实时写入
    with open(plddt_csv_path, "a", newline="") as csv_file:
        csv_writer = csv.writer(csv_file)
        if write_header:
            csv_writer.writerow(["Sample", "PLDDT"])

        for idx, (sample, sequence) in enumerate(sample_sequences.items(), 1):
            esmf_dir = os.path.join(decoy_pdb_dir, model_dataset)
            os.makedirs(esmf_dir, exist_ok=True)
            esmf_sample_path = os.path.join(esmf_dir, f'{sample}.pdb')

            # 检查文件是否已存在，如果存在则跳过
            if os.path.exists(esmf_sample_path):
                print(f"File {esmf_sample_path} already exists. Skipping...")
                continue

            # 运行折叠函数，获取输出和PLDDT值
            output, plddt_value = run_folding(sequence, esmf_sample_path, 0)
            print(plddt_value)

            # 将结果实时写入 CSV 文件
            csv_writer.writerow([sample, plddt_value])

            # 计算并打印进度
            elapsed_time = time.time() - start_time
            progress_percentage = (idx / total_samples) * 100
            print(f"Progress: {progress_percentage:.2f}% | Elapsed Time: {elapsed_time:.2f}s | ESM results saved for {sample} in {esmf_sample_path}")

    print(f"ESM calculations for all samples completed and saved to CSV at {plddt_csv_path}.")


In [4]:

def read_predicted_sequences(fasta_file_path):
    predicted_sequences = {}

    for record in SeqIO.parse(fasta_file_path, "fasta"):
        header = record.description
        if "Predicted sequence" in header:
            pdb_id = header.split()[0].split('.')[0]
            sequence = str(record.seq)
            predicted_sequences[pdb_id] = sequence

    return predicted_sequences


# CATH4.2 PiFold

In [5]:
fasta_file_path = "/home/zhengsun/code/protein/ProteinInvBench/results/fasta_files/CATH4.2_PiFold.fasta"
sample_sequences = read_predicted_sequences(fasta_file_path)

In [7]:
decoy_pdb_dir = "/home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb"
run_esm_all(sample_sequences, decoy_pdb_dir, 'cath42pifold')

File /home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold/1a1x.pdb already exists. Skipping...
File /home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold/1a2p.pdb already exists. Skipping...
File /home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold/1a32.pdb already exists. Skipping...
62.47464489795919
Progress: 0.37% | Elapsed Time: 28.23s | ESM results saved for 1a73 in /home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold/1a73.pdb
88.64814387699066
Progress: 0.46% | Elapsed Time: 56.94s | ESM results saved for 1a8l in /home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb/cath42pifold/1a8l.pdb


KeyboardInterrupt: 

# CATH4.2 ProteinMPNN

In [None]:
fasta_file_path = "/home/zhengsun/code/protein/ProteinInvBench/results/fasta_files/CATH4.2_ProteinMPNN.fasta"
sample_sequences = read_predicted_sequences(fasta_file_path)

In [None]:
decoy_pdb_dir = "/home/zhengsun/code/protein/ProteinInvBench/results/esm_fold_pdb"
run_esm_all(sample_sequences, decoy_pdb_dir, 'cath42mpnn')