## download and process raw files

In [1]:
import numpy as np
import torch
from tqdm import tqdm
from proteinshake.datasets import EnzymeCommissionDataset, ProteinLigandInterfaceDataset, TMAlignDataset
from proteinshake.datasets import GeneOntologyDataset, ProteinFamilyDataset, ProteinProteinInterfaceDataset
from proteinshake.datasets import ProteinLigandDecoysDataset, SCOPDataset

def get_dataset(id):
    if id == "enzyme_class":
        return EnzymeCommissionDataset(root="../data/enzyme_class/", use_precomputed=True)
    elif id == "gene_ontology":
        return GeneOntologyDataset(root="../data/gene_ontology/", use_precomputed=True)
    elif id == "protein_family":
        return ProteinFamilyDataset(root="../data/protein_family/", use_precomputed=True)
    elif id == "tm_align":
        return TMAlignDataset(root="../data/tm_align/", use_precomputed=True)
    elif id == "scop":
        return SCOPDataset(root="../data/scop/", use_precomputed=True)
    elif id == "ligand_binding":
        return ProteinLigandInterfaceDataset(root="../data/ligand_binding/", use_precomputed=True)
    elif id == "ppis":
        return ProteinProteinInterfaceDataset(root="../data/ppis/", use_precomputed=True)
    elif id == "ligand_decoys":
        return ProteinLigandDecoysDataset(root="../data/ligand_decoys/", use_precomputed=True)
    else:
        raise ValueError(f"Unknown dataset ID: {id}")

In [2]:
dataset_id = "ligand_binding"
dataset = get_dataset(dataset_id)
#proteins_atom = dataset.proteins(resolution='atom')
proteins_res = dataset.proteins(resolution='residue')
print(dataset.root)

../data/ligand_binding/


In [None]:
def save_pdb_names():
    with open(dataset.root + 'proteins.txt', 'w') as f:
        for protein_res in proteins_res:
            protein_info = protein_res['protein']
            pid = protein_info['ID']
            if '_' in pid:
                pid = pid.split('_')[0]
            f.write(pid+',')
#save_pdb_names()

In [3]:
protein_dict, residue_dict = tuple(next(proteins_res).values())
print(protein_dict.keys())
print(residue_dict.keys())

dict_keys(['ID', 'sequence', 'kd', 'neglog_aff', 'resolution', 'year', 'ligand_id', 'ligand_smiles', 'fp_maccs', 'fp_morgan_r2', 'random_split', 'sequence_split_0.5', 'sequence_split_0.6', 'sequence_split_0.7', 'sequence_split_0.8', 'sequence_split_0.9', 'structure_split_0.3', 'structure_split_0.4', 'structure_split_0.5', 'structure_split_0.6', 'structure_split_0.7', 'structure_split_0.8', 'structure_split_0.9'])
dict_keys(['residue_number', 'residue_type', 'x', 'y', 'z', 'SASA', 'RSA', 'chain_id', 'binding_site'])


In [70]:
print(protein_dict['ID'])
print(residue_dict['residue_number'])
print(len(residue_dict['residue_number']))
print(len(set(residue_dict['residue_number'])))
print(residue_dict['residue_type'][:208])
print(residue_dict['residue_type'][208:])
print(residue_dict['binding_site'][:208])
print(residue_dict['binding_site'][208:])

