In [None]:
import glob
import os
import torch

In [None]:
# Directory path containing .pt files
directory_path = '../../models/problematic_batches/'

# Initialize an empty list to store data from .pt files
data_list = []

# Use glob to find all .pt files in the directory
pt_files = glob.glob(directory_path + '*.pt')

# Iterate through each .pt file found
for file_path in pt_files:
    # Load the data from the .pt file using torch.load() and append to data_list
    data = torch.load(file_path)
    data_list.append(data)

In [None]:
data_list

In [None]:
data_list[0].ptr

In [None]:
data_list[0].batch

In [None]:
for batch in data_list:
    if torch.isnan(batch.x).any():
        nan_indices = torch.isnan(batch.x)
        print(f'{batch.cath_id} has NaN node feats:')
        print(batch.x[nan_indices])

    if torch.isnan(batch.edge_attr).any():
        nan_indices = torch.isnan(batch.edge_attr)
        print(f'{batch.cath_id} has NaN edge feats:')
        print(batch.edge_attr[nan_indices])

    if torch.isnan(batch.edge_index).any():
        nan_indices = torch.isnan(batch.edge_index)
        print(f'{batch.cath_id} has NaN edge indices:')
        print(batch.edge_index[nan_indices])

In [None]:
data_list[0].edge_attr[0]

In [2]:
import MDAnalysis as mda
import numpy as np
import warnings
warnings.filterwarnings("ignore", message="Element information is missing, elements attribute will not be populated.")

# Create a function to calculate the bond length
def bond_length(atom1, atom2):
    return np.linalg.norm(atom1.position - atom2.position)

# Create a function to calculate the bond angle
def bond_angle(atom1, atom2, atom3):
    vec1 = atom1.position - atom2.position
    vec2 = atom3.position - atom2.position
    cosine_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    angle = np.arccos(cosine_angle)
    return np.degrees(angle)

for cath_id in ['1vd6A00', '2phcB02', '3n29B01', '5dl7A00']:
    struct = f'../../data/{cath_id}/pdb/{cath_id}.pdb'
    u = mda.Universe(struct, guess_bonds=True)
    for i, bond in enumerate(u.bonds):
        # Parse out interacting atoms in this bond
        atom1, atom2 = bond
        # print(i, atom1, atom2)
        
    
        # Determine edge features
        '''1) Distance'''
        dist = bond_length(atom1, atom2)

       
        '''2) Mean bond angle'''
        bond_angles = []
    
        neighbors_atom2 = atom2.bonded_atoms
        for atom3 in neighbors_atom2:
            if atom3 != atom1:
                angle_deg = bond_angle(atom1, atom2, atom3)
                bond_angles.append(angle_deg)
        # print(bond_angles)
    
        # switch roles to include all angles relevant to the connection
        temp = atom1
        atom1 = atom2
        atom2 = temp
    
        # and append new calculations...
        neighbors_atom2 = atom2.bonded_atoms
        for atom3 in neighbors_atom2:
            if atom3 != atom1:
                angle_deg = bond_angle(atom1, atom2, atom3)
                bond_angles.append(angle_deg)
     
        # Catch RuntimeWarnings during mean calculation
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always", RuntimeWarning)
            mean_bond_angle = np.mean(bond_angles)
            if len(w) > 0:
                print(f"An error occurred when calculating bond angles for [{cath_id}]:\n Atom1 - {atom1}\n Atom2 - {atom2}")
                print(f"Setting `mean_bond_angle` to 0 for this edge")
                print(f"Warnings:")
                mean_bond_angle = 0
                for warning in w:
                    print(f'{warning.message}')
                        
        edge_attributes = {
                    'distance' : str(dist),
                    'bond_angle' : str(mean_bond_angle)}

        if len(bond_angles) == 0:
            print(edge_attributes)
            print()
            continue


An error occurred when calculating bond angles for [1vd6A00]:
 Atom1 - <Atom 769: CZ of type C of resname PHE, resid 104 and segid A and altLoc A>
 Atom2 - <Atom 768: CE2 of type C of resname PHE, resid 104 and segid A and altLoc A>

Setting `mean_bond_angle` to 0

Mean of empty slice.

invalid value encountered in scalar divide

1vd6A00
785 <Atom 769: CZ of type C of resname PHE, resid 104 and segid A and altLoc A> <Atom 768: CE2 of type C of resname PHE, resid 104 and segid A and altLoc A>
1.3932604
[]
0
{'distance': '1.3932604', 'bond_angle': '0'}

An error occurred when calculating bond angles for [2phcB02]:
 Atom1 - <Atom 276: NZ of type N of resname LYS, resid 120 and segid B and altLoc >
 Atom2 - <Atom 275: CE of type C of resname LYS, resid 120 and segid B and altLoc >

Setting `mean_bond_angle` to 0

Mean of empty slice.

invalid value encountered in scalar divide

2phcB02
281 <Atom 276: NZ of type N of resname LYS, resid 120 and segid B and altLoc > <Atom 275: CE of type C of

In [None]:
def load_data_from_partitions(base_path):
    partition_files = glob.glob(os.path.join(base_path, '*.pt'))
    data_list = []
    for partition_file in partition_files:
        partition_data = torch.load(partition_file)
        data_list.extend(partition_data)
    
    print(f"Loaded {len(data_list)} PyG objects from {base_path}")
    return data_list

In [None]:
test_list = load_data_from_partitions('../../models/test_partitions')
test_list

In [None]:
data = test_list[0]
data

In [None]:
edge_feat=[616, 2]

In [None]:
from torch_geometric.utils import dense_to_sparse,  to_undirected
edge_index, _ = dense_to_sparse(data.adj_matrix)
edge_index

In [None]:
edge_index.shape

In [None]:
edge_index = to_undirected(edge_index)
edge_index

In [None]:
edge_index.shape