In [1]:
from Bio.PDB import MMCIFParser
from Bio.PDB.Polypeptide import is_aa
from Bio.SeqUtils import seq1
from transformers import BertTokenizer, BertModel
import torch 
import os 
import numpy as np 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#Inializing here
tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert_bfd', do_lower_case=False)
model = BertModel.from_pretrained('Rostlab/prot_bert_bfd')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device).eval()



BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30, 1024, padding_idx=0)
    (position_embeddings): Embedding(40000, 1024)
    (token_type_embeddings): Embedding(2, 1024)
    (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-29): 30 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
            (value): Linear(in_features=1024, out_features=1024, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=1024, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.0, i

In [3]:
def seq2protbert(seq):
    # Tokenize the sequence
    seq = ' '.join(seq)
    inputs = tokenizer(seq, return_tensors='pt', add_special_tokens=True, padding=True, truncation=True)
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)

    with torch.no_grad():
        # Get ProtBERT embeddings
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state

    # ProtBERT embeddings to numpy
    embeddings = embeddings #.detach().cpu() #.numpy()
    attention_mask = attention_mask #.detach().cpu() #.numpy()
    features = []
    for seq_num in range(len(embeddings)):
        seq_len = (attention_mask[seq_num] == 1).sum()
        if seq_len > 2:
            seq_emd = embeddings[seq_num][1:seq_len-1]  # without [CLS] and [SEP]
            features.append(seq_emd)

    # Convert list of arrays to 2D array
    if features:
        features_2d = torch.vstack(features).to('cpu')  # Stack all sequences into a 2D array
        return features_2d

In [4]:
# Initialize dictionary to hold contact maps for each chain in edge index format
contact_maps = {}
init_node_features = [] # save seperate init AA features

In [4]:
# Define distance threshold for contacts
distance_threshold = 10.0

# convert adj map into edge index format for contact maps
def contact_map_to_edge_index(contact_map):
    # Step 1: Find the indices where there is a contact (i.e., where the value is 1)
    row, col = torch.nonzero(contact_map, as_tuple=True)

    # Step 2: Stack the row and col indices to create the edge index
    edge_index = torch.stack([row, col], dim=0)

    return edge_index

In [5]:
# create contact maps and get embeddings for AA per chain
def extract_ContactMap_SeqEmbedds(file_name):
    # file_name ='4hhb.cif'
    pdb_id = file_name.split('\\')[-1].split('.')[0]
    structure = MMCIFParser(QUIET=True).get_structure(pdb_id, file_name)
    
    # Iterate through chains
    for model in structure:
        model_id = model.get_id()
        # print(model_id)
        for chain in model:
            chain_id = chain.get_id()
            ca_atoms = []
            chain_seq = ''
            for residue in chain:
                if is_aa(residue) and 'CA' in residue:
                    chain_seq += seq1(residue.resname)
                    ca_atoms.append(residue['CA'].get_coord())

            num_atoms = len(ca_atoms)
            print('ca atom count ', num_atoms)#

            if len(chain_seq) > 60:
                node_features = seq2protbert(chain_seq)
                init_node_features.append(node_features)
                # Initialize a contact map matrix
                contact_map = torch.zeros((num_atoms, num_atoms))

                    # Calculate distances between all pairs of C-alpha atoms
                for i in range(num_atoms):  # through each c-alpha atom
                    for j in range(i+1, num_atoms): # cal euclidean norm with other c-alphas = total c-alphas - 1
                        distance = np.linalg.norm(ca_atoms[i] - ca_atoms[j])
                        # distance = ca_atoms[i].get - ca_atoms[j]    # cal euclidean distance
                        if distance <= 10.0 and distance > 0:
                            contact_map[i, j] = 1
                            contact_map[j, i] = 1
                # print(contact_map)

                    # Save the contact map in the dictionary
                map_name = f'{pdb_id}_{model_id}_{chain_id}'
                print(map_name)

                    # Step 2: Stack the row and col indices to create the edge index
                edge_index = contact_map_to_edge_index(contact_map)
                # global contact_maps
                contact_maps[map_name] = edge_index

In [7]:
## merge seperate contact maps in edge index format into a single combined edge index
def combine_contact_maps(contact_maps):
    # Cumulative offset for node indices
    cumulative_offset = 0
    combined_edge_index = []

    for edge_index in contact_maps:
        # Adjust the node indices for each graph
        adjusted_edge_index = edge_index + cumulative_offset
        combined_edge_index.append(adjusted_edge_index)

        # Update the cumulative offset based on the max node index in the current graph
        cumulative_offset += edge_index.max().item() + 1

    # Concatenate all edge indices into one big edge index
    combined_edge_index = torch.cat(combined_edge_index, dim=1)
    # print(combined_edge_index)
    return combined_edge_index

In [8]:
def return_init_node_features(node_features_list):
    init_node_features = torch.cat(node_features_list, dim=0)
    return init_node_features

In [9]:
# Specify the folder path
folder_path = 'D:\\year 4\\semester 1\\BT\\BT 4033\\structure\\temp\\'

# List all files in the specified folder
file_list = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]

In [None]:
for file in file_list:
    file_path = folder_path + file
    try:
        extract_ContactMap_SeqEmbedds(file_name=file_path)
        print(file, ' done ************************************\n')
    except:
        continue

# print(init_node_features)
final_AA_list = combine_contact_maps(contact_maps.values())
node_features = return_init_node_features(init_node_features)
print(final_AA_list)
print(node_features.shape)

In [55]:
torch.save(final_AA_list, 'aa_edge_indices.pt')
torch.save(node_features, 'init_node_features.pt')