In [1]:
import json
import numpy as np

## 1. Load model

In [2]:
import argparse
parser = argparse.ArgumentParser()
# Set-up parameters
parser.add_argument('--device', default='cuda', type=str, help='Name of device to use for tensor computations (cuda/cpu)')
parser.add_argument('--display_step', default=10, type=int, help='Interval in batches between display of training metrics')
parser.add_argument('--res_dir', default='./results', type=str)
parser.add_argument('--ex_name', default='debug', type=str)
parser.add_argument('--use_gpu', default=True, type=bool)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--seed', default=111, type=int)

# CATH
# dataset parameters
parser.add_argument('--data_name', default='CATH', choices=['CATH', 'TS50'])
parser.add_argument('--data_root', default='./data/')
parser.add_argument('--batch_size', default=8, type=int)
parser.add_argument('--num_workers', default=8, type=int)

# method parameters
parser.add_argument('--method', default='ProDesign', choices=['ProDesign'])
parser.add_argument('--config_file', '-c', default=None, type=str)
parser.add_argument('--hidden_dim',  default=128, type=int)
parser.add_argument('--node_features',  default=128, type=int)
parser.add_argument('--edge_features',  default=128, type=int)
parser.add_argument('--k_neighbors',  default=30, type=int)
parser.add_argument('--dropout',  default=0.1, type=int)
parser.add_argument('--num_encoder_layers', default=10, type=int)

# Training parameters
parser.add_argument('--epoch', default=100, type=int, help='end epoch')
parser.add_argument('--log_step', default=1, type=int)
parser.add_argument('--lr', default=0.001, type=float, help='Learning rate')
parser.add_argument('--patience', default=100, type=int)

# ProDesign parameters
parser.add_argument('--updating_edges', default=4, type=int)
parser.add_argument('--node_dist', default=1, type=int)
parser.add_argument('--node_angle', default=1, type=int)
parser.add_argument('--node_direct', default=1, type=int)
parser.add_argument('--edge_dist', default=1, type=int)
parser.add_argument('--edge_angle', default=1, type=int)
parser.add_argument('--edge_direct', default=1, type=int)
parser.add_argument('--virtual_num', default=3, type=int)
args = parser.parse_args([])

import torch
from main import Exp
from parser import create_parser
exp = Exp(args)
svpath = './results/ProDesign/'
exp.method.model.load_state_dict(torch.load(svpath+'checkpoint.pth'))

Use GPU: cuda:0

device: 	cuda	
display_step: 	10	
res_dir: 	./results	
ex_name: 	debug	
use_gpu: 	True	
gpu: 	0	
seed: 	111	
data_name: 	CATH	
data_root: 	./data/	
batch_size: 	8	
num_workers: 	8	
method: 	ProDesign	
config_file: 	None	
hidden_dim: 	128	
node_features: 	128	
edge_features: 	128	
k_neighbors: 	30	
dropout: 	0.1	
num_encoder_layers: 	10	
epoch: 	100	
log_step: 	1	
lr: 	0.001	
patience: 	100	
updating_edges: 	4	
node_dist: 	1	
node_angle: 	1	
node_direct: 	1	
edge_dist: 	1	
edge_angle: 	1	
edge_direct: 	1	
virtual_num: 	3	


<All keys matched successfully>

## 2. Results on CATH4.2

In [None]:
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))

## 3. Results on TS50

In [15]:
from API.dataloader import make_cath_loader
from API.cath_dataset import CATH

with open('./data/ts/ts50.json','r') as f:
    ts50 = json.load(f)


ts50_list = []
for entry in ts50:
    coords = np.array(entry['coords'])
    ts50_list.append(
        {
                    'title':entry['name'],
                    'seq':entry['seq'],
                    'CA':coords[:,1,:],
                    'C':coords[:,2,:],
                    'O':coords[:,3,:],
                    'N':coords[:,0,:]
        }
    )
exp.test_loader = make_cath_loader(CATH(data=ts50_list), 'SimDesign', 8)

In [16]:
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))

test loss: 1.2509: 100%|██████████| 7/7 [00:01<00:00,  5.54it/s]
100%|██████████| 50/50 [00:01<00:00, 29.85it/s]


Test Perp: 3.8553, Test Rec: 0.5872

Category Unknown Rec: 0.5872

median: 0.5872	 mean: 0.5654	 std: 0.0887	 min: 0.3724	 max: 0.7530


## Results on TS500

In [17]:
with open('./data/ts/ts500.json','r') as f:
    ts500 = json.load(f)


ts500_list = []
for entry in ts500:
    coords = np.array(entry['coords'])
    ts500_list.append(
        {
                    'title':entry['name'],
                    'seq':entry['seq'],
                    'CA':coords[:,1,:],
                    'C':coords[:,2,:],
                    'O':coords[:,3,:],
                    'N':coords[:,0,:]
        }
    )
exp.test_loader = make_cath_loader(CATH(data=ts500_list), 'SimDesign', 8)

In [18]:
exp.test()
print("median: {:.4f}\t mean: {:.4f}\t std: {:.4f}\t min: {:.4f}\t max: {:.4f}".format(exp.method.median_recovery, exp.method.mean_recovery, exp.method.std_recovery, exp.method.min_recovery, exp.method.max_recovery))

test loss: 1.6132: 100%|██████████| 63/63 [00:06<00:00,  9.68it/s]
100%|██████████| 500/500 [00:17<00:00, 28.06it/s]

Test Perp: 3.4403, Test Rec: 0.6046

