In [1]:
import torch
import esm
from esm import pretrained
import math
import numpy as np
import matplotlib.pyplot as plt
import json, pickle
from collections import OrderedDict
import os
from tqdm import tqdm

import dgl
import torch

In [2]:
# data prepare
def protein_graph_construct(proteins, save_dir):
    # Load ESM-1b model
    # torch.set_grad_enabled(False)
    model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()
    target_graph = {}

    count = 0
    key_list=[]
    for key in proteins:
        key_list.append(key)

    device = torch.device("cuda:0")
    model.to(device)
    for k_i in tqdm(range(len(key_list))):
        key=key_list[k_i]
        # if len(proteins[key]) < 1500:
        #     continue
        # print('=============================================')
        data = []
        pro_id = key
        if os.path.exists(save_dir + pro_id + '.npy'):
            continue
        seq = proteins[key]
        if len(seq) <= 1000:
            data.append((pro_id, seq))
            batch_labels, batch_strs, batch_tokens = batch_converter(data)
            with torch.no_grad():
                results = model(batch_tokens.to(device), repr_layers=[33], return_contacts=True)
            contact_map = results["contacts"][0]
            target_graph[pro_id] = contact_map.cpu().numpy()
        else:
            contact_prob_map = np.zeros((len(seq), len(seq)))  # global contact map prediction
            interval = 500
            i = math.ceil(len(seq) / interval)
            # ======================
            # =                    =
            # =                    =
            # =          ======================
            # =          =*********=          =
            # =          =*********=          =
            # ======================          =
            #            =                    =
            #            =                    =
            #            ======================
            # where * is the overlapping area
            # subsection seq contact map prediction
            for s in range(i):
                start = s * interval  # sub seq predict start
                end = min((s + 2) * interval, len(seq))  # sub seq predict end
                sub_seq_len = end - start

                # prediction
                temp_seq = seq[start:end]
                temp_data = []
                temp_data.append((pro_id, temp_seq))
                batch_labels, batch_strs, batch_tokens = batch_converter(temp_data)
                with torch.no_grad():
                    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
                # insert into the global contact map
                row, col = np.where(contact_prob_map[start:end, start:end] != 0)
                row = row + start
                col = col + start
                contact_prob_map[start:end, start:end] = contact_prob_map[start:end, start:end] + results["contacts"][
                    0].numpy()
                contact_prob_map[row, col] = contact_prob_map[row, col] / 2.0
                if end == len(seq):
                    break
            target_graph[pro_id] = contact_prob_map

        #np.save(save_dir + pro_id + '.npy', target_graph[pro_id])
        count += 1
        # # for test
        # print(count, 'of', len(proteins))
        # print('protein id', pro_id)
        # print('seq length:', len(seq))
        # print(target_graph[pro_id].shape)
        # print(len(np.where(target_graph[pro_id] >= 0.5)[0]))
        # plt.matshow(target_graph[pro_id][: len(seq), : len(seq)])
        # plt.title(pro_id)
        # plt.savefig('test/' + pro_id + '.png')
        # print('=============================================')
    return target_graph

In [25]:
def get_uv(adj):
    u, v = [], []
    weight = []
    m, n = len(adj), len(adj[0])
    for i in range(m):
        for j in range(n):
            if j == i-1 or j == i+1:
                u.append(i)
                v.append(j)
                weight.append(1.0)
                continue
            if adj[i][j] > 0.5:
                u.append(i)
                v.append(j)
                weight.append(adj[i][j])
                
    return u, v, weight

def sequence_to_graph(data, distance):  #{id:seq}; {id:map}
    Gs = dict()
    
    for id, feats in tqdm(data.items()):
        contact_map = distance[id]
        u, v, weight = get_uv(contact_map)
        g = dgl.graph((u,v))
        g.ndata["feat"] = feats
        g.edata["weight"] = torch.tensor(weight, dtype=torch.float32)
        
        Gs[id] = g
        
    return Gs

In [3]:
task = "BIOSNAP" 
with open(f"{task}/prot_seq_raw.pkl", 'rb') as f:
    prot = pickle.load(f)
print(len(Dprot))

4294


In [None]:
save_dir = "prot_graph/"
contace_map = protein_graph_construct(prot, save_dir)
list(DTG.values())[0]

In [None]:
with open(f"{task}/prot_seq.pkl", 'rb') as f:
    prot_seq = pickle.load(f)
print(len(prot_seq))

In [None]:
Gs = sequence_to_graph(prot_seq, contace_map)

In [None]:
with open(f"{task}/prot_graph.pkl", 'wb') as f:
    pickle.dump(Gs, f)