6s56
[979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035, 1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 1048, 1049, 1050, 1051, 1052, 1053, 1054, 1055, 1056, 1057, 1058, 1059, 1060, 1061, 1062, 1063, 1064, 1065, 1066, 1067, 1068, 1069, 1070, 1071, 1072, 1073, 1074, 1075, 1076, 1077, 1078, 1079, 1080, 1081, 1082, 1083, 1084, 1085, 1086, 1087, 1088, 1089, 1090, 1091, 1092, 1093, 1094, 1095, 1096, 1097, 1098, 1099, 1100, 1101, 1102, 1103, 1104, 1105, 1106, 1107, 1108]
130
130
['S', 'M', 'Q', 'E', 'E', 'D', 'T', 'F', 'R', 'E', 'L', 'R', 'I', 'F', 'L', 'R', 'N', 'V', 'T', 'H', 'R', 'L', 'A', 'I', 'D', 'K', 'R', 'F', 'R', 'V', 'F', 'T', 'K', 'P', 'V', 'D', 'P', 'D', 'E', 'V', 'P', 'D', 'Y', 'V', 'T', '

## select protein by id

In [23]:
for protein_res in dataset.proteins(resolution='residue'):
    protein_info = protein_res['protein']
    pid = protein_info['ID']
    if pid == '2wos':
        break

In [19]:
protein_info

{'ID': '1q91',
 'sequence': 'RALRVLVDMDGVLADFEGGFLRKFRARFPDQPFIALEDRRGFWVSEQYGRLRPGLSEKAISIWESKNFFFELEPLPGAVEAVKEMASLQNTDVFICTSPIKMFKYCPYEKYAWVEKYFGPDFLEQIVLTRDKTVVSADLLIDDRPDITGAEPTPSWEHVLFTACHNQHLQLQPPRRRLHSWADDWKAILDSKRP',
 'kd': 70.0,
 'neglog_aff': 4.150000095367432,
 'resolution': 1.600000023841858,
 'year': 2004,
 'ligand_id': 'DPB)\n',
 'ligand_smiles': 'Cc1cn([C@H]2C[C@H]3O[C@](c4ccccc4)(P(=O)(O)O)OC[C@H]3O2)c(=O)[nH]c1=O',
 'fp_maccs': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  0,
  0,
  1,
  0,
  0,
  1,
  0,
  0,
  0,
  1,
  0,
  0,
  1,
  0,
  0,
  1,
  0,
  1,
  0,
  0,
  1,
  0,
  0,
  1,
  0,
  1,
  0,
  0,
  0,
  1,
  0,
  1,
  1,
  0,
  0,
  0,
  1,
  1,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  1,
  1,
  0,
  0,
  1,
 

In [7]:
set(protein_res['residue'])

{'RSA',
 'SASA',
 'binding_site',
 'chain_id',
 'residue_number',
 'residue_type',
 'x',
 'y',
 'z'}

In [24]:
print(protein_res['residue']['residue_type'])
print(protein_res['residue']['residue_number'])
print(protein_res['residue']['binding_site'])
print(protein_res['residue']['chain_id'])
print(len(protein_res['residue']['residue_type']))


['S', 'N', 'T', 'Q', 'A', 'E', 'R', 'S', 'I', 'I', 'G', 'M', 'I', 'D', 'M', 'F', 'H', 'K', 'Y', 'T', 'R', 'R', 'D', 'D', 'K', 'I', 'D', 'K', 'P', 'S', 'L', 'L', 'T', 'M', 'M', 'K', 'E', 'N', 'F', 'P', 'N', 'F', 'L', 'S', 'A', 'C', 'D', 'K', 'K', 'G', 'T', 'N', 'Y', 'L', 'A', 'D', 'V', 'F', 'E', 'K', 'K', 'D', 'K', 'N', 'E', 'D', 'K', 'K', 'I', 'D', 'F', 'S', 'E', 'F', 'L', 'S', 'L', 'L', 'G', 'D', 'I', 'A', 'T', 'D', 'Y', 'H', 'K', 'Q', 'S', 'H', 'G', 'A', 'A', 'P', 'C', 'S', 'S', 'N', 'T', 'Q', 'A', 'E', 'R', 'S', 'I', 'I', 'G', 'M', 'I', 'D', 'M', 'F', 'H', 'K', 'Y', 'T', 'R', 'R', 'D', 'D', 'K', 'I', 'D', 'K', 'P', 'S', 'L', 'L', 'T', 'M', 'M', 'K', 'E', 'N', 'F', 'P', 'N', 'F', 'L', 'S', 'A', 'C', 'D', 'K', 'K', 'G', 'T', 'N', 'Y', 'L', 'A', 'D', 'V', 'F', 'E', 'K', 'K', 'D', 'K', 'N', 'E', 'D', 'K', 'K', 'I', 'D', 'F', 'S', 'E', 'F', 'L', 'S', 'L', 'L', 'G', 'D', 'I', 'A', 'T', 'D', 'Y', 'H', 'K', 'Q', 'S', 'H', 'G', 'A', 'A', 'P', 'C', 'S']
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,

In [25]:
print([res_id for res_id, bind in zip(protein_res['residue']['residue_number'], protein_res['residue']['binding_site']) if bind==1])

[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 34, 38, 39, 74, 75, 78, 81, 82, 85, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19, 34, 38, 39, 74, 75, 78, 81, 82, 85]


In [27]:
protein_res['residue']['residue_number'][7], protein_res['residue']['residue_number'][85]

(8, 86)

In [None]:
a = torch.load("/root/autodl-tmp/unit-protein/tasks/proteinshake/data/ligand_binding/esm_encodings/6eqw.pth", map_location='cpu')
for k, v in a.items():
    if isinstance(v, torch.Tensor) or isinstance(v, np.ndarray):
        print(k, v.shape)
    else:
        print(k, len(v))

chain_id (475,)
residue_index (475,)
insertion_code (475,)
residue_type (475,)
coordinates (475, 37, 3)
structure_encodings torch.Size([475, 128])
sequence_tokens torch.Size([475])
structure_tokens torch.Size([475])


: 

: 

In [None]:
a['chain_id']

array(['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A',
       'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A

: 

: 

## process the ppis dataset contact labels

In [33]:
res_ids = {}
for protein_res in tqdm(proteins_res):
    protein_info = protein_res['protein']
    residue_info = protein_res['residue']
    res_ids[protein_info['ID']] = residue_info['residue_number']

100%|██████████| 12530/12530 [00:26<00:00, 478.58it/s]


In [39]:
interfaces = {}
for pid, value in tqdm(dataset._interfaces.items()):
    # value is a list of dictionaries with 'ID' and 'chain'
    interfaces[pid] = {}
    for chain1 in value.keys():
        if f'{pid}_{chain1}' not in res_ids:
            continue
        interfaces[pid][chain1] = {}
        for chain2 in value[chain1].keys():
            if f'{pid}_{chain2}' not in res_ids:
                continue
            interfaces[pid][chain1][chain2] = []
            for interface in value[chain1][chain2]:
                interfaces[pid][chain1][chain2].append((res_ids[f'{pid}_{chain1}'][interface[0]], res_ids[f'{pid}_{chain2}'][interface[1]]))
        if len(interfaces[pid][chain1]) == 0:
            del interfaces[pid][chain1]
    if len(interfaces[pid]) == 0:
        del interfaces[pid]

  0%|          | 0/2839 [00:00<?, ?it/s]

100%|██████████| 2839/2839 [00:01<00:00, 2299.40it/s]


In [43]:
count = 0
for pid, value in tqdm(interfaces.items()):
    for chain1 in value.keys():
        for chain2 in value[chain1].keys():
            count += 1
print(f"Total interfaces: {count}")

100%|██████████| 2800/2800 [00:00<00:00, 117008.75it/s]

Total interfaces: 24008





In [50]:
with open("../data/ppis/interfaces.json", "w") as f:
    import json
    json.dump(interfaces, f)