# MultiStateMPNN

In [1]:
### Design settings
design_name = 'NeoCas_Trial3'
pdb_path = '/work/lpdi/users/mpacesa/Projects/NeoCas/MultiStateMPNN/InputStates/'
folder_for_outputs = '/work/lpdi/users/ymiao/ProteinMPNN/test_myy/output3'
pdb_path_chains = 'A' #which chains to design, if none specified then all will be designed, space separated "A B"
num_of_sequences = 50 #number of sequences to design
sampling_temp = '1.0' #sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sampling randomly.")
backbone_noise = 0.00 #backbone noise during sampling, 0.00-0.02 are good values
omit_AAs = 'XC' #which amino acids to avoid, X is unkonwn, example 'XC' avoids unknown amino acids and cysteine
positions_to_fix = '10 66 70 74 75 77 78 115 116 165 328 329 364 403 407 459 762 839 840 863 866 983 986 1122 1333 1335 1349' #which positions should remain fixed in the chains to be designed, separte positions by space and chains by comma (1 2 3, 21 24 25, ...), or enter "none" (PDB gets renumbered, starts from 1 always)
invert_sel = None # If set to True, it will only design the positions selected above, basically inverting your selection

### Pipeline settings
# model settings
main_path = '/work/lpdi/users/mpacesa/Pipelines/MultiStateMPNN' # path to your proteinMPNN installation
model_name = 'v_48_020'
solublempnn = False
ca_only = False # Parse CA-only structures and use CA-only models
num_seq_per_target = 1 # keep at 1 for multi-state
batch_size = 1 # keep at 1 for multi-state
seed_set = None # hard set the seed, set to number for testing purposes
pssm_threshold = 0.00
pssm_multi = 0.00

# log settings
save_score = True
save_probs = True
score_only = False
conditional_probs_only = False
unconditional_probs_only = False
pssm_log_odds_flag = False
pssm_bias_flag = False

# custom JSON paths
fixed_positions_jsonl = ''
pssm_jsonl = ''
omit_AA_jsonl = ''
bias_AA_jsonl = ''
tied_positions_jsonl = ''
bias_by_res_jsonl = ''

### Print time    
from datetime import datetime
time_date = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print("Settings updated: "+time_date)

Settings updated: 01/05/2023 14:00:15


# Initialise functions

In [2]:
### Import dependencies
import json, time, os, sys, glob
import shutil
import warnings
import copy
import csv
import random
import itertools
import subprocess
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import Markdown, display
os.chdir(main_path)
%matplotlib inline

# torch
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import torch.nn as nn
import torch.nn.functional as F

# proteinmpnn scripts
from helper_scripts.parse_multiple_chains_multistate import ms_parse_chains
from helper_scripts.assign_fixed_chains_multistate import ms_assign_chains
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN

def printmd(string, color=None):
    colorstr = "<span style='color:{}'>{}</span>".format(color, string)
    display(Markdown(colorstr))

printmd("Dependencies loaded", color="green")

<span style='color:green'>Dependencies loaded</span>

In [3]:
### Build paths for experiment
base_folder = folder_for_outputs
if base_folder[-1] != '/':
    base_folder = base_folder + '/'
    
if not os.path.exists(base_folder):
    os.makedirs(base_folder)
    
if not os.path.exists(base_folder + 'seqs'):
    os.makedirs(base_folder + 'seqs')
    
if save_score:
    if not os.path.exists(base_folder + 'scores'):
        os.makedirs(base_folder + 'scores')

if score_only:
    if not os.path.exists(base_folder + 'score_only'):
        os.makedirs(base_folder + 'score_only')

if conditional_probs_only:
    if not os.path.exists(base_folder + 'conditional_probs_only'):
        os.makedirs(base_folder + 'conditional_probs_only')

if unconditional_probs_only:
    if not os.path.exists(base_folder + 'unconditional_probs_only'):
        os.makedirs(base_folder + 'unconditional_probs_only')

if save_probs:
    if not os.path.exists(base_folder + 'probs'):
        os.makedirs(base_folder + 'probs') 

printmd("Paths have been created", color="green")

<span style='color:green'>Paths have been created</span>

In [4]:
### Define accessory functions
def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def sample_sequence(master_pssm):
    des_seq = []
    for i in master_pssm[:92]:
        sampled_aa = np.random.choice(np.arange(0, 21), p=i[0:21])
        #print(alphabet[sampled_aa])
        des_seq.append(alphabet[sampled_aa])
        sampled_seq = ''.join(des_seq)
        
    return sampled_seq

def max_sequence(master_pssm):
    l = []
    for res in np.argmax(master_pssm, axis=-1):
        l.append(alphabet[res])
    master_seq = ''.join(l)
    
    return master_seq

def make_fixed_positions_dict(input_path, output_path, chain_list, position_list, invert_sel=None):
    with open(input_path, 'r') as json_file:
        json_list = list(json_file)
    
    fixed_list = [[int(item) for item in one.split()] for one in position_list.split(",")]
    global_designed_chain_list = [str(item) for item in chain_list.split()]
    my_dict = {}
    
    if not invert_sel:
        for json_str in json_list:
            result = json.loads(json_str)
            all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
            fixed_position_dict = {}
            for i, chain in enumerate(global_designed_chain_list):
                fixed_position_dict[chain] = fixed_list[i]
            for chain in all_chain_list:
                if chain not in global_designed_chain_list:       
                    fixed_position_dict[chain] = []
            my_dict[result['name']] = fixed_position_dict
    else:
        for json_str in json_list:
            result = json.loads(json_str)
            all_chain_list = [item[-1:] for item in list(result) if item[:9]=='seq_chain']
            fixed_position_dict = {}   
            for chain in all_chain_list:
                seq_length = len(result[f'seq_chain_{chain}'])
                all_residue_list = (np.arange(seq_length)+1).tolist()
                if chain not in global_designed_chain_list:
                    fixed_position_dict[chain] = all_residue_list
                else:
                    idx = np.argwhere(np.array(global_designed_chain_list) == chain)[0][0]
                    fixed_position_dict[chain] = list(set(all_residue_list)-set(fixed_list[idx]))
            my_dict[result['name']] = fixed_position_dict

    with open(output_path, 'w') as f:
        f.write(json.dumps(my_dict) + '\n')


printmd("Accessory functions defined", color="green")

<span style='color:green'>Accessory functions defined</span>

