In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json, time, os, sys, glob
import shutil
import warnings
import numpy as np
import pandas as pd
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
import subprocess
from tqdm import tqdm
from omegaconf import OmegaConf

from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, ProteinMPNN
from kaggle_dataset import KaggleTrainDataset

In [34]:
cfg = OmegaConf.load("config.yaml")
dataset = KaggleTrainDataset(cfg, "train")
wt_feat, mut_feat, out, position = dataset[0]

NameError: name 'cfg' is not defined

In [3]:
pdb = parse_PDB("data/wildtype_structure_prediction_af2.pdb")
device='cuda:0'
# dataset = StructureDatasetPDB(pdb, max_length=500)

hidden_dim = 128
num_layers = 3 

checkpoint_path = "vanilla_model_weights/v_48_020.pt"
checkpoint = torch.load(checkpoint_path, map_location=device) 
model = ProteinMPNN(ca_only=False, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=0.0, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# freeze these weights for transfer learning
for param in model.parameters():
    param.requires_grad = False

In [10]:
decoding_order = torch.tensor([list(range(len(pdb[0]['seq'])))])
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize([pdb[0]], device, None, None, None, None, None, None, ca_only=False)
randn_1 = torch.randn(chain_M.shape, device=X.device)
model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1)[0]

tensor([[ 0.1011,  0.0839, -0.1516,  ..., -0.0291,  0.0226, -0.0672],
        [ 0.0814,  0.0568, -0.0762,  ..., -0.0162,  0.0897, -0.0294],
        [ 0.0838,  0.0809, -0.0517,  ..., -0.0077, -0.0020, -0.1106],
        ...,
        [-0.1671,  0.3427,  0.0820,  ...,  0.1484,  0.1602,  0.1255],
        [ 0.0211,  0.0867, -0.0278,  ...,  0.1656,  0.0881, -0.0983],
        [-0.0292, -0.2701, -0.0051,  ..., -0.0196, -0.1481, -0.2710]],
       device='cuda:0')

In [8]:
df = pd.read_csv("data/test.csv")
pdb = parse_PDB("data/wildtype_structure_prediction_af2.pdb")
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
wt = pdb[0]
wt_seq = wt['seq']
scores = []
for i, row in tqdm(df.iterrows(), total=len(df)):
    if len(row.protein_sequence) < len(wt_seq):
        scores.append(-100)
        continue # ignore deletions for now
    eq = [ c1 != c2 for c1, c2 in zip(row.protein_sequence, wt_seq)]
    if sum(eq) == 0:
        scores.append(0)
        continue # we found the wt sequence
    assert sum(eq) == 1

    idx = eq.index(True)
    other_indexes = list(range(len(wt_seq)))
    other_indexes.remove(idx)
    # random.shuffle(other_indexes)
    decoding_order = other_indexes + [ idx ]
    decoding_order = torch.tensor([decoding_order], device=device)

    aa_idx = alphabet.index(row.protein_sequence[idx])

    wt['seq'] = row.protein_sequence
    X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize([pdb[0]], device, None, None, None, None, None, None, ca_only=False)
    with torch.no_grad():
        log_probs = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, None, True, decoding_order)
        score = float(log_probs[0][idx][aa_idx])
    scores.append(score)
    

100%|██████████| 2413/2413 [00:43<00:00, 55.06it/s]


In [38]:
f"{3333.233333333:.3f}"

'3333.233'

In [35]:
dataset = KaggleTrainDataset()
dataset.df.query("WT == 'M' and MUT == 'L'")
# loader = torch.utils.data.DataLoader(dataset, batch_size=2)
# next(iter(loader))

Unnamed: 0,PDB,WT,position,MUT,dTm,sequence,mutant_seq,CIF
211,GP02,M,128,L,-0.488626,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,AF-A0A2W0F5X5-F1
266,GP02,M,405,L,-0.488626,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,AF-A0A2W0F5X5-F1
272,GP02,M,419,L,1.511374,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,MNQSVSSLPEKDIQYQLHPYTNARLHQELGPLIIERGEGIYVYDDQ...,AF-A0A2W0F5X5-F1
686,GP10,M,210,L,-1.416667,MKFLQIIPVLLSLTSTTLAQSFCSSASHSGQSVKETGNKVGTIGGV...,MKFLQIIPVLLSLTSTTLAQSFCSSASHSGQSVKETGNKVGTIGGV...,AF-Q9UV68-F1
807,GP12,M,32,L,10.270513,MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRMVIPLASARLLSDKV...,MQFKVYTYKRESRYRLFVDVQSDIIDTPGRRLVIPLASARLLSDKV...,AF-A0A142CN06-F1


In [13]:
hidden_dim = 128
num_layers = 3 

checkpoint_path = "vanilla_model_weights/v_48_020.pt"
checkpoint = torch.load(checkpoint_path, map_location=device) 
model = ProteinMPNN(ca_only=False, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=0.0, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# freeze these weights for transfer learning
for param in model.parameters():
    param.requires_grad = False

NameError: name 'device' is not defined

In [19]:
pd.read_csv("data/all_train_data_v17.csv")#.query("source == 'kaggle.csv'").reset_index(drop=True).ddG[0]

Unnamed: 0,PDB,wildtype,position,mutation,ddG,sequence,mutant_seq,source,dTm,CIF
0,1A5E,L,121,R,0.66,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
1,1A5E,L,37,S,0.71,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGASPNAPNSYGR...,jin_train.csv,,
2,1A5E,W,15,D,0.17,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADDLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
3,1A5E,D,74,N,-2.00,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
4,1A5E,P,81,L,0.00,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,MEPAAGSSMEPSADWLATAAARGRVEEVRALLEAGALPNAPNSYGR...,jin_train.csv,,
...,...,...,...,...,...,...,...,...,...,...
6804,GP77,R,32,A,,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFQRQHMDSDSS...,MALEKSLVRLLLLVLILLVLGWVQPSLGKESAAKKFQRQHMDSDSS...,kaggle.csv,1.52,AF-P07998-F1
6805,GP77,K,34,A,,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFQRQHMDSDSS...,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAAKFQRQHMDSDSS...,kaggle.csv,1.02,AF-P07998-F1
6806,GP77,Q,37,E,,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFQRQHMDSDSS...,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFERQHMDSDSS...,kaggle.csv,-0.38,AF-P07998-F1
6807,GP77,S,45,N,,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFQRQHMDSDSS...,MALEKSLVRLLLLVLILLVLGWVQPSLGKESRAKKFQRQHMDSDNS...,kaggle.csv,-1.68,AF-P07998-F1
