In [1]:
import pandas as pd
import numpy as np
import os
from tqdm.auto import tqdm

In [5]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
df = load_dataset("wanglab/cafa5", "cafa5_reasoning")['train'].to_pandas()
df = df.reset_index()  # make sure indexes pair with number of rows
# ds = load_dataset("wanglab/cafa5", "interpro_metadata")

In [7]:
from tqdm.auto import tqdm
import gzip

In [8]:
from Bio.PDB.MMCIFParser import FastMMCIFParser
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1

def extract_sequence_and_ca_coords(pdb_file, chain_id=None, af = True):
    parser = FastMMCIFParser(QUIET=True) if af else PDBParser(QUIET=True)
    if pdb_file.endswith('.gz'):
        with gzip.open(pdb_file, 'rt') as gz_file:
            temp_file = pdb_file.replace('.gz', '_temp')
            try:
                with open(temp_file, 'w') as temp:
                    temp.write(gz_file.read())
                
                structure = parser.get_structure('protein', temp_file)
            finally:
                if os.path.exists(temp_file):
                    os.remove(temp_file)
    else:
        structure = parser.get_structure('protein', pdb_file)
    
    results = {}
    
    for model in structure:
        for chain in model:
            if chain_id is None or chain.id == chain_id:
                ca_coords = []
                
                for residue in chain:
                    if residue.id[0] == ' ':
                        try:                            
                            if 'CA' in residue:
                                ca_atom = residue['CA']
                                coord = ca_atom.get_coord()
                                ca_coords.append((float(coord[0]), float(coord[1]), float(coord[2])))
                            else:
                                print(f"Warning: No CA atom found in residue {residue.resname}{residue.id[1]} of chain {chain.id}")
                                ca_coords.append(None)
                                
                        except KeyError:
                            print(f'Non natural residue in {pdb_file}')
                
                results[chain.id] = {
                    'ca_coords': ca_coords
                }
    
    return results[chain_id]['ca_coords'] if chain_id in results else []

In [10]:
import tarfile
from io import BytesIO
import fsspec

fs = fsspec.filesystem("file")

# we don't know the shard so we'll just regex match based on structure in hf
shard_paths = fs.glob("/Users/arnavshah/Code/DPFunc/cafa5/structures/*/*/*")  # Lists all tar.gz shard paths
shard_paths 

