In [263]:
import os
import numpy as np
import torch
import copy
import time

In [264]:
from utils import get_pdb_dataset
from protein_mpnn.protein_mpnn_utils import ProteinMPNN

In [265]:
hidden_dim = 128
num_layers = 3
file_path = os.getcwd()
model_folder_path = os.path.join(file_path, 'model_weights')
checkpoint_path = os.path.join(model_folder_path, 'v_48_002.pt')
folder_for_outputs = './tmp/'
temperatures = [float(item) for item in '0.1'.split()]
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
omit_AAs_list = 'X'
pdb_path = './pdbs/'
backbone_noise = 0.0

In [266]:
NUM_BATCHES = 2
BATCH_COPIES = 1

In [267]:
omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))

In [268]:
dataset_valid = get_pdb_dataset(pdb_path)

## Need to figure out how to add designable chains to input json

In [269]:
pdb_path_chains = 'A'

In [270]:
chain_id_dict = {}
for pdb in dataset_valid:
    all_chains = [item[-1:] for item in list(pdb) if item[:9]=='seq_chain']
    if pdb_path_chains:
        designable_chains = [str(item) for item in pdb_path_chains.split()]
    else:
        designable_chains = all_chains
    fixed_chains = [letter for letter in all_chains if letter not in designable_chains]
    chain_id_dict[pdb['name']] = (designable_chains, fixed_chains)

In [271]:
checkpoint = torch.load(checkpoint_path, map_location=device)
print('Number of edges:', checkpoint['num_edges'])
print('Training noise level:', checkpoint['noise_level'])

Number of edges: 48
Training noise level: 0.02


In [272]:
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

