In [1]:
import torch
from torch_geometric.data import Data, Dataset, HeteroData
from torch_geometric.loader import DataLoader
from tools.utils import convert_hom_to_het

In [2]:
hom_dataset = torch.load('data/PSCDB/hom_pscdb_graphs.pt', weights_only=False)
len(hom_dataset)

856

In [3]:
three_to_one = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}
amino_acids = sorted(three_to_one.values())
amino_acids


['A',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'K',
 'L',
 'M',
 'N',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'V',
 'W',
 'Y']

In [4]:
edge_indicies = ['edge_index_free', 'edge_index_bound']
for i in edge_indicies:
    print(hom_dataset[0][i].shape)

torch.Size([2, 385])
torch.Size([2, 389])


In [5]:
het_dataset = []
for hom_data in hom_dataset:
    het_data = convert_hom_to_het(hom_data, onehot_indices=(range(len(amino_acids))), expected_node_types=amino_acids, expected_edge_types=edge_indicies, is_directed=False)
    het_dataset.append(het_data)

len(het_dataset)

856

In [6]:
het_dataset[0]

HeteroData(
  y=[1],
  A={ x=[32, 9] },
  C={ x=[3, 9] },
  D={ x=[22, 9] },
  E={ x=[20, 9] },
  F={ x=[9, 9] },
  G={ x=[29, 9] },
  H={ x=[4, 9] },
  I={ x=[29, 9] },
  K={ x=[13, 9] },
  L={ x=[23, 9] },
  M={ x=[0, 9] },
  N={ x=[15, 9] },
  P={ x=[13, 9] },
  Q={ x=[13, 9] },
  R={ x=[21, 9] },
  S={ x=[20, 9] },
  T={ x=[18, 9] },
  V={ x=[19, 9] },
  W={ x=[7, 9] },
  Y={ x=[7, 9] },
  (A, edge_index_free, A)={ edge_index=[2, 4] },
  (A, edge_index_bound, A)={ edge_index=[2, 4] },
  (A, edge_index_free, C)={ edge_index=[2, 1] },
  (A, edge_index_bound, C)={ edge_index=[2, 1] },
  (A, edge_index_free, D)={ edge_index=[2, 3] },
  (A, edge_index_bound, D)={ edge_index=[2, 3] },
  (A, edge_index_free, E)={ edge_index=[2, 5] },
  (A, edge_index_bound, E)={ edge_index=[2, 5] },
  (A, edge_index_free, F)={ edge_index=[2, 1] },
  (A, edge_index_bound, F)={ edge_index=[2, 1] },
  (A, edge_index_free, G)={ edge_index=[2, 7] },
  (A, edge_index_bound, G)={ edge_index=[2, 8] },
  (A, edge_

In [7]:
torch.save(het_dataset, 'data/PSCDB/het_pscdb_graphs.pt')

In [8]:
dataset = torch.load('data/PSCDB/het_pscdb_graphs.pt', weights_only=False)
len(dataset)

856