In [1]:
import argparse
import random
from tqdm import tqdm
from datetime import datetime, timedelta
from collections import OrderedDict

import pandas as pd

import torch
import numpy as np
import pandas as pd

from configs import *
from inference import *
from sampling import *

from ofold.np import residue_constants

from flowmatch import flowmatcher

from model import main_network
from flowmatch.data import utils as du
from evaluation.metrics import *
from evaluation.loss import *
from data.utils import *
from data.loader import *
from data.data import *

from Bio.PDB import PDBParser, PDBIO

INFO: Using numpy backend
INFO: Enabling RDKit 2022.09.5 jupyter extensions


In [2]:
# Args

args = Args()
args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if args.discrete_flow_type == 'uniform':
    args.num_aa_type = 20
    args.masked_aa_token_idx = None


    if args.flow_msa:
        args.msa.num_msa_vocab = 64
        args.msa.masked_msa_token_idx = None

    if args.flow_ec:
        args.ec.num_ec_class = 6
        args.ec.masked_ec_token_idx = None
        

# discrete
elif args.discrete_flow_type == 'masking':
    args.num_aa_type = 21
    args.masked_aa_token_idx = 20
    args.aa_ot = False


    if args.flow_msa:
        args.msa.num_msa_vocab = 65
        args.msa.masked_msa_token_idx = 64
        args.msa_ot = False

    if args.flow_ec:
        args.ec.num_ec_class = 7
        args.ec.masked_ec_token_idx = 6

else:
    raise ValueError(f'Unknown discrete flow type {args.discrete_flow_type}')

args.n_res_design = 32

In [3]:
# Loading Model

flow_model = flowmatcher.SE3FlowMatcher(args)
model = main_network.ProteinLigandNetwork(args)
model = model.to(args.device)

ckpt_path = 'checkpoint/enzymeflow_mini.ckpt'
if ckpt_path:
    print(f'loading pretrained weights for enzymeflow {ckpt_path}')
    checkpoint = torch.load(ckpt_path, map_location='cpu')
    model_state_dict = checkpoint["model_state_dict"]

    new_state_dict = OrderedDict()
    for k, v in model_state_dict.items():
        name = k # remove `module.`
        new_state_dict[name] = v

    model.load_state_dict(new_state_dict, strict=True)

loading pretrained weights for enzymeflow checkpoint/enzymeflow_mini.ckpt


  checkpoint = torch.load(ckpt_path, map_location='cpu')


In [4]:
# Not needed

# from pifold import folding_network
# inversefold_model = folding_network.ProDesign_Model(args.inverse_folding)

# inversefold_ckpt_path = 'pifold/weights/enzymefold.ckpt'
# if inversefold_ckpt_path:
#     print(f'loading pretrained weights for pifold {inversefold_ckpt_path}')
#     checkpoint = torch.load(inversefold_ckpt_path, map_location='cpu')
#     inversefold_model.load_state_dict(checkpoint, strict=False)

loading pretrained weights for pifold pifold/weights/enzymefold.ckpt


  checkpoint = torch.load(inversefold_ckpt_path, map_location='cpu')


In [4]:
import copy

def process_protein(args, chain_feats):
    gt_bb_rigid = ru.Rigid.from_tensor_4x4(chain_feats["rigidgroups_1"])[:, 0]
    flowed_mask = np.ones(args.n_res_design)
    flow_mask = np.ones(args.n_res_design)
    chain_feats["res_mask"] = flow_mask
    chain_feats["flow_mask"] = flow_mask
    chain_feats["rigids_1"] = gt_bb_rigid.to_tensor_7()
    chain_feats["sc_ca_t"] = torch.zeros(args.n_res_design, 3)
    chain_feats["sc_aa_t"] = torch.zeros(args.n_res_design, args.num_aa_type)

    #remove unused features
    del chain_feats["residx_atom14_to_atom37"], chain_feats["atom37_pos"], chain_feats["atom37_mask"], chain_feats["atom14_pos"], chain_feats["atom37_pos_before_com"], chain_feats["torsion_angles_sin_cos"]
    return chain_feats