In [5]:
### Set important paths and variables
# JSON paths
jsonl_path = folder_for_outputs + '/parsed_pdbs.jsonl'
chain_id_jsonl = folder_for_outputs + '/assigned_pdbs.jsonl'
fixed_positions_jsonl = folder_for_outputs + '/fixed_positions.jsonl'

# model paths
if ca_only:
    model_folder_path = main_path + '/ca_model_weights/'
else:
    if solublempnn:
        model_folder_path = main_path + '/soluble_model_weights/'
    else:
        model_folder_path = main_path + '/vanilla_model_weights/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

# AA alphabet
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
chain_list = pdb_path_chains
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
bias_AAs_np = np.zeros(len(alphabet))

# NN settings
max_length = 200000
hidden_dim = 128
num_layers = 3
NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]

printmd("Paths and variable set", color="green")

<span style='color:green'>Paths and variable set</span>

In [6]:
### Parse JSONs
ms_parse_chains(pdb_path, jsonl_path, ca_only)
ms_assign_chains(jsonl_path, chain_list, chain_id_jsonl)

if os.path.isfile(chain_id_jsonl):
    with open(chain_id_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        chain_id_dict = json.loads(json_str)
else:
    chain_id_dict = None
    print(40*'-')
    print('chain_id_jsonl is NOT loaded')
    
if positions_to_fix != '':
    make_fixed_positions_dict(input_path=jsonl_path, output_path=fixed_positions_jsonl, chain_list=chain_list, position_list=positions_to_fix, invert_sel=invert_sel)
    
if os.path.isfile(fixed_positions_jsonl):
    with open(fixed_positions_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        fixed_positions_dict = json.loads(json_str)
else:
    print(40*'-')
    print('fixed_positions_jsonl is NOT loaded')
    fixed_positions_dict = None
    
if os.path.isfile(pssm_jsonl):
    with open(pssm_jsonl, 'r') as json_file:
        json_list = list(json_file)
    pssm_dict = {}
    for json_str in json_list:
        pssm_dict.update(json.loads(json_str))
else:
    print(40*'-')
    print('pssm_jsonl is NOT loaded')
    pssm_dict = None
    
if os.path.isfile(omit_AA_jsonl):
    with open(omit_AA_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        omit_AA_dict = json.loads(json_str)
else:
    print(40*'-')
    print('omit_AA_jsonl is NOT loaded')
    omit_AA_dict = None

if os.path.isfile(bias_AA_jsonl):
    with open(bias_AA_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        bias_AA_dict = json.loads(json_str)
else:
    print(40*'-')
    print('bias_AA_jsonl is NOT loaded')
    bias_AA_dict = None

if bias_AA_dict:
    for n, AA in enumerate(alphabet):
        if AA in list(bias_AA_dict.keys()):
            bias_AAs_np[n] = bias_AA_dict[AA]

if os.path.isfile(tied_positions_jsonl):
    with open(tied_positions_jsonl, 'r') as json_file:
        json_list = list(json_file)
    for json_str in json_list:
        tied_positions_dict = json.loads(json_str)
else:
    print(40*'-')
    print('tied_positions_jsonl is NOT loaded')
    tied_positions_dict = None
    
if os.path.isfile(bias_by_res_jsonl):
    with open(bias_by_res_jsonl, 'r') as json_file:
        json_list = list(json_file)
    
    for json_str in json_list:
        bias_by_res_dict = json.loads(json_str)
    print('bias by residue dictionary is loaded')
else:
    print(40*'-')
    print('bias by residue dictionary is NOT loaded, or not provided')
    bias_by_res_dict = None
print(40*'-')

if pdb_path.endswith('.pdb'):
    pdb_dict_list = parse_PDB(pdb_path, ca_only=ca_only)
    dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
    all_chain_list = [item[-1:] for item in list(pdb_dict_list[0]) if item[:9]=='seq_chain'] #['A','B', 'C',...]
    if pdb_path_chains:
        designed_chain_list = [str(item) for item in pdb_path_chains.split()]
    else:
        designed_chain_list = all_chain_list
    fixed_chain_list = [letter for letter in all_chain_list if letter not in designed_chain_list]
    chain_id_dict = {}
    chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)
else:
    dataset_valid = StructureDataset(jsonl_path, truncate=None, max_length=max_length)    
print(40*'-')

----------------------------------------
pssm_jsonl is NOT loaded
----------------------------------------
omit_AA_jsonl is NOT loaded
----------------------------------------
bias_AA_jsonl is NOT loaded
----------------------------------------
tied_positions_jsonl is NOT loaded
----------------------------------------
bias by residue dictionary is NOT loaded, or not provided
----------------------------------------
discarded {'bad_chars': 0, 'too_long': 0, 'bad_seq_length': 0}
----------------------------------------


In [None]:


# check device
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

# Model settings
checkpoint = torch.load(checkpoint_path, map_location=device) 
#print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
#print(f'Training noise level: {noise_level_print}A')

# Load model
model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# reset Probabilities
all_probs_list_list = []
all_sample_list_list =[]
chain_M_pos_list =[]
# Timing
start_time = time.time()
total_residues = 0
protein_list = []
total_step = 0
sample_encode_list =[]
with torch.no_grad():
    test_sum, test_weights = 0., 0.
    #print('Generating sequences...')
    for ix, protein in enumerate(dataset_valid):
        score_list = []
        global_score_list = []
        all_probs_list = []
        all_log_probs_list = []
        S_sample_list = []
        batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
        X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
        pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
        name_ = batch_clones[0]['name']
        if score_only:
            structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + '.npz'
            native_score_list = []
            global_native_score_list = []
            for j in range(NUM_BATCHES):
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S, log_probs, mask_for_loss)
                native_score = scores.cpu().data.numpy()
                native_score_list.append(native_score)
                global_scores = _scores(S, log_probs, mask)
                global_native_score = global_scores.cpu().data.numpy()
                global_native_score_list.append(global_native_score)
            native_score = np.concatenate(native_score_list, 0)
            global_native_score = np.concatenate(global_native_score_list, 0)
            ns_mean = native_score.mean()
            ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
            ns_std = native_score.std()
            ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)

            global_ns_mean = global_native_score.mean()
            global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
            global_ns_std = global_native_score.std()
            global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
            ns_sample_size = native_score.shape[0]
            np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score)
            print(f'Score for {name_}, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size},  Global Score for {name_}, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
        elif conditional_probs_only:
            print(f'Calculating conditional probabilities for {name_}')
            conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
            log_conditional_probs_list = []
            for j in range(NUM_BATCHES):
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_conditional_probs = model.conditional_probs(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, conditional_probs_only_backbone)
                log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
            concat_log_p = np.concatenate(log_conditional_probs_list, 0) #[B, L, 21]
            mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
            np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
        elif unconditional_probs_only:
            print(f'Calculating sequence unconditional probabilities for {name_}')
            unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
            log_unconditional_probs_list = []
            for j in range(NUM_BATCHES):
                log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
                log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
            concat_log_p = np.concatenate(log_unconditional_probs_list, 0) #[B, L, 21]
            mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
            np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
        else:
            randn_1 = torch.randn(chain_M.shape, device=X.device)
            log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
            mask_for_loss = mask*chain_M*chain_M_pos
            scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
            native_score = scores.cpu().data.numpy()
            global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
            global_native_score = global_scores.cpu().data.numpy()
            # Generate some sequences
            ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
            score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
            probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
            print(f'Generating probabilities for: {name_}')
            t0 = time.time()
            with open(ali_file, 'w') as f:
                for temp in temperatures:
                    for j in range(1):
                        # print(j)
                        randn_2 = torch.randn(chain_M.shape, device=X.device)
                        if tied_positions_dict == None:
                            sample_encode = model.sample_multistate_encode(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                            # S_sample = sample_dict["S"] 
                            sample_encode_list.append(sample_encode)
                            # break


Generating probabilities for: 1_binary_noMSA
Generating probabilities for: 2_6bp_noMSA
Generating probabilities for: 3_8bp_noMSA


In [9]:
### Define Multi-state MPNN function

np.random.seed(seed)
def run_multistate_mpnn_test(dataset_valid, seed,state_weights):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # check device
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

    # Model settings
    checkpoint = torch.load(checkpoint_path, map_location=device) 
    #print('Number of edges:', checkpoint['num_edges'])
    noise_level_print = checkpoint['noise_level']
    #print(f'Training noise level: {noise_level_print}A')

    # Load model
    model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # reset Probabilities
    all_probs_list_list = []
    all_sample_list_list =[]
    chain_M_pos_list =[]
    # Timing
    start_time = time.time()
    total_residues = 0
    protein_list = []
    total_step = 0
    
    sample_encode_list=[]
    # Validation epoch
    with torch.no_grad():
        test_sum, test_weights = 0., 0.
        #print('Generating sequences...')
        for ix, protein in enumerate(dataset_valid):
            score_list = []
            global_score_list = []
            all_probs_list = []
            all_log_probs_list = []
            S_sample_list = []
            batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
            pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
            name_ = batch_clones[0]['name']
            # delet score only, conditional, unconditional prob mode
            randn_1 = torch.randn(chain_M.shape, device=X.device)
            log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
            mask_for_loss = mask*chain_M*chain_M_pos
            scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
            native_score = scores.cpu().data.numpy()
            global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
            global_native_score = global_scores.cpu().data.numpy()
            # Generate some sequences
            ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
            score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
            probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
            print(f'Generating probabilities for: {name_}')
            t0 = time.time()
            for temp in temperatures:
                for j in range(NUM_BATCHES):
                    if tied_positions_dict == None:
                        sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                        sample_encode = model.sample_multistate_encode(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                        sample_encode_list.append(sample_encode)
        output_dict_list=[]
        decode_require_list=[]
        myVars = vars()
        for sample_encode in sample_encode_list:
            namelist = ['N_nodes', 'decoding_order', 'chain_mask', 'mask', 'bias_by_res', 'S_true', 'E_idx', 'h_E', 'h_S','S','h_V', 'h_EXV_encoder_fw', 'h_V_stack', 'mask_bw', 'temperature', 'constant', 'constant_bias', 'pssm_bias_flag', 'pssm_coef', 'pssm_bias', 'pssm_log_odds_flag', 'pssm_multi', 'pssm_log_odds_mask', 'omit_AA_mask_flag', 'omit_AA_mask', 'all_probs']
            for name in namelist:
                myVars.__setitem__(name, sample_encode[name]) 
            decode_require =(decoding_order,chain_mask,mask,bias_by_res,S_true,E_idx,h_E,h_S,S,h_V,h_EXV_encoder_fw,h_V_stack,mask_bw,temperature,constant,constant_bias,pssm_bias_flag,pssm_coef,pssm_bias,pssm_log_odds_flag,pssm_multi,pssm_log_odds_mask,omit_AA_mask_flag,omit_AA_mask,all_probs)
            decode_require_list.append(decode_require)
        
        sample_dict=model.sample_multistate_decode_N_nodes(N_nodes,decode_require_list,state_weights)
        S_sample = sample_dict["S"]
        log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
        mask_for_loss = mask*chain_M*chain_M_pos
        scores = _scores(S_sample, log_probs, mask_for_loss)
        scores = scores.cpu().data.numpy()

        global_scores = _scores(S_sample, log_probs, mask) #score the whole structure-sequence
        global_scores = global_scores.cpu().data.numpy()

                
    t1 = time.time()
    dt = round(float(t1-start_time), 4)
    num_seqs = len(temperatures)*NUM_BATCHES*BATCH_COPIES
    total_length = X.shape[1]
    print(f'{num_seqs} sequence of length {total_length} generated in {dt} seconds with score of {scores} ')
    
    
    return sample_dict["probs"].cpu().data.numpy(), S_sample.cpu().data.numpy(),chain_M_pos,scores


In [12]:
seq_list=[]
dataset_valid = dataset_valid[0:3]
state_weights = np.array([1.0, 1.0, 1.0])
state_weights = torch.from_numpy(state_weights/np.sum(state_weights))

for i in range(0,1):
    if seed_set is None:
        seed = int(np.random.randint(0, high=9999, size=1, dtype=int)[0])
    else:
        seed = seed_set

    sequence_name = design_name+'_'+str(i)
    
    printmd("Seed value for design "+sequence_name+" is: "+str(seed), color="blue")
    prob,sample,chain_M_pos, scores= run_multistate_mpnn_test(dataset_valid, seed,state_weights)
    l = []
    for res in sample[0]:
        l.append(alphabet[int(res)])
    master_seq = ''.join(l)    
    seq_list.append(master_seq)
    print(master_seq,scores)


<span style='color:blue'>Seed value for design NeoCas_Trial3_0 is: 775</span>

Generating probabilities for: 1_binary_noMSA
Generating probabilities for: 2_6bp_noMSA
Generating probabilities for: 3_8bp_noMSA
1 sequence of length 1368 generated in 31.6868 seconds with score of [1.8503491] 
REEPWSIGFDVGTDTVGVAVLGEKYKVPTRSFPVTGNSISHQRQTPLLGNLVIEPAQSREERIRKREKARIPKRRNNRMEELEKIFEPELKKKDFPFLLRIRLRDLPPHERSQSRHPLYGNDRLEVIYYKKFPTIQHRISTLANSSAKTQLIDIYVALKYLISNRGNFNLKGEIYPLFRDIQREANRLVLTFNHLFPQNPIDNSGVDFASILGADTSPEGRLAALQSTLPSRRPEDPFVRLVRLALGLQPDFAPSFQLANRALLNWEDEDYAERLERLLKEVGKRYRPLFEEALRLRDAILLGSWLTANDPAEISPLEEQEIMILNRHHKDRSHTKQLLKKEKPELVTRVFEDVPLNGYAGYEDGYATLEMFQEYIRPLLEDLPGTSFLLKQLGDRTLERELRSSTNKNIPTNVKHFTVDHILKKQQKYYPFIGEKKDLIVKIHEFVVPIDVGPLATDNTDNAVSRRKIDQPITSWNFTERIQYVMSRRKKTLNHTPRSPHNPSDEVLPKQNLTTMTYNVLNLLVRLRYLSQWLFAPRLLSATEIAALMMHLDLEKEEVTIDEIRTDYFGLILQYADVPIEGEKERLHAKLHMKLNLNDIIKPPLVLLDDSNVALIDEIVDTLVRYSTAIIIAEELRKWMNYLTKEQMITLVNMKLDGWSWKSKNEIDGFRDQTTQKSDMDYLADGGERQLNILEIYSDPDLSFGGIVAHIKIRAHGRNLEASIDYIPWLPAMSKVGLLMFRIYMENIDTHPSRMTRNIFMEIETVRTNGSKHTAEPEKILRGVKEGLA

In [23]:

print(master_seq)

HPELFSINFDVGNTSIGYAVLDDSFAIPSLTFEVTGNTNKTKHERPLIGVLLAPAAKDHVRYYYERRARRRRARRQNRQTLLAAYFAAIMDQEDPYFMERVRQMYLPYKDRNTHRHARYGNKEKEREHYEAYPNIWLLIQQLVNKPTKAPLRDGYTALSFLLSNRGNNEMKGIVNVLKADVQEEFLNSVSLYNKRFPSAPIDTTGIDASGILNAWTSADDRLAALMGQLPHIAADSYFGHLVALVPGLVPNQISAFGLREEALLDFVDPGFAHRLKQNLELVILENRELFDSALDLGWAIKLGSTLTSRDANSRSPESKQKVDDQMKHHEDEEKFYDLVTKQNPGEKEEIAENKESNGSAGFRDGFANWEIYKNKILPILQEMEGTDELISKIEKNEFLRPLRTDENRYMDRDLYLHDLGEILKQQQQFFPFLREKEEHIIQLARYRVPAEVGPLLKKNSEDEDRAFKIQLAISPSNYYGVVDWITSAVRKFHLLTPTSPNLPTFRVLALQSLRNQTFRTFNELATITYLSKQLAEPKKLSSSDALDLYETLFLINAYVTVNMVNTMFFKTIKKWANVTLSGHQNQFPTKLGMYHELLEITKNEPLLHDGSKWARIDQIIDILTRAKNKVFIANGKSQFAHDFTAAEMDKITQFVTEGWGDFSVQLITGLLDDVTGKSVMDFLKDDGTERRTFDQIINDPTLSQKRQIQWLSSRAQDWSLPYLVANLAAPPSKKKGISYTLNILGEIIELTGGENPAVVAIEFSVIRSSINTDLLTASQYLTQINRGAQTLATSILQDYPVKTVDLFSIKLYLWFRQQMMDAWNPETLDPRMLEQYDIDHIKSVRYRETNDIGNLALVKSRSNRGKNAAHPSEKVTARLNQDHTALLAAELIDARTYKALTMSYSAGLWLKTIYDYILESIELKEMISTRLAELIASMLNDQQDGNGKPVYKVKVIRLNSALIDIFRQNHKLYSVNRISPKIHAFDAYMTGLLGSRLISR

In [13]:
### Define Multi-state MPNN function

def run_multistate_mpnn_test(dataset_valid, seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # check device
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

    # Model settings
    checkpoint = torch.load(checkpoint_path, map_location=device) 
    #print('Number of edges:', checkpoint['num_edges'])
    noise_level_print = checkpoint['noise_level']
    #print(f'Training noise level: {noise_level_print}A')

    # Load model
    model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # reset Probabilities
    all_probs_list_list = []
    all_sample_list_list =[]
    chain_M_pos_list =[]
    # Timing
    start_time = time.time()
    total_residues = 0
    protein_list = []
    total_step = 0

    # Validation epoch
    with torch.no_grad():
        test_sum, test_weights = 0., 0.
        #print('Generating sequences...')
        for ix, protein in enumerate(dataset_valid):
            score_list = []
            global_score_list = []
            all_probs_list = []
            all_log_probs_list = []
            S_sample_list = []
            batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
            pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
            name_ = batch_clones[0]['name']
            if score_only:
                structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + '.npz'
                native_score_list = []
                global_native_score_list = []
                for j in range(NUM_BATCHES):
                    randn_1 = torch.randn(chain_M.shape, device=X.device)
                    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                    mask_for_loss = mask*chain_M*chain_M_pos
                    scores = _scores(S, log_probs, mask_for_loss)
                    native_score = scores.cpu().data.numpy()
                    native_score_list.append(native_score)
                    global_scores = _scores(S, log_probs, mask)
                    global_native_score = global_scores.cpu().data.numpy()
                    global_native_score_list.append(global_native_score)
                native_score = np.concatenate(native_score_list, 0)
                global_native_score = np.concatenate(global_native_score_list, 0)
                ns_mean = native_score.mean()
                ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
                ns_std = native_score.std()
                ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)

                global_ns_mean = global_native_score.mean()
                global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
                global_ns_std = global_native_score.std()
                global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
                ns_sample_size = native_score.shape[0]
                np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score)
                print(f'Score for {name_}, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size},  Global Score for {name_}, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
            elif conditional_probs_only:
                print(f'Calculating conditional probabilities for {name_}')
                conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
                log_conditional_probs_list = []
                for j in range(NUM_BATCHES):
                    randn_1 = torch.randn(chain_M.shape, device=X.device)
                    log_conditional_probs = model.conditional_probs(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, conditional_probs_only_backbone)
                    log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
                concat_log_p = np.concatenate(log_conditional_probs_list, 0) #[B, L, 21]
                mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
                np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
            elif unconditional_probs_only:
                print(f'Calculating sequence unconditional probabilities for {name_}')
                unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
                log_unconditional_probs_list = []
                for j in range(NUM_BATCHES):
                    log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
                    log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
                concat_log_p = np.concatenate(log_unconditional_probs_list, 0) #[B, L, 21]
                mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
                np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
            else:
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
                native_score = scores.cpu().data.numpy()
                global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
                global_native_score = global_scores.cpu().data.numpy()
                # Generate some sequences
                ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
                score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
                probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
                print(f'Generating probabilities for: {name_}')
                t0 = time.time()
                with open(ali_file, 'w') as f:
                    for temp in temperatures:
                        for j in range(NUM_BATCHES):
                            # print(j)
                            randn_2 = torch.randn(chain_M.shape, device=X.device)
                            if tied_positions_dict == None:
                                sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                                S_sample = sample_dict["S"] 
                            else:
                                sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
                                # Compute scores
                                S_sample = sample_dict["S"]
                            log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
                            mask_for_loss = mask*chain_M*chain_M_pos
                            scores = _scores(S_sample, log_probs, mask_for_loss)
                            scores = scores.cpu().data.numpy()

                            global_scores = _scores(S_sample, log_probs, mask) #score the whole structure-sequence
                            global_scores = global_scores.cpu().data.numpy()

                            all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
                            all_log_probs_list.append(log_probs.cpu().data.numpy())
                            S_sample_list.append(S_sample.cpu().data.numpy())
                            for b_ix in range(BATCH_COPIES):
                                masked_chain_length_list = masked_chain_length_list_list[b_ix]
                                masked_list = masked_list_list[b_ix]
                                seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                                seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                                score = scores[b_ix]
                                score_list.append(score)
                                global_score = global_scores[b_ix]
                                global_score_list.append(global_score)
                                native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                                if b_ix == 0 and j==0 and temp==temperatures[0]:
                                    start = 0
                                    end = 0
                                    list_of_AAs = []
                                    for mask_l in masked_chain_length_list:
                                        end += mask_l
                                        list_of_AAs.append(native_seq[start:end])
                                        start = end
                                    native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                    l0 = 0
                                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                        l0 += mc_length
                                        native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                                        l0 += 1

                                    sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                                    print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                                    sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                                    print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                                    native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                                    global_native_score_print = np.format_float_positional(np.float32(global_native_score.mean()), unique=False, precision=4)
                                    # script_dir = os.path.dirname(os.path.realpath(__file__))
                                    if ca_only:
                                        print_model_name = 'CA_model_name'
                                    else:
                                        print_model_name = 'model_name'

                                    f.write('>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, seed={}\n{}\n'.format(name_, native_score_print, global_native_score_print, print_visible_chains, print_masked_chains, print_model_name, model_name, seed, native_seq)) #write the native sequence
                                start = 0
                                end = 0
                                list_of_AAs = []
                                for mask_l in masked_chain_length_list:
                                    end += mask_l
                                    list_of_AAs.append(seq[start:end])
                                    start = end

                                seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                l0 = 0
                                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                    l0 += mc_length
                                    seq = seq[:l0] + '/' + seq[l0:]
                                    l0 += 1
                                score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                                global_score_print = np.format_float_positional(np.float32(global_score), unique=False, precision=4)
                                seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                                sample_number = j*BATCH_COPIES+b_ix+1
                                f.write('>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,sample_number, score_print, global_score_print, seq_rec_print, seq)) #write generated sequence
                if save_score:
                    np.savez(score_file, score=np.array(score_list, np.float32), global_score=np.array(global_score_list, np.float32))
                if save_probs:
                    all_probs_concat = np.concatenate(all_probs_list)
                    all_log_probs_concat = np.concatenate(all_log_probs_list)
                    S_sample_concat = np.concatenate(S_sample_list)
                    np.savez(probs_file, probs=np.array(all_probs_concat, np.float32), log_probs=np.array(all_log_probs_concat, np.float32), S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(), chain_order=chain_list_list)
                
                all_probs_list_list.append(all_probs_list)
                all_sample_list_list.append(S_sample_list)
                chain_M_pos_list.append(chain_M_pos)                
    t1 = time.time()
    dt = round(float(t1-start_time), 4)
    num_seqs = len(temperatures)*NUM_BATCHES*BATCH_COPIES
    total_length = X.shape[1]
    print(f'{num_seqs} sequence of length {total_length} generated in {dt} seconds with score of {score_print} and sequence recovery of {seq_rec_print}')
    
    
    return all_probs_list_list, all_sample_list_list,chain_M_pos_list,score_print, seq_rec_print

printmd("MPNN Function defined", color="green")

<span style='color:green'>MPNN Function defined</span>

In [14]:
### Define Multi-state MPNN function
def run_multistate_mpnn(dataset_valid, seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # check device
    device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

    # Model settings
    checkpoint = torch.load(checkpoint_path, map_location=device) 
    #print('Number of edges:', checkpoint['num_edges'])
    noise_level_print = checkpoint['noise_level']
    #print(f'Training noise level: {noise_level_print}A')

    # Load model
    model = ProteinMPNN(ca_only=ca_only, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
    model.to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # reset Probabilities
    all_probs_list_list = []
    all_sample_list_list =[]
    chain_M_pos_list =[]
    # Timing
    start_time = time.time()
    total_residues = 0
    protein_list = []
    total_step = 0

    # Validation epoch
    with torch.no_grad():
        test_sum, test_weights = 0., 0.
        #print('Generating sequences...')
        for ix, protein in enumerate(dataset_valid):
            score_list = []
            global_score_list = []
            all_probs_list = []
            all_log_probs_list = []
            S_sample_list = []
            batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]
            X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize(batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, tied_positions_dict, pssm_dict, bias_by_res_dict, ca_only=ca_only)
            pssm_log_odds_mask = (pssm_log_odds_all > pssm_threshold).float() #1.0 for true, 0.0 for false
            name_ = batch_clones[0]['name']
            if score_only:
                structure_sequence_score_file = base_folder + '/score_only/' + batch_clones[0]['name'] + '.npz'
                native_score_list = []
                global_native_score_list = []
                for j in range(NUM_BATCHES):
                    randn_1 = torch.randn(chain_M.shape, device=X.device)
                    log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                    mask_for_loss = mask*chain_M*chain_M_pos
                    scores = _scores(S, log_probs, mask_for_loss)
                    native_score = scores.cpu().data.numpy()
                    native_score_list.append(native_score)
                    global_scores = _scores(S, log_probs, mask)
                    global_native_score = global_scores.cpu().data.numpy()
                    global_native_score_list.append(global_native_score)
                native_score = np.concatenate(native_score_list, 0)
                global_native_score = np.concatenate(global_native_score_list, 0)
                ns_mean = native_score.mean()
                ns_mean_print = np.format_float_positional(np.float32(ns_mean), unique=False, precision=4)
                ns_std = native_score.std()
                ns_std_print = np.format_float_positional(np.float32(ns_std), unique=False, precision=4)

                global_ns_mean = global_native_score.mean()
                global_ns_mean_print = np.format_float_positional(np.float32(global_ns_mean), unique=False, precision=4)
                global_ns_std = global_native_score.std()
                global_ns_std_print = np.format_float_positional(np.float32(global_ns_std), unique=False, precision=4)
                ns_sample_size = native_score.shape[0]
                np.savez(structure_sequence_score_file, score=native_score, global_score=global_native_score)
                print(f'Score for {name_}, mean: {ns_mean_print}, std: {ns_std_print}, sample size: {ns_sample_size},  Global Score for {name_}, mean: {global_ns_mean_print}, std: {global_ns_std_print}, sample size: {ns_sample_size}')
            elif conditional_probs_only:
                print(f'Calculating conditional probabilities for {name_}')
                conditional_probs_only_file = base_folder + '/conditional_probs_only/' + batch_clones[0]['name']
                log_conditional_probs_list = []
                for j in range(NUM_BATCHES):
                    randn_1 = torch.randn(chain_M.shape, device=X.device)
                    log_conditional_probs = model.conditional_probs(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, conditional_probs_only_backbone)
                    log_conditional_probs_list.append(log_conditional_probs.cpu().numpy())
                concat_log_p = np.concatenate(log_conditional_probs_list, 0) #[B, L, 21]
                mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
                np.savez(conditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
            elif unconditional_probs_only:
                print(f'Calculating sequence unconditional probabilities for {name_}')
                unconditional_probs_only_file = base_folder + '/unconditional_probs_only/' + batch_clones[0]['name']
                log_unconditional_probs_list = []
                for j in range(NUM_BATCHES):
                    log_unconditional_probs = model.unconditional_probs(X, mask, residue_idx, chain_encoding_all)
                    log_unconditional_probs_list.append(log_unconditional_probs.cpu().numpy())
                concat_log_p = np.concatenate(log_unconditional_probs_list, 0) #[B, L, 21]
                mask_out = (chain_M*chain_M_pos*mask)[0,].cpu().numpy()
                np.savez(unconditional_probs_only_file, log_p=concat_log_p, S=S[0,].cpu().numpy(), mask=mask[0,].cpu().numpy(), design_mask=mask_out)
            else:
                randn_1 = torch.randn(chain_M.shape, device=X.device)
                log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)
                mask_for_loss = mask*chain_M*chain_M_pos
                scores = _scores(S, log_probs, mask_for_loss) #score only the redesigned part
                native_score = scores.cpu().data.numpy()
                global_scores = _scores(S, log_probs, mask) #score the whole structure-sequence
                global_native_score = global_scores.cpu().data.numpy()
                # Generate some sequences
                ali_file = base_folder + '/seqs/' + batch_clones[0]['name'] + '.fa'
                score_file = base_folder + '/scores/' + batch_clones[0]['name'] + '.npz'
                probs_file = base_folder + '/probs/' + batch_clones[0]['name'] + '.npz'
                print(f'Generating probabilities for: {name_}')
                t0 = time.time()
                with open(ali_file, 'w') as f:
                    for temp in temperatures:
                        for j in range(NUM_BATCHES):
                            # print(j)
                            randn_2 = torch.randn(chain_M.shape, device=X.device)
                            if tied_positions_dict == None:
                                sample_dict = model.sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), bias_by_res=bias_by_res_all)
                                S_sample = sample_dict["S"] 
                            else:
                                sample_dict = model.tied_sample(X, randn_2, S, chain_M, chain_encoding_all, residue_idx, mask=mask, temperature=temp, omit_AAs_np=omit_AAs_np, bias_AAs_np=bias_AAs_np, chain_M_pos=chain_M_pos, omit_AA_mask=omit_AA_mask, pssm_coef=pssm_coef, pssm_bias=pssm_bias, pssm_multi=pssm_multi, pssm_log_odds_flag=bool(pssm_log_odds_flag), pssm_log_odds_mask=pssm_log_odds_mask, pssm_bias_flag=bool(pssm_bias_flag), tied_pos=tied_pos_list_of_lists_list[0], tied_beta=tied_beta, bias_by_res=bias_by_res_all)
                                # Compute scores
                                S_sample = sample_dict["S"]
                            log_probs = model(X, S_sample, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_2, use_input_decoding_order=True, decoding_order=sample_dict["decoding_order"])
                            mask_for_loss = mask*chain_M*chain_M_pos
                            scores = _scores(S_sample, log_probs, mask_for_loss)
                            scores = scores.cpu().data.numpy()

                            global_scores = _scores(S_sample, log_probs, mask) #score the whole structure-sequence
                            global_scores = global_scores.cpu().data.numpy()

                            all_probs_list.append(sample_dict["probs"].cpu().data.numpy())
                            all_log_probs_list.append(log_probs.cpu().data.numpy())
                            S_sample_list.append(S_sample.cpu().data.numpy())
                            for b_ix in range(BATCH_COPIES):
                                masked_chain_length_list = masked_chain_length_list_list[b_ix]
                                masked_list = masked_list_list[b_ix]
                                seq_recovery_rate = torch.sum(torch.sum(torch.nn.functional.one_hot(S[b_ix], 21)*torch.nn.functional.one_hot(S_sample[b_ix], 21),axis=-1)*mask_for_loss[b_ix])/torch.sum(mask_for_loss[b_ix])
                                seq = _S_to_seq(S_sample[b_ix], chain_M[b_ix])
                                score = scores[b_ix]
                                score_list.append(score)
                                global_score = global_scores[b_ix]
                                global_score_list.append(global_score)
                                native_seq = _S_to_seq(S[b_ix], chain_M[b_ix])
                                if b_ix == 0 and j==0 and temp==temperatures[0]:
                                    start = 0
                                    end = 0
                                    list_of_AAs = []
                                    for mask_l in masked_chain_length_list:
                                        end += mask_l
                                        list_of_AAs.append(native_seq[start:end])
                                        start = end
                                    native_seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                    l0 = 0
                                    for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                        l0 += mc_length
                                        native_seq = native_seq[:l0] + '/' + native_seq[l0:]
                                        l0 += 1

                                    sorted_masked_chain_letters = np.argsort(masked_list_list[0])
                                    print_masked_chains = [masked_list_list[0][i] for i in sorted_masked_chain_letters]
                                    sorted_visible_chain_letters = np.argsort(visible_list_list[0])
                                    print_visible_chains = [visible_list_list[0][i] for i in sorted_visible_chain_letters]
                                    native_score_print = np.format_float_positional(np.float32(native_score.mean()), unique=False, precision=4)
                                    global_native_score_print = np.format_float_positional(np.float32(global_native_score.mean()), unique=False, precision=4)
                                    # script_dir = os.path.dirname(os.path.realpath(__file__))
                                    if ca_only:
                                        print_model_name = 'CA_model_name'
                                    else:
                                        print_model_name = 'model_name'

                                    f.write('>{}, score={}, global_score={}, fixed_chains={}, designed_chains={}, {}={}, seed={}\n{}\n'.format(name_, native_score_print, global_native_score_print, print_visible_chains, print_masked_chains, print_model_name, model_name, seed, native_seq)) #write the native sequence
                                start = 0
                                end = 0
                                list_of_AAs = []
                                for mask_l in masked_chain_length_list:
                                    end += mask_l
                                    list_of_AAs.append(seq[start:end])
                                    start = end

                                seq = "".join(list(np.array(list_of_AAs)[np.argsort(masked_list)]))
                                l0 = 0
                                for mc_length in list(np.array(masked_chain_length_list)[np.argsort(masked_list)])[:-1]:
                                    l0 += mc_length
                                    seq = seq[:l0] + '/' + seq[l0:]
                                    l0 += 1
                                score_print = np.format_float_positional(np.float32(score), unique=False, precision=4)
                                global_score_print = np.format_float_positional(np.float32(global_score), unique=False, precision=4)
                                seq_rec_print = np.format_float_positional(np.float32(seq_recovery_rate.detach().cpu().numpy()), unique=False, precision=4)
                                sample_number = j*BATCH_COPIES+b_ix+1
                                f.write('>T={}, sample={}, score={}, global_score={}, seq_recovery={}\n{}\n'.format(temp,sample_number, score_print, global_score_print, seq_rec_print, seq)) #write generated sequence
                if save_score:
                    np.savez(score_file, score=np.array(score_list, np.float32), global_score=np.array(global_score_list, np.float32))
                if save_probs:
                    all_probs_concat = np.concatenate(all_probs_list)
                    all_log_probs_concat = np.concatenate(all_log_probs_list)
                    S_sample_concat = np.concatenate(S_sample_list)
                    np.savez(probs_file, probs=np.array(all_probs_concat, np.float32), log_probs=np.array(all_log_probs_concat, np.float32), S=np.array(S_sample_concat, np.int32), mask=mask_for_loss.cpu().data.numpy(), chain_order=chain_list_list)
                
                all_probs_list_list.append(all_probs_list)
                all_sample_list_list.append(S_sample_list)
                chain_M_pos_list.append(chain_M_pos)                
    t1 = time.time()
    dt = round(float(t1-start_time), 4)
    num_seqs = len(temperatures)*NUM_BATCHES*BATCH_COPIES
    total_length = X.shape[1]
    print(f'{num_seqs} sequence of length {total_length} generated in {dt} seconds with score of {score_print} and sequence recovery of {seq_rec_print}')
    
    
    return all_probs_list_list, all_sample_list_list,chain_M_pos_list,score_print, seq_rec_print

printmd("MPNN Function defined", color="green")

<span style='color:green'>MPNN Function defined</span>

# Generate sequences

In [11]:
## Dictionary of states and associated weights
# match input PDB name and number between -1 and 1
# if state is omitted it gets assigned 1 by default
state_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]

### TODO: sort the JSON entries based on PDB name

### Check
printmd("Check that your protein states are in the same order as your weights!", color="red")
i = 0
for ix, protein in enumerate(dataset_valid):
    j = i+1
    printmd(str(j)+' => '+protein['name']+' <= '+str(state_weights[i]), color="blue")
    i = i+1

<span style='color:red'>Check that your protein states are in the same order as your weights!</span>

<span style='color:blue'>1 => 1_binary_noMSA <= 1.0</span>

<span style='color:blue'>2 => 2_6bp_noMSA <= 1.0</span>

<span style='color:blue'>3 => 3_8bp_noMSA <= 1.0</span>

In [None]:
# Generate probabilities
final_sequences = []

for i in tqdm(range(1, num_of_sequences+1, 1), desc="Generating designs"):
    # set seed
    if seed_set is None:
        seed = int(np.random.randint(0, high=9999, size=1, dtype=int)[0])
    else:
        seed = seed_set

    sequence_name = design_name+'_'+str(i)
    printmd("Seed value for design "+sequence_name+" is: "+str(seed), color="blue")
    all_probs_list_list, all_sample_list_list,chain_M_pos_list,mpnn_score, sequence_recovery = run_multistate_mpnn(dataset_valid, seed)
    mean_pssm_list = []

    for j in range(len(all_probs_list_list)):
        design = all_probs_list_list[j]
        weight = state_weights[j]
        mean_design_pssm = design[0].mean(axis=0)
        mean_pssm_list.append(mean_design_pssm*weight)
        
    np.array(mean_pssm_list).shape
    master_pssm = np.array(mean_pssm_list).mean(axis=0)
    idx = np.argmax(master_pssm,axis=1)
    result= (all_sample_list_list[0]*(1-chain_M_pos_list[0].numpy()))[0][0]+idx*(chain_M_pos_list[0][0].numpy())

    # Sample most probable sequence
    print(40*'-')
    l = []
    for res in result:
        l.append(alphabet[int(res)])
    master_seq = ''.join(l)
#     master_seq = max_sequence(master_pssm)
    print("Designed sequence:")
    print(master_seq)

    final_seq = [sequence_name, master_seq, str(mpnn_score), str(sequence_recovery), str(seed), str(sampling_temp), str(backbone_noise), " ".join([str(w) for w in state_weights]), positions_to_fix]
    final_sequences.append(final_seq)
    
# generate dataframe
final_seq_list = pd.DataFrame(final_sequences, columns=['Design Name', 'Sequence', 'MPNN Score', 'Sequence Recovery', 'Seed', 'Temperature', 'Backbone noise', 'State Weights', 'Fixed Positions'])
final_seq_list.to_csv(folder_for_outputs+design_name+"_designs.csv", index=False, header=False)

Generating designs:   0%|          | 0/50 [00:00<?, ?it/s]

<span style='color:blue'>Seed value for design NeoCas_Trial3_1 is: 5324</span>

Generating probabilities for: 1_binary_noMSA
Generating probabilities for: 2_6bp_noMSA
Generating probabilities for: 3_8bp_noMSA
Generating probabilities for: 4_10bp_noMSA
Generating probabilities for: 5_16bp_noMSA
Generating probabilities for: 6_18bpchck_noMSA
Generating probabilities for: 7_18bpcat_noMSA
1 sequence of length 1368 generated in 148.52 seconds with score of 1.8485 and sequence recovery of 0.3251
----------------------------------------
Designed sequence:
APEPWSIGLDIGTDSVGYAVIDENFEVPKKTFPVSGNTDVKEREKNLIGVLLFPPAKSREARRRRRRRRRERRRRQNRLELLEEIFAPELAKEDPNYLARLRERHLPPEDRKLSRHPLFGNEEKEKAYKEKYPTIEALILDLVESTEKQPLRLIYLALRYLIRNRGNFRIKGELDPSNNDIQALFRELVDTYNALFPENPIDVEGVDFESILTSDLSPEERLDELIAAIPGVTKDSFFGNLLALALGLTPNFSPNFGLDEPALLDLDDPDYEERLAKLLSEVGEEYKPLFDAARKLGDAILLSRILKVDPSTTKSPLAARKIEIYERHHEDLEKLKELIKKQAPELYDEIFEDKSGNGYAGYEDGTATYEEFYAYIRPILESLPGTEELLELLEAGTLLRKLRDPRNKAIPRDLRLHELSAILDNQEPYYPFLKENKEEILTILTFRVPEYVGPLSRGNSPDSHAVFKKNEPVTPWNFEEIVDFVASARNYVRRKTPRDPLLPGKPVLPKNSLTYQEFLVYNELN

<span style='color:blue'>Seed value for design NeoCas_Trial3_2 is: 5006</span>

Generating probabilities for: 1_binary_noMSA
Generating probabilities for: 2_6bp_noMSA
Generating probabilities for: 3_8bp_noMSA
Generating probabilities for: 4_10bp_noMSA
Generating probabilities for: 5_16bp_noMSA
Generating probabilities for: 6_18bpchck_noMSA
Generating probabilities for: 7_18bpcat_noMSA
1 sequence of length 1368 generated in 145.6462 seconds with score of 1.8014 and sequence recovery of 0.3199
----------------------------------------
Designed sequence:
APEPWSIGLDIGTDSVGFAVIDENFRVPTKTFPVSGNTDVKSRKKNLIGVLLFPPAKSDEEERRRRRARRRRRRRANRLELLEEIFAPELAKEDPNYLARLEERHLPPEDRKYSRHPLFGNEEKEEAFKKKFPTIEALILDLVESKEKQDLRDIYEALRYLIRNRGNFKIEGELDPSKTDVQALFRELVDTYNALFPENPIDTTGVDFEAILTSDKSKEERLDKLIAAIPGVTEDGFFGKLLALALGLTPNFSPSFNLPEPALLDLRDPDYEERLAELLSEIGPEYKPLFDAAKKLSDAIYLSRILKVDPSTTKAPLAARKIEILKQHHEQLEKLKELIKAQAPELYDEIFEDTSGNGYAGYEDGTATYEEFYAYIRPILESLEGTEELLEALDKGTLLRKIRDEDNKAIPRDLRLGTLSAILDNQQPYYPFLAENKEEILNILTFRVPEYIGPLSRGNSPDSHAVFKTNEPVTPWNFEEIVDYVASAERYVERKTPRDPLLPDEPVLPKNSLTMQEFLVYNE

<span style='color:blue'>Seed value for design NeoCas_Trial3_3 is: 7378</span>

Generating probabilities for: 1_binary_noMSA
Generating probabilities for: 2_6bp_noMSA
Generating probabilities for: 3_8bp_noMSA
Generating probabilities for: 4_10bp_noMSA
Generating probabilities for: 5_16bp_noMSA
Generating probabilities for: 6_18bpchck_noMSA
Generating probabilities for: 7_18bpcat_noMSA


In [None]:
### Display Designs
design_df = pd.read_csv(folder_for_outputs+design_name+"_designs.csv", names=['Design Name', 'Sequence', 'MPNN Score', 'Sequence Recovery', 'Seed', 'Temperature', 'Backbone noise', 'State Weights', 'Fixed Positions'])
design_df.sort_values(['MPNN Score', 'Sequence Recovery'], inplace=True, ascending=[True, False])
design_df.head(20)

In [None]:
### Save FASTA
save_sequences = []

for index, row in tqdm(design_df.iterrows(), desc="Saving FASTA"):
    print("Saving sequence: "+row['Design Name'])
    save_sequences.append('>%s\n%s\n'%(row['Design Name'],row['Sequence']))
    
with open(folder_for_outputs+design_name+"_designs.fasta", 'w+') as fm:
    for line in save_sequences:
        fm.write(line)
fm.close()
printmd("FASTA output generated", color="green")

In [None]:
### Make weblogo of sequences
import logomaker
def make_logoplot(seqs):
    arr = np.asarray([list(s) for s in seqs])
    dfdict = {}
    for pos in range(arr.shape[1]):
        dfdict[pos] = {n:v/arr.shape[0] for n,v in zip(*np.unique(arr[:,pos],return_counts=True))}
    df = pd.DataFrame.from_dict(dfdict,orient='index')
    df.fillna(0,inplace=True) 
    logo = logomaker.Logo(df,color_scheme='weblogo_protein', width=0.95,figsize=(600,3))
    logo.style_xticks(anchor=0, spacing=5)
    logo.ax.set_ylabel('Occurence')
    logo.ax.set_xlabel('Position')
    logo.fig.tight_layout()
    plt.savefig(folder_for_outputs+design_name+'_weblogo.png', dpi=100, bbox_inches="tight")

logo_seq = []
for index, row in design_df.iterrows():
    logo_seq.append(row['Sequence'])    

make_logoplot(logo_seq)