['/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/.DS_Store',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A021WW32-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A021WZA4-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023FBW4-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023FBW7-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023FF81-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023FFB5-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023FT45-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023G6B6-F1-model_v4.cif.gz',
 '/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_0/AF-A0A023G9N9-F1-model_v4.cif.gz',
 '/Users/arna

In [11]:
import re
af_shard_index = [re.search(r'AF-(.*?)-F1-', sorted(fs.glob(f"/Users/arnavshah/Code/DPFunc/cafa5/structures/af_shards/shard_{i}/*"))[-1].split("/")[-1]).group(1) for i in range(35)]
pdb_shard_index = [sorted(fs.glob(f"/Users/arnavshah/Code/DPFunc/cafa5/structures/pdb_shards/shard_{i}/*"))[-1].split("/")[-1].split(".")[0] for i in range(5)]

def find_protein_shard(entry: str, af: bool = True) -> int:
    code = re.search(r'AF-(.*?)-F1-', entry).group(1) if af else entry.split(".")[0]
    shard_index = af_shard_index if af else pdb_shard_index
    length = 35 if af else 5
    
    for i in range(length):
        if code <= shard_index[i]:
            return i

In [12]:
pdb_points_info = {}
pdb_seq_info = {}
unseen_proteins = set()

for index, row in tqdm(df.iloc[:50].iterrows()):
    uni_id, sequence, struct_entry = row['protein_id'], row['sequence'], row['structure_path']
    if struct_entry is None: # will be roughly 5%
        unseen_proteins.add(uni_id)
        continue

    database, entry = struct_entry.split("/")
    af = database == "af_db"
    assert database in ["af_db", "pdb_files"]
    pdb_file = f"../cafa5/structures/{"af_shards" if af else "pdb_shards"}/shard_{find_protein_shard(entry, af)}/{entry}"
        
    if not os.path.exists(pdb_file): # should never be triggered (@Purav)
        print(f"GUARD REACHED {struct_entry}. PDB file: {pdb_file}")
        unseen_proteins.add(uni_id)
        continue
    
    coords_list = extract_sequence_and_ca_coords(pdb_file, 'A', af)
    
    if coords_list: # guard
        valid_coords = [coord for coord in coords_list if coord is not None]
        pdb_points_info[uni_id] = valid_coords
        pdb_seq_info[uni_id] = sequence

0it [00:00, ?it/s]



In [13]:
pdb_points_info

{'A0A087X1C5': [(26.2549991607666, 39.17300033569336, -33.48899841308594),
  (23.683000564575195, 40.994998931884766, -35.6609992980957),
  (19.917999267578125, 40.2859992980957, -35.37099838256836),
  (20.229999542236328, 38.62200164794922, -38.85300064086914),
  (22.253999710083008, 35.65800094604492, -37.41699981689453),
  (19.636999130249023, 34.95199966430664, -34.6619987487793),
  (16.75200080871582, 34.231998443603516, -37.12699890136719),
  (18.351999282836914, 31.240999221801758, -39.025001525878906),
  (19.648000717163086, 29.812999725341797, -35.69499969482422),
  (16.14900016784668, 30.075000762939453, -34.119998931884766),
  (14.663999557495117, 28.385000228881836, -37.25400161743164),
  (17.249000549316406, 25.520000457763672, -37.13100051879883),
  (16.518999099731445, 25.031999588012695, -33.382999420166016),
  (12.729000091552734, 25.1200008392334, -34.04899978637695),
  (13.045999526977539, 22.55699920654297, -36.92300033569336),
  (15.340999603271484, 20.378999710083

In [None]:
save_pkl('./processed_file/pdb_points.pkl', pdb_points_info)
save_pkl('./processed_file/pdb_seqs.pkl', pdb_seq_info)
save_pkl('./processed_file/unseen_proteins.pkl', unseen_proteins)

## 4.generate ESM - PDB

In [None]:
'''
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Please first run "process_esm.py" file to generate the esm data.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
'''

In [4]:
all_protein_list = []
for tag in tags:
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        all_protein_list+=pid_list
all_protein_list = list(set(all_protein_list))

In [5]:
pdb_points_info = read_pkl('./processed_file/pdb_points.pkl')
pdb_seqs = read_pkl('./processed_file/pdb_seqs.pkl')

In [6]:
assert len(set(pdb_points_info.keys())&set(all_protein_list))==len(all_protein_list)
assert len(set(pdb_seqs.keys())&set(all_protein_list))==len(all_protein_list)

In [7]:
import dgl

In [8]:
import math
import torch

In [9]:
def get_dis(point1, point2):
    dis_x = point1[0] - point2[0]
    dis_y = point1[1] - point2[1]
    dis_z = point1[2] - point2[2]
    return math.sqrt(dis_x*dis_x + dis_y*dis_y + dis_z*dis_z)

def process_input_pdb_file(tag, part, pid_list, pdb_points_info, pdb_seqs, thresholds=12):
    protein_map = read_pkl('./processed_file/protein_map.pkl')
    pdb_graphs = []
    p_cnt = 0
    file_idx = 0
    for pid in tqdm(pid_list):
        p_cnt += 1
        points = pdb_points_info[pid]
        
        u_list = []
        v_list = []
        dis_list = []
        for uid, amino_1 in enumerate(points):
            for vid, amino_2 in enumerate(points):
                if uid==vid:
                    continue
                dist = get_dis(amino_1, amino_2)
                if dist<=thresholds:
                    u_list.append(uid)
                    v_list.append(vid)
                    dis_list.append(dist)
        u_list, v_list = torch.tensor(u_list), torch.tensor(v_list)
        dis_list = torch.tensor(dis_list)

        graph = dgl.graph((u_list, v_list), num_nodes=len(points))
        graph.edata['dis'] = dis_list

        # graph node feature - esm
        esm_file_idx = protein_map[pid]
        esm_features = read_pkl(f"./processed_file/esm_emds/esm_part_{esm_file_idx}.pkl")
        node_features = esm_features[pid]
        assert node_features.shape[0]==graph.num_nodes()
        graph.ndata['x'] = torch.from_numpy(node_features)
        pdb_graphs.append(graph)

        if p_cnt%5000==0:
            save_pkl('./processed_file/graph_features/{}_{}_whole_pdb_part{}.pkl'.format(tag, part, file_idx), pdb_graphs)
            p_cnt = 0
            file_idx += 1
            pdb_graphs = []
    if len(pdb_graphs)>0:
        save_pkl('./processed_file/graph_features/{}_{}_whole_pdb_part{}.pkl'.format(tag, part, file_idx), pdb_graphs)
    return file_idx

In [10]:
for tag in tags:
    if tag=='mf':
        continue
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        max_cnt = process_input_pdb_file(tag, tp, pid_list, pdb_points_info, pdb_seqs)
        if tp=='train':
            print(f"{tag}-{tp}-train_file_count-{max_cnt}")

  0%|          | 0/41119 [00:00<?, ?it/s]

cc-train-train_file_count-8


  0%|          | 0/618 [00:00<?, ?it/s]

  0%|          | 0/990 [00:00<?, ?it/s]

  0%|          | 0/46642 [00:00<?, ?it/s]

bp-train-train_file_count-9


  0%|          | 0/707 [00:00<?, ?it/s]

  0%|          | 0/1280 [00:00<?, ?it/s]

## 5. generate Interpro

In [4]:
interpro_list = read_pkl('./data_dpfunc/interpro_list_26203.pkl')

In [6]:
len(interpro_list)

26203

In [7]:
inter_idx = {}
for idx, ipr in enumerate(interpro_list):
    inter_idx[ipr] = idx
save_pkl('./processed_file/inter_idx.pkl', inter_idx)

In [8]:
all_protein_list = []
for tag in tags:
    for tp in types:
        pid_list = read_pkl(f"./processed_file/{tag}_{tp}_used_pid_list.pkl")
        all_protein_list+=pid_list
all_protein_list = list(set(all_protein_list))
len(all_protein_list)

59350

In [9]:
all_protein_interpro = read_pkl('./data_dpfunc/all_protein_interpros.pkl')

In [12]:
for pr in all_protein_list:
    inters = all_protein_interpro[pr]
    inter_matrix = np.zeros(len(interpro_list))
    for it in inters:
        inter_matrix[inter_idx[it]] += 1
    save_pkl(f"./processed_file/interpro/{pr}.pkl", inter_matrix)

## 6. Check configures

In [None]:
'''
name: mf
mlb: ./mlb/mf_go.mlb
results: ./results

base:
  interpro_whole: ./processed_file/interpro/{}.pkl

train:
  name: train
  pid_list_file: ./processed_file/mf_train_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_train_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_train_whole_pdb_part{}.pkl
  train_file_count: 7
  interpro_file: ./processed_file/mf_train_interpro.pkl

valid:
  name: valid
  pid_list_file: ./processed_file/mf_test1_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_test1_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_test1_whole_pdb_part0.pkl
  interpro_file: ./processed_file/mf_test1_interpro.pkl
  
test:
  name: test
  pid_list_file: ./processed_file/mf_test2_used_pid_list.pkl
  pid_go_file: ./processed_file/mf_test2_go.txt
  pid_pdb_file: ./processed_file/graph_features/mf_test2_whole_pdb_part0.pkl
  interpro_file: ./processed_file/mf_test2_interpro.pkl
'''

In [None]:
'''
Run DPFunc_main.py / DPFunc_pred.py if you need:
python DPFunc_main.py -d mf -n 0 -e 15 -p DPFunc
python DPFunc_pred.py -d mf -n 0 -p DPFunc
'''