def process_ligand(ligand_feats, guiding_mol):
    updated_ligand_feats = {}
    # process ligand, move to CoM
    ligand_atom_feat = torch.tensor(ligand_feats["ligand_feat"]).long()
    ligand_atom_coord = torch.tensor(ligand_feats["ligand_pos_after_com"]).double()
    ligand_atom_mask = torch.ones_like(ligand_atom_feat)
    
    updated_ligand_feats["ligand_atom"] = ligand_atom_feat
    updated_ligand_feats["ligand_pos"] = ligand_atom_coord
    updated_ligand_feats["ligand_mask"] = ligand_atom_mask

    # guiding_mol = processed_feats["product"]
    guiding_atom_feat = torch.tensor(guiding_mol["molecule_atom_feat"]).long()
    guiding_edge_feat = torch.tensor(guiding_mol["molecule_edge_feat"]).long()
    guiding_edge_index = torch.tensor(guiding_mol["molecule_edge_idx"]).long()
    guiding_atom_mask = torch.ones_like(guiding_atom_feat)
    guiding_edge_mask = torch.ones_like(guiding_edge_feat)
    updated_ligand_feats["guide_ligand_atom"] = guiding_atom_feat
    updated_ligand_feats["guide_ligand_edge"] = guiding_edge_feat
    updated_ligand_feats["guide_ligand_edge_index"] = guiding_edge_index
    updated_ligand_feats["guide_ligand_atom_mask"] = guiding_atom_mask
    updated_ligand_feats["guide_ligand_edge_mask"] = guiding_edge_mask
    return updated_ligand_feats

def gen_data(args, gen_model, protein, ligand):
    gt_bb_rigid = ru.Rigid.from_tensor_7(protein["rigids_1"])
    gt_trans, gt_rot = extract_trans_rots_mat(gt_bb_rigid)
    protein["trans_1"] = gt_trans
    protein["rot_1"] = gt_rot

    if args.n_res_design != protein["aatype"].size(0):
        protein["seq_idx"] = torch.arange(args.n_res_design) + 1
        protein["residue_idx"] = torch.arange(args.n_res_design) + 1
        
    aatype_1 = F.one_hot(protein["aatype"], num_classes=args.num_aa_type)

    t = 0.
    gen_feats_t = gen_model.sample_ref(
        n_samples=args.n_res_design,
        flow_mask=None,
        as_tensor_7=True,
        center_of_mass=None,
    )

    aatype_0 = torch.rand(args.n_res_design)
    aatype_t = gen_model.forward_masking(
        feat_0=aatype_0,
        feat_1=None,
        t=0.,
        mask_token_idx=args.masked_aa_token_idx,
        flow_mask=None,
    )

    protein["aatype_t"] = aatype_t
    protein.update(gen_feats_t)
    protein["t"] = t

    final_feats = {}
    for k, v in protein.items():
        if not torch.is_tensor(v):
            v = torch.tensor(v)
        
        if k in {"residx_atom14_to_atom37", "atom37_pos", "atom14_pos", "atom37_mask"}:
            continue

        else:
            final_feats[k] = v

    final_feats.update(ligand)
    return final_feats


