In [None]:
from pscdb_dataset import ProteinPairGraphBuilder
from tqdm import tqdm
import pandas as pd
import torch

import warnings
from Bio import BiopythonWarning
warnings.filterwarnings("ignore", category=BiopythonWarning)


In [None]:
df = pd.read_csv('data/PSCDB/structural_rearrangement_data.csv')
df.head()

In [None]:
df.shape

In [None]:
df_parsed = df[['Free PDB', 'Bound PDB', 'motion_type']]
df_parsed

In [None]:
# Create int labels for motion_type
df_parsed['motion_type'] = pd.Categorical(df_parsed['motion_type']).copy()
df_parsed['motion_type_int'] = df_parsed['motion_type'].cat.codes

# Mappings from int label to string label
int_to_name = dict(enumerate(df_parsed['motion_type'].cat.categories))
name_to_int = {v: k for k, v in int_to_name.items()}

df_parsed.head()


In [None]:
three_to_one = {
    'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D',
    'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G',
    'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K',
    'MET': 'M', 'PHE': 'F', 'PRO': 'P', 'SER': 'S',
    'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
}
amino_acids = sorted(three_to_one.values())  # Sorted alphabetically
idx_to_aa = {i: aa for i, aa in enumerate(amino_acids)}
idx_to_aa

In [None]:
builder = ProteinPairGraphBuilder(threshold=5.0, data_path='data/PSCDB/PDB_structures')
my_dataset = []

for _, row in tqdm(df_parsed.iterrows(), total=len(df_parsed)):
    data = builder.build_graph_pair(
        row['Free PDB'],
        row['Bound PDB'],
        row['motion_type_int']
    )
    if data:
        my_dataset.append(data)

print(f"\n\nBuilt {len(my_dataset)} valid graph pairs")

In [None]:
print(f"One-hot Amino Acid: \n{my_dataset[0].x[0,:20]}")
print(f"Amino Acid Label: {idx_to_aa[(my_dataset[0].x[0,:20]).argmax().item()]}")
print(f"Free Structure Amino Acid Coords (x, y, z):   {my_dataset[0].x[0,20:23]}")
print(f"Bound Structure Amino Acid Coords (x, y, z):  {my_dataset[0].x[0,23:26]}")
print(f"Displacement Coords (Bound - Free) (x, y, z): {my_dataset[0].x[0,26:29]}")


In [None]:
for i in range(len(my_dataset)):
    if not my_dataset[i].x.__len__() == my_dataset[i].x[:,:20].sum().item():
        print(f"Missing amino acid(s) in data point{i}")

In [None]:
my_dataset[:5]

In [None]:
torch.save(my_dataset, 'hom_pscdb_graphs.pt')

In [None]:
dataset = torch.load('hom_pscdb_graphs.pt', weights_only=False)
len(dataset)