In [106]:
import torch_geometric
from SLAM_combine import *

In [185]:
def parse_pdb_chain(pdb_file, chain='A',pos=None, atom_type='CA', nneighbor=32, cal_cb=True):
    """
    ########## Process PDB file ##########
    """
    current_pos = -1000
    X = []
    current_aa = {} # N, CA, C, O, R
    with open(pdb_file, 'r') as pdb_f:
        for line in pdb_f:
            if line[21] == chain:
                if (line[0:4].strip() == "ATOM" and int(line[22:26].strip()) != current_pos) or line[0:4].strip() == "TER":
                    if current_aa != {}:
                        R_group = []
                        for atom in current_aa:
                            if atom not in ["N", "CA", "C", "O"]:
                                R_group.append(current_aa[atom])
                        if R_group == []:
                            R_group = [current_aa["CA"]]
                        R_group = np.array(R_group).mean(0)
                        X.append([current_aa["N"], current_aa["CA"], current_aa["C"], current_aa["O"], R_group])
                        current_aa = {}
                    if line[0:4].strip() != "TER":
                        current_pos = int(line[22:26].strip())

                if line[0:4].strip() == "ATOM":
                    atom = line[13:16].strip()
                    if atom != "H":
                        xyz = np.array([line[30:38].strip(), line[38:46].strip(), line[46:54].strip()]).astype(np.float32)
                        current_aa[atom] = xyz
    X = np.array(X)
    if cal_cb:
        X = np.concatenate([X, get_cb(X[:,0], X[:,1], X[:,2])[:, None]], 1)
    if pos is not None:
        atom_ind = atom_idx[atom_type] # CA atom
        if pos >= X.shape[0]:
            pos = X.shape[0] - 1
        query_coord = X[pos,atom_ind]
        distances = calculate_distances(X[:,atom_ind,:], query_coord)
        closest_indices = sorted(np.argsort(distances)[:nneighbor])
        X = X[(closest_indices)]
        # print(closest_indices)
    return X, closest_indices # array shape: [Length, 6, 3] N, CA, C, O, R, CB

def get_graph_fea_chain(pdb_path, pos, chain='A', nneighbor=32, radius=10, atom_type='CA', cal_cb=True):
    X, closest_indices = parse_pdb_chain(pdb_path, chain=chain,pos=pos, atom_type=atom_type, nneighbor=nneighbor, cal_cb=cal_cb)
    X = torch.tensor(X).float()
    query_atom = X[:, atom_idx[atom_type]]
    edge_index = radius_graph(query_atom, r=radius, loop=False, max_num_neighbors=nneighbor, num_workers = 4)
    node, edge = get_geo_feat(X, edge_index)
    return Data(x=node, edge_index=edge_index, edge_attr=edge, name=os.path.basename(pdb_path).split('.')[0], near=closest_indices)

def _get_encoding(seq, feature=[BLOSUM62, BINA]):
    alphabet = 'ARNDCQEGHILKMFPSTWYVX'
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    sample = ''.join([re.sub(r"[UZOB*]", "X", token) for token in seq])
    # seq = [char_to_int[char] for char in sample]
    max_len = len(sample)
    all_fea = []
    for encoder in feature:
        fea = encoder([sample])
        assert fea.shape[0] == max_len
        all_fea.append(fea)
    return np.hstack(all_fea)
    
def get_all_inputs(seq, pos, tokenizer, pdb_path):
    data = get_graph_fea_chain(pdb_path, pos, nneighbor=32, atom_type='CA', cal_cb=True)
    fea = _get_encoding(seq)
    s = ''.join([token for token in re.sub(r"[UZOB*]", "X", seq.rstrip('*'))])
    max_len = len(s)
    encoded = tokenizer.encode_plus(seq, add_special_tokens=True, padding='max_length', return_token_type_ids=False, pad_to_max_length=True,truncation=True, max_length=max_len, return_tensors='pt')
    input_ids = encoded['input_ids']
    attention_mask = encoded['attention_mask']
    return data, input_ids, attention_mask, torch.tensor(fea, dtype=torch.float)
    