ProteinMPNN(
  (features): ProteinFeatures(
    (embeddings): PositionalEncodings(
      (linear): Linear(in_features=66, out_features=16, bias=True)
    )
    (edge_embedding): Linear(in_features=416, out_features=128, bias=False)
    (norm_edges): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (W_e): Linear(in_features=128, out_features=128, bias=True)
  (W_s): Embedding(21, 128)
  (encoder_layers): ModuleList(
    (0): EncLayer(
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (W1): Linear(in_features=384, out_features=128, bias=True)
      (W2): Linear(in_features=128, out_features=128, bias=True)
      (W3): Linear(in_features=128, out_features=128, bias=True)
     

In [273]:
start_time = time.time()
total_residues = 0
protein_list = []
total_step = 0

In [290]:
#ix, protein = 0, dataset_valid[0]
score_list = []
all_probs_list = []
all_log_probs_list = []
S_sample_list = []
batch_clones = [copy.deepcopy(protein) for i in range(BATCH_COPIES)]

In [289]:
protein.keys()

dict_keys(['seq_chain_A', 'coords_chain_A', 'name', 'num_of_chains', 'seq'])

In [288]:
protein['num_of_chains'] = 1

Want to find paths for residue_idx, randn_1

In [291]:
def get_model_inputs(batch, chain_dict=None):
    B = len(batch)
    L_max = max([len(b['seq']) for b in batch])
    X = np.zeros([B, L_max, 4, 3])
    S = np.zeros([B, L_max], dtype=np.int32)
    chain_M = np.zeros([B, L_max], dtype=np.int32)
    chain_M_pos = np.zeros([B, L_max], dtype=np.int32)
    chain_encoding_all = np.zeros([B, L_max], dtype=np.int32)
    residue_idx = -100 * np.ones([B, L_max], dtype=np.int32)

    for i, b in enumerate(batch):

        if chain_dict != None:
            masked_chains, visible_chains = chain_dict[b['name']]
            # Masked chains are designable chains
            # Visible chains are fixed chains
        else:
            masked_chains = [item[-1:] for item in list(b) if item[:10]=='seq_chain_']
            visible_chains = []
        all_chains = masked_chains + visible_chains

        x_chain_list = []
        chain_seq_list = []
        chain_mask_list = []
        fixed_position_mask_list = []
        chain_encoding_list = []
        c = 1
        l0, l1 = 0, 0

        for letter in all_chains:
            chain_coords = b[f'coords_chain_{letter}']
            x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1)
            x_chain_list.append(x_chain)

            chain_seq = b[f'seq_chain_{letter}']
            chain_seq = ''.join([a if a != '-' else 'X' for a in chain_seq])
            chain_length = len(chain_seq)
            chain_seq_list.append(chain_seq)

            l1 += chain_length
            residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
            l0 += chain_length
            chain_encoding_list.append(c*np.ones(chain_length))
            c += 1

            if letter in visible_chains:
                chain_mask = np.zeros(chain_length)
                chain_mask_list.append(chain_mask)

                fixed_position_mask = np.ones(chain_length)
                fixed_position_mask_list.append(fixed_position_mask)
            if letter in masked_chains:
                chain_mask = np.ones(chain_length)
                chain_mask_list.append(chain_mask)

                fixed_position_mask = np.ones(chain_length)
                # If there are fixed positions on the designable chain this is where the indicies are
                # mapped to 0.0
                fixed_position_mask_list.append(fixed_position_mask)


        x = np.concatenate(x_chain_list, 0)
        all_sequence = ''.join(chain_seq_list)  
        l = len(all_sequence)
        x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
        X[i, :, :, :] = x_pad

        indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
        S[i, :l] = indices

        m = np.concatenate(chain_mask_list, 0)
        m_pos = np.concatenate(fixed_position_mask_list, 0)
        m_pad = np.pad(m, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
        m_pos_pad = np.pad(m_pos, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
        chain_M[i, :] = m_pad
        chain_M_pos[i, :] = m_pos_pad

        chain_encoding = np.concatenate(chain_encoding_list, 0)
        chain_encoding_pad = np.pad(chain_encoding, [[0, L_max-l]], 'constant', constant_values=(0.0, ))
        chain_encoding_all[i, :] = chain_encoding_pad

    isnan = np.isnan(X)
    mask = np.isfinite(np.sum(X, (2,3))).astype(np.float32)
    X[isnan] = 0.

    S = torch.from_numpy(S).to(dtype=torch.long, device=device)
    X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
    residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long, device=device)
    chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
    chain_M_pos = torch.from_numpy(chain_M_pos).to(dtype=torch.float32, device=device)
    chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)

    return X, S, mask, residue_idx, chain_M, chain_M_pos, chain_encoding_all

X is a torch.Tensor of size [B, L_max, 4, 3] where B is the batch_size, L_max is the maximum length of a protein in the batch, 4 is the number of backbone atoms, and 3 is the x, y, z coordinates of the backbone atoms. It is padded with 0. up to the L_max for any protein smaller than L_max.

S is a torch.Tensor of size [B, L_max]. It is padded with 0. (not explicitly but from initialization) up to L_max for any protein smaller than L_max.

mask is a torch.Tensor of size [B, L_max] representing a residue-level mask with a 1.0 when a residue (or any of its atoms) are not present.

residue_idx is a torch.Tensor of size [B, L_max] representing the residue index of each residue (starting from 0). It is padded to L_max with -100.0 for all proteins less than L_max. Chains are separated by residue index of 100.

chain_M is a torch.Tensor of size [B, L_max] representing a residue-level mask where all residues in a designable chain are 1.0 and all residues in a fixed chain are 0.0.

chain_M_pos is a torch.Tensor of size [B, L_max] representing a residue-level mask where all residues that are designable are 1.0 and all residues that are fixed are 0.0.

chain_encoding_all is a torch.Tensor of size [B, L_max] with a unique integer value (starting at 1) for each different chain. It is padded with 0.0 for any protein smaller than L_max.

In [373]:
X, S, mask, residue_idx, chain_M, chain_M_pos, chain_encoding_all = get_model_inputs(batch_clones)

In [374]:
randn_1 = torch.randn(chain_M.shape, device=X.device)
randn_1.shape

torch.Size([1, 143])

#### Entering forward() of ProteinMPNN

In [365]:
X, S, mask, chain_M, residue_idx, chain_encoding_all, randn = X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1

##### Entering forward() of ProteinFeatures

In [295]:
b = X[:, :, 1, :] - X[:, :, 0, :] # CA - N distance
c = X[:, :, 2, :] - X[:, :, 1, :] # C - CA distance
a = torch.cross(b, c, dim=-1)
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + X[:,:,1,:] # Imputed Cb locations based on tetrahedral geometries
Ca = X[:,:,1,:]
N = X[:,:,0,:]
C = X[:,:,2,:]
O = X[:,:,3,:]

In [296]:
# Entering _dist(X, mask, eps=1e-6)
mask_2D = torch.unsqueeze(mask, 1) * torch.unsqueeze(mask, 2) # pair-wise verson of mask: 1.0 iff both res are fully present
dX = torch.unsqueeze(Ca, 1) - torch.unsqueeze(Ca, 2) # all pairwise vectors pointing from Ca1 to Ca2 for res1 and res2
D = mask_2D * torch.sqrt(torch.sum(dX**2, 3) + 1E-6) # Masked distance between CA for every residue pair. 0.0 if at least one res not fully present
D_max, _ = torch.max(D, -1, keepdim=True) # Gets furthest distances for every residue
D_adjust = D + (1. - mask_2D) * D_max # Pushes missing residues outside of closest range
sampled_top_k = 48
D_neighbors, E_idx = torch.topk(D_adjust, np.minimum(48, X.shape[1]), dim=-1, largest=False)

In [297]:
# Entering _rbf(D)
D_min, D_max, D_count = 2., 22., 16
D_mu = torch.linspace(D_min, D_max, D_count, device=device)
D_mu = D_mu.view([1, 1, 1, -1])
D_sigma = (D_max - D_min) / D_count
D_expand = torch.unsqueeze(D_neighbors, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)

In [375]:
chain_M = chain_M * mask
decoding_order = torch.argsort((chain_M+0.0001)*torch.abs(randn)) # [B, L_max] -> meaning: decoding_order[0,0]=130 means that the first decoded res is res 130, not that res 0 is decoded as the 130th res
mask_size = E_idx.shape[1]
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float() # [B, L_max, L_max] one hot version of the decoding
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)

In [405]:
decoding_order

tensor([[130,   3,   1,  90,  46,  98,  25,  22,  13, 139,   4, 134,  50, 127,
         124, 123,  54, 109, 103, 122,  35,  61,  84,  31,  82,  59,  29,  24,
          16,  74,  72,  83,  44,   6, 137,  58,  33, 117,  20,  53,   0,  62,
          64,  92, 138,   8,  65,  40, 118,  43, 115,  26, 121,  93,  49, 104,
          80,  23,  48, 101, 128, 131,  27,  34, 111,  11, 126,  85,  86,  36,
         135,  73,  69,  75,  28, 108,   9,  88,  30,  78,  47,  95, 116,  17,
          94,  19,  87, 113,  51,  37,  70,  45,  89,  99,   7,  56, 114, 102,
          67,  14, 141,  21,  60,  41, 119, 132,  63,  77, 106,  68, 136,  38,
          79,  10,   5,  39,  81, 107, 133,  52,  71, 140,  97, 112,  91, 105,
          12,   2,  76,  15,  96, 125,  57, 120,  32,  42, 129, 142,  66, 100,
          18,  55, 110]], device='cuda:0')

In [419]:
mask_fw[0, 1, :, 0]

tensor([1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0')

In [401]:
sum(order_mask_backward[0, 2, :])

tensor(127., device='cuda:0')

In [367]:
mask_bw.shape

torch.Size([1, 143, 48, 1])

In [319]:
decoding_order[0, 0]

tensor(130, device='cuda:0')

In [322]:
low_tri = (1 - torch.triu(torch.ones(mask_size, mask_size)))
low_tri

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [1., 1., 1.,  ..., 0., 0., 0.],
        [1., 1., 1.,  ..., 1., 0., 0.],
        [1., 1., 1.,  ..., 1., 1., 0.]])

In [321]:
order_mask_backward[0, 0, :]

tensor([0., 1., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        0., 0., 1., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1.,
        1., 0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0.,
        0., 1., 0., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0.],
       device='cuda:0')

In [368]:
#chain_M = chain_M * mask
decoding_order = torch.from_numpy(np.array([[1, 0]])) # [B, L_max]
mask_size = 2
permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float().to(device=device) # [B, L_max, L_max] one hot version of the decoding
order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
mask_attend = torch.gather(order_mask_backward, 2, torch.from_numpy(np.array([[[0, 1], [1, 0]]])).to(device=device, dtype=torch.int64)).unsqueeze(-1)
mask = torch.from_numpy(np.array([[1, 1]])).to(device=device)
mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
mask_bw = mask_1D * mask_attend
mask_fw = mask_1D * (1. - mask_attend)

In [371]:
mask_bw

tensor([[[[0.],
          [1.]],

         [[0.],
          [0.]]]], device='cuda:0')

In [330]:
permutation_matrix_reverse

tensor([[[0., 1.],
         [1., 0.]]], device='cuda:0')

In [331]:
order_mask_backward

tensor([[[0., 1.],
         [0., 0.]]], device='cuda:0')