In [1]:
import argparse
import json, time, os
import numpy as np
import torch
import copy
import os.path
import subprocess
from vanilla_proteinmpnn.protein_mpnn_utils import _scores, _S_to_seq, tied_featurize, parse_PDB
from vanilla_proteinmpnn.protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

  from .autonotebook import tqdm as notebook_tqdm


In [33]:
args = argparse.Namespace()
args.path_to_model_weights = 'run/model_weights/'
args.model_name = 'v_48_020'
args.save_score = 0
args.save_probs = 0
args.score_only = 1
args.conditional_probs_only = 0
args.conditional_probs_only_backbone = 0
args.unconditional_probs_only = 0
args.backbone_noise = 0.00
args.num_seq_per_target = 20
args.batch_size = 20
args.max_length = 20000
args.sampling_temp = '0.1'
args.out_folder = 'sandbox_outs/novozyme/'
args.pdb_path = 'sandbox_outs/novozyme/wildtype_structure_prediction_af2.pdb'
args.pdb_path_chains = ''
args.jsonl_path = ''
args.chain_id_jsonl = ''
args.fixed_positions_jsonl = ''
args.omit_AAs = 'X'
args.bias_AA_jsonl = ''
args.bias_by_res_jsonl = ''
args.omit_AA_jsonl = ''
args.pssm_jsonl = ''
args.pssm_multi = 0.0
args.pssm_threshold = 0.0
args.pssm_log_odds_flag = 0
args.pssm_bias_flag = 0
args.tied_positions_jsonl = ''

In [3]:
hidden_dim = 128
num_layers = 3

In [4]:
if args.path_to_model_weights:
    model_folder_path = args.path_to_model_weights
    if model_folder_path[-1] != '/':
        model_folder_path = model_folder_path + '/'
else: 
    file_path = os.path.realpath(__file__)
    k = file_path.rfind("/")
    model_folder_path = file_path[:k] + '/vanilla_model_weights/'

In [5]:
checkpoint_path = model_folder_path + f'{args.model_name}.pt'
folder_for_outputs = args.out_folder

In [34]:
NUM_BATCHES = args.num_seq_per_target // args.batch_size
BATCH_COPIES = args.batch_size
temperatures = [float(temp) for temp in args.sampling_temp.split()]
omit_AAs_list = args.omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

In [7]:
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
test_set_file = 'sandbox_outs/novozyme/test.csv'
with open(test_set_file, 'r') as f:
    test_set_list = f.readlines()
test_set = [mut.split(',')[:2] for mut in test_set_list[1:]]
test_set_seqs = [mut[1] for mut in test_set]

wt_seq = 'VPVNPEPDATSVENVALKTGSGDSQSDPIKADLEVKGQSALPFDVDCWAILCKGAPNVLQRVNEKTKNSNRDRSGANKGPF' \
         'KDPQKWGIKALPPKNPSWSAQDFKSPEEYAFASSLQGGTNAILAPVNLASQNSQGGVLNGFYSANKVAQFDPSKPQQTKGT' \
         'WFQITKFTGAAGPYCKALGSNDKSVCDKNKNIAGDWGFDPAKWAYQYDEKNNKFNYVGK'

for i, seq in enumerate(test_set_seqs):
    if len(seq) < len(wt_seq):
        k = next((idx for idx, res in enumerate(seq) if res != wt_seq[idx]), None)
        test_set_seqs[i] = seq[:k] + '-' + seq[k:]