def write_prot_to_pdb(
    prot_pos: np.ndarray,
    file_path: str,
    aatype: np.ndarray = None,
    overwrite=False,
    no_indexing=False,
    b_factors=None,
    residue_index=None
):
    if overwrite:
        max_existing_idx = 0
    else:
        file_dir = os.path.dirname(file_path)
        file_name = os.path.basename(file_path).strip(".pdb")
        existing_files = [x for x in os.listdir(file_dir) if file_name in x]
        max_existing_idx = max(
            [
                int(re.findall(r"_(\d+).pdb", x)[0])
                for x in existing_files
                if re.findall(r"_(\d+).pdb", x)
                if re.findall(r"_(\d+).pdb", x)
            ]
            + [0]
        )
    if not no_indexing:
        save_path = file_path.replace(".pdb", "") + f"_{max_existing_idx+1}.pdb"
    else:
        save_path = file_path
    with open(save_path, "w") as f:
        if prot_pos.ndim == 4:
            for t, pos14 in enumerate(prot_pos):
                atom14_mask = np.sum(np.abs(pos14), axis=-1) > 1e-7
                prot = create_full_prot(
                    pos14, atom14_mask, aatype=aatype, b_factors=b_factors, residue_index=residue_index
                )
                pdb_prot = protein.to_pdb(prot, model=t+1, add_end=False)
                f.write(pdb_prot)
        elif prot_pos.ndim == 3:
            atom14_mask = np.sum(np.abs(prot_pos), axis=-1) > 1e-7
            prot = create_full_prot(
                prot_pos, atom14_mask, aatype=aatype, b_factors=b_factors, residue_index=residue_index
            )
            pdb_prot = to_pdb(prot, model=1, add_end=False)
            f.write(pdb_prot)
        else:
            raise ValueError(f"Invalid positions shape {prot_pos.shape}")
        f.write("END")
    return save_path


def write_pdb_traj(args, feats_0, feats_1, parent_dir, pdb_name, substrate_name, sample_id=0):
    final_prot = {
                "t_1": feats_1["t"][0],
                "pos_1": feats_1["coord_traj"][0],
                "aa_1": feats_1["aa_traj"][0],
            }
    
    CA_IDX = residue_constants.atom_order["CA"]
    res_mask = du.move_to_np(feats_0["res_mask"].bool())
    flow_mask = du.move_to_np(feats_0["flow_mask"].bool())
    res_index = du.move_to_np(feats_0["residue_index"])
    
    ligand_pos = du.move_to_np(feats_0["ligand_pos"])
    ligand_atom = du.move_to_np(feats_0["ligand_atom"])
    ligand_mask = du.move_to_np(feats_0["ligand_mask"].bool())
    batch_size = res_mask.shape[0]
    
    for i in range(batch_size):
        num_res = int(np.sum(res_mask[i]).item())
        unpad_flow_mask = flow_mask[i][res_mask[i]]
        unpad_protein = {
            "pos": final_prot['pos_1'][i][res_mask[i]],
            "aatype": final_prot['aa_1'][i][res_mask[i]],
        }
        
        pred_aatype = unpad_protein["aatype"]
        pred_portein_pos = unpad_protein["pos"]
            
        generated_dir = parent_dir
        generated_prot = pdb_name
        prot_dir = os.path.join(generated_dir, generated_prot)
        if not os.path.isdir(prot_dir):
            os.makedirs(prot_dir, exist_ok=True)
                    
        prot_path = os.path.join(prot_dir, f"sample_{sample_id}.pdb")
    
        saved_path = write_prot_to_pdb(
                        prot_pos=pred_portein_pos,
                        file_path=prot_path,
                        aatype=pred_aatype,
                        no_indexing=True,
                        b_factors=np.tile(unpad_flow_mask[..., None], 37) * 100,
                        residue_index=res_index[i],
                    )


# N_idx = residue_constants.atom_order["N"]
# Ca_idx = residue_constants.atom_order["CA"]
# C_idx = residue_constants.atom_order["C"]
# O_idx = residue_constants.atom_order["O"]

# def frames_to_inversefold(args, frames):
#     _, atom_mask, _, atom_pos = all_atom.to_atom37(frames)
#     batch_size = atom_pos.shape[0]

#     atom_pos = atom_pos[:, :, :4]
#     atom_mask = atom_mask[:, :, 0]
#     score = torch.zeros([batch_size, args.n_res_design]).to(args.device) + 100.0
#     atom_mask = atom_mask.to(dtype=torch.float32)
#     return atom_pos, score, atom_mask
    