In [186]:
def _get_encoding(seq, feature=[BLOSUM62, BINA]):
    alphabet = 'ARNDCQEGHILKMFPSTWYVX'
    char_to_int = dict((c, i) for i, c in enumerate(alphabet))
    int_to_char = dict((i, c) for i, c in enumerate(alphabet))
    sample = ''.join([re.sub(r"[UZOB*]", "X", token) for token in seq])
    # seq = [char_to_int[char] for char in sample]
    max_len = len(sample)
    all_fea = []
    for encoder in feature:
        fea = encoder([sample])
        assert fea.shape[0] == max_len
        all_fea.append(fea)
    return np.hstack(all_fea)

class SLAMPredDataset(object):
    def __init__(self, peplist, tokenizer, pdb_path):
        self.tokenizer = tokenizer
        self.feature_list = []
        self.graphlist = []
        self.seqlist = []
        self.label_list = []
        for record in tqdm(peplist):
            seq = record[-1]
            desc = record[0].split('|')
            name, pos, length = desc[0], int(desc[2]), int(desc[3])
            data = get_graph_fea(pdb_path, pos, nneighbor=32, atom_type='CA', cal_cb=True)
            self.graphlist.append(data)
            fea = _get_encoding(seq)
            self.feature_list.append(fea)
            self.seqlist.append(seq)
            self.label_list.append(desc)
    
    def __len__(self):
        return len(self.graphlist)
    
    def __getitem__(self, index):
        seq = self.seqlist[index]
        seq = [token for token in re.sub(r"[UZOB*]", "X", seq.rstrip('*'))]
        max_len = len(seq)
        encoded = self.tokenizer.encode_plus(' '.join(seq), add_special_tokens=True, padding='max_length', return_token_type_ids=False, pad_to_max_length=True,truncation=True, max_length=max_len, return_tensors='pt')
        input_ids = encoded['input_ids'].flatten()
        attention_mask = encoded['attention_mask'].flatten()
        self.graphlist[index], input_ids, attention_mask, torch.tensor(self.feature_list[index], dtype=torch.float), self.label_list[index]

In [187]:
def get_peptide(pos, window_size, seq, mirror=True):
    """Return peptide based on window_size. Missing residues are padded with X symbol (if mirror == False) or mirroring residues from the other side (if mirror == True)."""
    pos = pos-1
    half_window = int(window_size/2)
    start = pos - half_window
    left_padding = '' if start >= 0 else 'X' * abs(start)
    start = 0 if start < 0 else start
    end = pos + half_window + 1
    right_padding = 'X' * half_window
    end = len(seq) if end + 1 > len(seq) else end
    peptide_ = seq[start:end]
    if mirror:
        if left_padding == '' and right_padding == '':
            peptide = left_padding + peptide_ + right_padding
        elif left_padding == '' and right_padding != '': # mirror left
            peptide = left_padding + peptide_ + peptide_[:len(right_padding)][::-1]
        elif left_padding != '' and right_padding == '': # mirror right
            peptide = peptide_[::-1][:len(left_padding)] + peptide_ + right_padding
        else:
            peptide = None
    else:
        peptide = left_padding + peptide_ + right_padding
    if peptide is not None:
        peptide = peptide[:window_size]
        assert peptide[half_window] == 'K' and len(peptide) == window_size
        return peptide
    else:
        return None

In [188]:
def get_all_k(seqlist, window_size=51):
    peplist = []
    window_size = window_size
    half_window = window_size // 2
    for record in seqlist:
        seq = str(record.seq)
        for m in re.finditer('K', seq):
            pos = m.start() + 1
            pep = get_peptide(pos, window_size, seq, mirror=False)
            if pep is not None:
                peplist.append([f'{pdb_id}|Pred|{pos}|{len(seq)}', pep])
    return peplist

In [189]:
pdb_id = '5w49'
pdb_path = f'proteomes/{pdb_id}.pdb'
seq_path = f'proteomes/{pdb_id}.fa'