Category Unknown Rec: 0.6046

median: 0.6046	 mean: 0.5932	 std: 0.1077	 min: 0.0296	 max: 0.8680





## PDB

In [3]:
import os
import gzip
import numpy as np
from collections import defaultdict

AAMAP = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q',
    'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W',
    'TYR': 'Y', 'VAL': 'V',
    'ASX': 'B', 'GLX': 'Z', 'SEC': 'U', 'PYL': 'O', 'XLE': 'J', '': '-'
}

# def get_pdb(pdb_code=""):
#   if pdb_code is None or pdb_code == "":
#     upload_dict = files.upload()
#     pdb_string = upload_dict[list(upload_dict.keys())[0]]
#     with open("tmp.pdb","wb") as out: out.write(pdb_string)
#     return "tmp.pdb"
#   else:
#     os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
#     return f"{pdb_code}.pdb"

def getSequence(resnames):
    """Returns polypeptide sequence as from list of *resnames* (residue
    name abbreviations)."""

    get = AAMAP.get
    return ''.join([get(rn, 'X') for rn in resnames])

def gzip_open(filename, *args, **kwargs):
    if args and "t" in args[0]:
        args = (args[0].replace("t", ""), ) + args[1:]
    if isinstance(filename, str):
        return gzip.open(filename, *args, **kwargs)
    else:
        return gzip.GzipFile(filename, *args, **kwargs)

def parsePDB(pdb, chain=['A']):
    title, ext = os.path.splitext(os.path.split(pdb)[1])
    title, ext = os.path.splitext(title)
    pdb = gzip_open(pdb, 'rt')
    
    lines = defaultdict(list)
    for loc, line in enumerate(pdb):
        line = line.decode('ANSI_X3.4-1968')
        startswith = line[0:6]
        lines[startswith].append((loc, line))
    pdb.close()
    
    sequence = ''
    # for idx, line in lines['SEQRES']:
    #     if line[11:12].strip() not in chain:
    #         continue
    #     sequence += ''.join(getSequence(line[19:].split()))
    
    CA_coords, C_coords, O_coords, N_coords = [], [], [], []
    
    # chain_id = []
    for idx, line in lines['ATOM  ']:
        if line[21:22].strip() not in chain:
            continue
        if line[13:16].strip() == 'CA':
            CA_coord = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
            CA_coords.append(CA_coord)
            sequence += ''.join(getSequence([line[17:20]]))
        elif line[13:16].strip() == 'C':
            C_coord = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
            C_coords.append(C_coord)
        elif line[13:16].strip() == 'O':
            O_coord = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
            O_coords.append(O_coord)
        elif line[13:16].strip() == 'N':
            N_coord = [float(line[30:38]), float(line[38:46]), float(line[46:54])]
            N_coords.append(N_coord)
    


    return {'title': title,
            'seq': sequence,
            'CA': np.array(CA_coords),
            'C': np.array(C_coords),
            'O': np.array(O_coords),
            'N': np.array(N_coords),
            'score' : 100.0}

In [4]:
data = parsePDB('/gaozhangyang/experiments/ProDesign/example/1o91.pdb1.gz', ['A', 'B', 'C'])

In [13]:
import torch.nn.functional as F
from API.dataloader_gtrans import featurize_GTrans
alphabet='ACDEFGHIKLMNPQRSTVWY'
batch = featurize_GTrans([data])
from methods.utils import cuda
X, S, score, mask, lengths = cuda(batch, device = exp.device)
X, S, score, h_V, h_E, E_idx, batch_id, mask_bw, mask_fw, decoding_order = exp.method.model._get_features(S, score, X=X, mask=mask)
log_probs, logits = exp.method.model(h_V, h_E, E_idx, batch_id, return_logit = True)

temperature = 0.1
probs = F.softmax(logits/temperature, dim=-1)
S_pred = torch.multinomial(probs, 1).view(-1)

recovery = torch.mean((S==S_pred).float())
S_design = ''.join([alphabet[i] for i in S_pred])

print(recovery)
print(S_design)

tensor(0.6361, device='cuda:0')
MMPAFTACLTTGYPPVGEPVKFDKILYNGRATYDPETGIWTCVVPGTYYFAWVVHCYGGDVLIALYKNDTPMMWVYLEYVDGKLDQASGSAVLRLEPGDEVYLEIPGESADGLYAGEFVHSLFSGFLLHPTEAPAFTALLTTPYPPVGEPIKFDKLLYNGLNVYDPETGIYTCQVPGIYYFAWTVHCLGGDVLVSLYKNDEPMMWTYMEHVEGRLSQASGDAVLELKPGDKVYLEQPTKLANGLAAGDDDHSYFSGFLLHPTEEPAFTALLTVGYPPVGEPIKFDKLLYNGRDVYDPETGIWTCKVPGIYYFAFVVHTKGNDVLVQLYKNDTPMMRVYLEHIDGKLSQASGSGVLRLEKGDKVYIQQPYESANGLAAGARIHSWLSGFLLHPL


In [19]:
S_true = data['seq']
ProteinMPNN = "EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEKEVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEKEVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK"

In [20]:
ProteinMPNN

'EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK/EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK/EVEAFTALLTTPHPPVGEPIKFNKLVYNGRNVYDPATGIFTVKTPGVYFFTFVLYVYGADLHAELMKNDTPVIKVYLQTVNGKINQVSGAAVLELEEGDKVYVKIPSASANGLWASADAHSYFSGYLLTEK'