# def inverse_fold(args, inversefold_model, feats, temperature=0.1):
#     batch_size, n_res = feats['aa_traj'][0].shape
#     rigids = torch.tensor(feats['rigid_traj'][0]).to(args.device)
#     frames = ru.Rigid.from_tensor_7(rigids)
#     pos, score, mask = frames_to_inversefold(args, frames)
#     X, score, h_V, h_E, E_idx, batch_id, mask_bw, mask_fw, decoding_order = inversefold_model._get_features(score, X=pos, mask=mask)
#     log_probs, logits = inversefold_model(h_V, h_E, E_idx, batch_id, return_logit = True)

#     probs = F.softmax(logits / temperature, dim=-1)
#     pred = torch.multinomial(probs, 1).view(batch_size, n_res)
#     return pred

In [5]:
pdb_name = 'X8HS19'
mol_name = '8025'
ligand_path = f'data/molecule_structures/{mol_name}.mol2'
pocket_path = f'data/pocket_fixed_residues/pdb_10A/{pdb_name}.pdb'
protein_path = f'data/pocket_fixed_residues/pdb_10A/{pdb_name}.pdb'
product_smiles = 'OC(=O)CC(C(=O)O)NC(=O)c1ncn(c1N)[C@@H]1O[C@@H]([C@H]([C@H]1O)O)COP(=O)(O)O'

ligand = process_mol(ligand_path)
pocket = process_pdb(pocket_path)
protein = process_pdb(protein_path)
product = process_smiles(product_smiles)

chain_feats, ligand_feats, guiding_mol = copy.copy(pocket), copy.copy(ligand), copy.copy(product)

protein, ligand = process_protein(args, chain_feats), process_ligand(ligand_feats, guiding_mol)

In [6]:
# Run EnzymeFlow Sampling

args.flow_msa = False
args.flow_ec = False
args.eval.aa_temp = 10.
args.eval.msa_temp = 10.
args.eval.ec_temp = 10.
args.eval.aa_noise = 20.
args.eval.msa_noise = 40.
args.eval.ec_noise = 0.

n_sample = 10
parent_dir = os.path.join('.')
os.makedirs(parent_dir, exist_ok=True)

for sample_idx in tqdm(range(0, n_sample)):
    seed = random.randint(0, 10000)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    feats_0 = gen_data(args, flow_model, protein, ligand)
    feats_0 = {key: value.unsqueeze(0) for key, value in feats_0.items()}

    feats_1 = sampling_inference(
                args,
                init_feats = feats_0,
                gen_model = flow_model,
                main_network = model,
                min_t = 0.0,
                max_t = 1.0,
                num_t = 10,
                self_condition = True,
                center = True,
                aa_do_purity = True,
                msa_do_purity = False,
                ec_do_purity = False,
                rot_sample_schedule = 'exp',
                trans_sample_schedule = 'linear',
            )

    # aatype = inverse_fold(args, inversefold_model, feats_1)
    # feats_1['aa_traj'][0] = aatype

    write_pdb_traj(
        args,
        feats_0=feats_0, 
        feats_1=feats_1, 
        parent_dir=parent_dir, 
        pdb_name=pdb_name, 
        substrate_name=None, 
        sample_id=sample_idx,
    )    

100%|█████████████████████████████████████████████| 10/10 [00:14<00:00,  1.48s/it]


In [7]:
import py3Dmol

generated_prot = pdb_name
pred_prot_path = os.path.join(parent_dir, generated_prot, f"sample_0.pdb")
with open(os.path.join(pred_prot_path)) as ifile:
    pred_system = "".join([f for f in ifile])

view = py3Dmol.view(width=300, height=300)
view.addModelsAsFrames(pred_system)
view.setStyle({'model': 0}, {"cartoon": {'color': 'spectrum'}})
view.zoomTo()
view.show()

In [8]:
with open(pocket_path) as ifile:
    gt_system = "".join([f for f in ifile])

view = py3Dmol.view(width=300, height=300)
view.addModelsAsFrames(gt_system)
view.setStyle({'model': 0}, {"cartoon": {'color': 'spectrum'}})
view.zoomTo()
view.show()