pretrained_model = '/mnt/data/zhqin/pretrain_LM/prot_bert/'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model, do_lower_case=False, use_fast=False)
window_size = 51
seqlist = [record for record in SeqIO.parse(seq_path, "fasta")]
peplist = get_all_k(seqlist, window_size=window_size)
# pep_ds = SLAMPredDataset(peplist, tokenizer, pdb_path)



In [190]:
str(seqlist[0].seq)

'KLPYKVADIGLAAWGRKALDIAENEMPGLMRMRERYSASKPLKGARIAGCLHMTVETAVLIETLVTLGAEVQWSSCNIFSTQDHAAAAIAKAGIPVYAWKGETDEEYLWCIEQTLYFKDGPLNMILDDGGDLTNLIHTKYPQLLPGIRGISEETTTGVHNLYKMMANGILKVPAINVNDSVTKSKFDNLYGCRESLIDGIKRATDVMIAGKVAVVAGYGDVGKGCAQALRGFGARVIITEIDPINALQAAMEGYEVTTMDEACQEGNIFVTTTGCIDIILGRHFEQMKDDAIVCNIGHFDVEIDVKWLNENAVEKVNIKPQVDRYRLKNGRRIILLAEGRLVNLGCAMGHPSFVMSNSFTNQVMAQIELWTHPDKYPVGVHFLPKKLDEAVAEAHLGKLNVKLTKLTEKQAQYLGMSCDGPFKPDHYRY'

In [228]:
predictions = []
model.eval()
for desc, seq in peplist:
    seq = str(seq)
    tmp = desc.split('|')
    pos = int(tmp[2])
    print(pos,seq)
    g, input_ids, attention_mask, feature = get_all_inputs(seq,pos,tokenizer,pdb_path)
    feature = feature.unsqueeze(0).to(device)
    g = g.to(device)
    g.batch = torch.zeros(g.x.shape[0],dtype=torch.int64).to(device)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    pred = model(input_ids=input_ids, attention_mask=attention_mask, feature=feature, g_data=g)
    result = [seq,]
    break

1 XXXXXXXXXXXXXXXXXXXXXXXXXKLPYKVADIGLAAWGRKALDIAENEM


In [229]:
pred, seq

(tensor([[0.0022]], device='cuda:0', grad_fn=<SigmoidBackward0>),
 'XXXXXXXXXXXXXXXXXXXXXXXXXKLPYKVADIGLAAWGRKALDIAENEM')

In [225]:
''.join(reversed('LPYKVADIGLAAWGRKALDIAENEM'))

'MENEAIDLAKRGWAALGIDAVKYPL'

In [7]:
loader = DataLoader(pep_ds,batch_size=len(pep_ds),shuffle=False,num_workers=8, collate_fn=graph_collate_fn, prefetch_factor=2)

In [194]:
gpu = 0
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')
n_layers = 1
dropout = 0.5
embedding_dim = 32
hidden_dim = 64
out_dim = 32

node_dim = 267
edge_dim = 632
nneighbor = 32
atom_type = 'CA' # CB, R, C, N, O
gnn_layers = 5
encoder_list = ['cnn','lstm','fea', 'gnn', 'plm']
fea_dim = 41
PLM_dim = 1024
BERT_encoder = AutoModel.from_pretrained(pretrained_model, local_files_only=True, output_attentions=False).to(device)

Some weights of the model checkpoint at /mnt/data/zhqin/pretrain_LM/prot_bert/ were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [195]:
model_file = f'Models/SLAM_combine/general_struct_plm/best_general_struct_plm_model_epoch.pt'  
model = SLAMNet(BERT_encoder=BERT_encoder, vocab_size=tokenizer.vocab_size, encoder_list=encoder_list,PLM_dim=PLM_dim,win_size=window_size,embedding_dim=embedding_dim, fea_dim=fea_dim, hidden_dim=hidden_dim, out_dim=out_dim,node_dim=node_dim, edge_dim=edge_dim, gnn_layers=gnn_layers,n_layers=n_layers,dropout=dropout).to(device)
model.load_state_dict(torch.load(model_file))

<All keys matched successfully>