In [9]:
if os.path.isfile(args.chain_id_jsonl):
    with open(args.chain_id_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        chain_id_dict = json.loads(json_str)
else:
    chain_id_dict = None
    print(40 * '-')
    print('chain_id_jsonl is NOT loaded')

----------------------------------------
chain_id_jsonl is NOT loaded


In [10]:
if os.path.isfile(args.fixed_positions_jsonl):
    with open(args.fixed_positions_jsonl, 'r') as json_file:
        json_list = list(json_list)
    for json_str in json_list:
        fixed_positions_dict = json.loads(json_str)
else:
    print(40 * '-')
    print('fixed_positions_jsonl is NOT loaded')
    fixed_positions_dict = None

----------------------------------------
fixed_positions_jsonl is NOT loaded


In [11]:
if os.path.isfile(args.pssm_jsonl):
    with open(args.pssm_jsonl, 'r') as json_file:
        json_list = list(json_file)
    pssm_dict = {}
    for json_str in json_list:
        pssm_dict.update(json.loads(json_str))
else:
    print(40 * '-')
    print('pssm_jsonl is NOT loaded')
    pssm_dict = None

----------------------------------------
pssm_jsonl is NOT loaded


In [12]:
if os.path.isfile(args.omit_AA_jsonl):
    with open(args.omit_AA_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        omit_AA_dict = json.loads(json_str)
else:
    print(40 * '-')
    print('omit_AA_jsonl is NOT loaded')
    omit_AA_dict = None

----------------------------------------
omit_AA_jsonl is NOT loaded


In [13]:
if os.path.isfile(args.bias_AA_jsonl):
    with open(args.bias_AA_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        bias_AA_dict = json.loads(json_str)
else:
    print(40 * '-')
    print('bias_AA_jsonl is NOT loaded')
    bias_AA_dict = None

----------------------------------------
bias_AA_jsonl is NOT loaded


In [14]:
if os.path.isfile(args.tied_positions_jsonl):
    with open(args.tied_positions_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        tied_positions_dict = json.loads(json_str)
else:
    print(40 * '-')
    print('tied_positions_jsonl is NOT loaded')
    tied_positions_dict = None

----------------------------------------
tied_positions_jsonl is NOT loaded


In [15]:
if os.path.isfile(args.bias_by_res_jsonl):
    with open(args.bias_by_res_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        bias_by_res_dict = json.loads(json_str)
else:
    print(40 * '-')
    print('bias_by_res_jsonl is NOT loaded')
    bias_by_res_dict = None

----------------------------------------
bias_by_res_jsonl is NOT loaded


In [16]:
print(40 * '-')
bias_AAs_np = np.zeros(len(alphabet))
if bias_AA_dict:
    for n, AA in enumerate(alphabet):
        if AA in list(bias_AA_dict.keys()):
            bias_AAs_np[n] = bias_AA_dict[AA]

----------------------------------------


In [17]:
if args.pdb_path:
    pdb_dict_list = parse_PDB(args.pdb_path)
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=args.max_length)
    all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9] == 'seq_chain']
    if args.pdb_path_chains:
        designed_chain_list = [str(item) for item in args.pdb_path_chains.split()]
    else:
        designed_chain_list = all_chain_list
    fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
    chain_id_dict = {}
    chain_id_dict[pdb_dict_list[0]['name']] = (designed_chain_list, fixed_chain_list)
else:
    dataset_valid = StructureDataset(args.jsonl_path, truncate=None, max_length=args.max_length)

In [18]:
print(40 * '-')
checkpoint = torch.load(checkpoint_path, map_location=device)
print(f"Number of edges: {checkpoint['num_edges']}")
noise_level_print = checkpoint['noise_level']
print(f"Training noise_level: {noise_level_print}A")
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=args.backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

----------------------------------------
Number of edges: 48
Training noise_level: 0.2A


ProteinMPNN(
  (features): ProteinFeatures(
    (embeddings): PositionalEncodings(
      (linear): Linear(in_features=66, out_features=16, bias=True)
    )
    (edge_embedding): Linear(in_features=416, out_features=128, bias=False)
    (norm_edges): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (W_e): Linear(in_features=128, out_features=128, bias=True)
  (W_s): Embedding(21, 128)
  (encoder_layers): ModuleList(
    (0): EncLayer(
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (W1): Linear(in_features=384, out_features=128, bias=True)
      (W2): Linear(in_features=128, out_features=128, bias=True)
      (W3): Linear(in_features=128, out_features=128, bias=True)
     

In [19]:
base_folder = folder_for_outputs
if base_folder[-1] != '/':
    base_folder = base_folder + '/'

if not os.path.exists(base_folder):
    os.makedirs(base_folder)

if not os.path.exists(base_folder + 'seqs'):
    os.makedirs(base_folder + 'seqs')

if args.save_score:
    if not os.path.exists(base_folder + 'scores'):
        os.makedirs(base_folder + 'scores')

if args.score_only:
    if not os.path.exists(base_folder + 'score_only'):
        os.makedirs(base_folder + 'score_only')

if args.conditional_probs_only:
    if not os.path.exists(base_folder + 'conditional_probs_only'):
        os.makedirs(base_folder + 'conditional_probs_only')

if args.unconditional_probs_only:
    if not os.path.exists(base_folder + 'unconditional_probs_only'):
        os.makedirs(base_folder + 'unconditional_probs_only')

if args.save_probs:
    if not os.path.exists(base_folder + 'probs'):
        os.makedirs(base_folder + 'probs')

In [47]:
# Timing
start_time = time.time()
total_residues = 0
protein_list = []
total_step = 0

# Validation epoch
with torch.no_grad():

    test_sum, test_weights = 0., 0.
    print('Generating sequences...')
    for ix, protein in enumerate(dataset_valid):

        for i in range(len(test_set_seqs[:1])):
            # Update the sequence with the mutated sequence
            protein['seq_chain_A'] = test_set_seqs[i]
            protein['seq'] = test_set_seqs[i]

            # Form batch of clones
            batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
            
            # Featurize
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict)
            pssm_log_odds_mask = (pssm_log_odds_all > args.pssm_threshold).float() #1.0 for true, 0.0 for false
            name_ = batch_clones[0]['name']
            for j in range(NUM_BATCHES):
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M * chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask * chain_M * chain_M_pos
                score = torch.mean(_scores(S, log_probs, mask_for_loss))




Generating sequences...
