In [1]:
import networkx as nx
import glob
import MDAnalysis as mda
import pandas as pd
import numpy as np
import pickle
import torch
import shutil
import os
import warnings

# Suppress element warning - pdb_share files do not have element information, that's ok!
warnings.filterwarnings("ignore", message="Element information is missing, elements attribute will not be populated.")

from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split

ModuleNotFoundError: No module named 'h5py'

In [None]:
# Load data
df = pd.read_csv('cleaned_data.csv')
old_df = pd.read_csv('../data/cath_w_seqs_share.csv', index_col=0)

### Atomic Resolution Molecualr Graph Features

**Node Features**
- x, y, z atomic coorindates
- atom type, 1 hot encoded for [N, CA, C, O, C_alt] where C_alt is a sidechain carbon

**Edge Features**
- mean bond angle for all bonds in the local proxmity of each atom
- bond distance

In [None]:
# Create a function to calculate the bond length
def bond_length(atom1, atom2):
    return np.linalg.norm(atom1.position - atom2.position)

# Create a function to calculate the bond angle
def bond_angle(atom1, atom2, atom3):
    vec1 = atom1.position - atom2.position
    vec2 = atom3.position - atom2.position
    cosine_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
    angle = np.arccos(cosine_angle)
    return np.degrees(angle)

def recalculate_distance(x1, y1, z1, x2, y2, z2):
    point1 = np.array([x1, y1, z1])
    point2 = np.array([x2, y2, z2])
    distance = np.linalg.norm(point2 - point1)
    return distance

def remove_hydrogen_atoms(input_file, output_file):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        for line in infile:
            # Check if the line is an ATOM record and not a hydrogen atom
            if line.startswith('ATOM') and line.split()[11]!= 'H': # Last character will be element symbol
                outfile.write(line)
            
            # Write non-ATOM lines as they are
            elif not line.startswith('ATOM'):
                outfile.write(line)

### Select all modified_pdbs

In [None]:
# Gather AlphaFold2 models - Current state is relaxed. AMBER is used to perform energy minimization
alpha_fold_models = glob.glob('../data/*/pdb/*_alphafold_remodeled_relaxed.pdb')

# Deprotonate AlphaFold2 models for consistency. In the long run, this is not ideal, and protonated structures would allow hbond modeling
deprotonated_models = []
for model in alpha_fold_models:
    # Create the new filename
    dir_name, file_name = os.path.split(model)
    base_name, ext = os.path.splitext(file_name)
    new_file_name = f"{base_name}_deprotonated{ext}"
    output_file = os.path.join(dir_name, new_file_name)

    # Process the file
    remove_hydrogen_atoms(model, output_file)
    deprotonated_models.append(output_file)

print(deprotonated_models)

### Select all original_pdbs where the new_seq is identical to the old_seq
- These pdbs do not contain gap residues and did not have to be remodeled

In [None]:
no_seq_change_models = []
for idx, row in df.iterrows():
    cath_id = row['cath_id']
    new_seq = row['sequences']
    old_seq = old_df.loc[old_df['cath_id'] == cath_id , 'sequences'].values[0]
    if new_seq == old_seq:
        # Find the original file
        original_file_path = glob.glob(f'../data/{cath_id}/pdb/{cath_id}')[0]
        
        # Define the new file path with .pdb extension
        new_file_path = os.path.join(os.path.dirname(original_file_path), f"{cath_id}.pdb")
              
        # Copy the file, overwriting if it already exists
        shutil.copy2(original_file_path, new_file_path)
        
        # Add the new file path to the list
        no_seq_change_models.append(new_file_path)

In [None]:
df[df['cath_id'] == '2w3sB01']['sequences'].item() == old_df[old_df['cath_id'] == '2w3sB01']['sequences'].item()

In [None]:
no_seq_change_models[:15]

In [None]:
all_files = alpha_fold_models + no_seq_change_models

### Production

In [None]:
# Define the list and the character to one-hot encode
atom_types = ['N', 'CA', 'C', 'O', 'C_alt'] # 'C_alt' covers sidechain Carbon atoms

data_list = [] # This is where the final PyTorch Geometric graphs will be stored

with open('graph_dataset_output.txt', 'w') as output_file: # Open a file for all output to reduce I/O burden
    for struct in tqdm(all_files):
        cath_id = struct.split("/")[2]
    
        # NetworkX Graph 
        G = nx.Graph()
        
        # Create a Universe Object with guess_bonds enabled 
        u = mda.Universe(struct, guess_bonds=True)
    
        print(f'[TOPOLOGY] Setting node features for CATH_ID {cath_id}', file=output_file)
        # Iterate over all residues and atoms in the universe
        for residue in u.residues:
            # Loop through each atom in the residue
            for atom in residue.atoms:
                # Create the nodel label from its attributes
                residue_name = residue.resname
                residue_id = residue.resid
                chain_id = residue.segid
                atom_name = atom.name
                label = f'{residue_name}{residue_id}.{chain_id}.{atom_name}'
    
                # Determine features for this atomic node
                element = atom_name[0]
                if element == 'C':
                    if atom_name == 'CA':
                        element = 'CA'
                    elif atom_name != 'C':
                        element = 'C_alt'     
                        
                one_hot_element = [1 if atom == element else 0 for atom in atom_types]
    
                node_attributes = {
                                    'element' : one_hot_element,
                                    'x' : str(atom.position[0]),
                                    'y' : str(atom.position[1]),
                                    'z' : str(atom.position[2])}
    
                # Add node
                G.add_node(label, **node_attributes)
                
        print(f'[TOPOLOGY] Setting edge features for CATH_ID {cath_id}',file=output_file)
        for bond in u.bonds:
            # Parse out interacting atoms in this bond
            atom1, atom2 = bond
    
            # Determine edge features
            '''1) Distance'''
            dist = bond_length(atom1, atom2)
            
            '''2) Mean bond angle'''
            bond_angles = []
    
            neighbors_atom2 = atom2.bonded_atoms
            for atom3 in neighbors_atom2:
                if atom3 != atom1:
                    angle_deg = bond_angle(atom1, atom2, atom3)
                    bond_angles.append(angle_deg)
        
            # switch roles to include all angles relevant to the connection
            temp = atom1
            atom1 = atom2
            atom2 = temp
        
            # and append new calculations...
            neighbors_atom2 = atom2.bonded_atoms
            for atom3 in neighbors_atom2:
                if atom3 != atom1:
                    angle_deg = bond_angle(atom1, atom2, atom3)
                    bond_angles.append(angle_deg)
        
            mean_bond_angle = np.mean(bond_angles)
    
            edge_attributes = {
                        'distance' : str(dist),
                        'bond_angle' : str(mean_bond_angle)}
    
            # Convet `atom1` and `atom2` back to original identities 
            atom1, atom2 = bond
    
            # Create atom labels compatible with nodes in the graph
            atom1_label = f'{atom1.resname}{atom1.resid}.{atom1.segid}.{atom1.name}'
            atom2_label = f'{atom2.resname}{atom2.resid}.{atom2.segid}.{atom2.name}'
    
            # Add edges, excluding self-connections
            G.add_edge(atom1_label, atom2_label, **edge_attributes)
    
        '''
        Save NetworkX graph with untransformed features to ensure human interpretability 
        '''
        print(f'[Checkpoint] Saving NetworkX graph for CATH_ID {cath_id}', file=output_file)
        gml_fileName = f'../data/{cath_id}/networkx/{cath_id}_graph.gml'
        pkl_fileName = f'../data/{cath_id}/networkx/{cath_id}_graph.pkl' # Faster I/O for .pkl
        nx.write_gml(G, gml_fileName)
        with open(pkl_fileName, 'wb') as f:
            pickle.dump(G, f)
    
        #####################GRAPH LEVEL OPERATIONS#############################
        
        print(f'[SE3] Transforming geometric feature values for CATH_ID {cath_id}', file=output_file)
        # Collect Coordinates
        coordinates = [[float(attributes['x']), float(attributes['y']), float(attributes['z'])] for _, attributes in G.nodes(data=True)]
    
        '''
        Rotation Invariance of Atomic Coordinates
    
        Each molecular structure is aligned to its own principal axes through PCA, 
        generating a rotation matrix that optimally aligns it to a standard
        reference frame
        '''
        # Step 1: Compute the principal axes (eigenvectors) that represent the directions of maximum variance in the distribution of atoms.
        pca = PCA(n_components=3) # X, Y, Z dimensions
        pca.fit(coordinates)
        
        # Obtain rotation matrix to align with principal axes
        rotation_matrix = pca.components_.T  # Ensure it's a rotation matrix by ensuring determinant is 1 or -1
        
        # Step 2: Apply rotation matrix to align coordinates to a cannonical orientation
        transformed_coordinates = np.dot(coordinates, rotation_matrix)
    
        '''
        Translation Invariance of Atomic Coordinates
    
        Enhances the robustness of the model to arbitrary positional variations, 
        which otherwise describe the same structure
        '''
        # Calculate centroid (mean) of coordinates
        centroid = np.mean(transformed_coordinates, axis=0)
        
        # Center coordinates about origin
        centered_coordinates = transformed_coordinates  - centroid
        
        # Initialize MinMaxScaler with feature_range set to [-1, 1]
        scaler_coordinates = MinMaxScaler(feature_range=(-1, 1))
        
        # Fit the scaler to the centered atomic coordinates
        scaler_coordinates.fit(centered_coordinates)
        
        # Normalize the atomic coordinates
        normalized_coordinates = scaler_coordinates.transform(centered_coordinates)
    
        '''
        Normalize bond angles to from [0, 180] to [0, 1]
        '''
        # Collect bond angles
        bond_angles = [[float(attributes['bond_angle'])] for _, _, attributes in G.edges.data()]
    
        # Initialize MinMaxScaler
        scaler_bond_angles = MinMaxScaler()
    
        # Fit the scaler on the data and transform bond_angles to [0, 1]
        normalized_angles = scaler_bond_angles.fit_transform(bond_angles)
    
        '''
        Convert to PyTorch Geometric Object
        '''
        print(f'[PyTorch] Converting NetworkX graph for CATH_ID {cath_id} to PyTorch Geometric\n', file=output_file)
    
        sequence = df[df['cath_id'] ==  cath_id]['sequences'].item()
        target = int(df[df['cath_id'] == cath_id]['label'].item())
        node_labels = list(G.nodes)
    
        # Extract node features
        node_features = []
        for j, (_, attributes) in enumerate(G.nodes(data=True)):
            one_hot_element = attributes['element']
            feature = [
                *one_hot_element, # Unpack the 1-hot encoded list
                # Geometric features rotated, translated, and scaled to [-1, 1]
                normalized_coordinates[j][0], # X
                normalized_coordinates[j][1], # Y
                normalized_coordinates[j][2]  # Z
                ]
            node_features.append(feature)
    
        # Extract edge features
        edge_features = []
        edge_index = []
        for k, (u, v, attributes) in enumerate(G.edges.data()):
    
            # Recalculate bond distances for the [-1, 1] scaled coordinate space
            u_id = node_labels.index(u) # Node ID for the `u` atom in the bond
            v_id = node_labels.index(v) # Node ID for the `v` atom in the bond
            # Use the node indices to access the respective normalized coordinates
            x1, y1, z1 = normalized_coordinates[u_id][0], normalized_coordinates[u_id][1], normalized_coordinates[u_id][2]
            x2, y2, z2 = normalized_coordinates[v_id][0], normalized_coordinates[v_id][1], normalized_coordinates[v_id][2]
            
            feature = [
                recalculate_distance(x1, y1, z1, x2, y2, z2),
                normalized_angles[k][0]
                ]
            edge_features.append(feature)
    
        # Data Object Instantiation
        data = Data(
                    cath_id = [cath_id],
                    seq = [sequence],
                    node_labels = node_labels,
                    node_feat = torch.Tensor(node_features),
                    edge_feat = torch.Tensor(edge_features),
                    adj_matrix = torch.Tensor(nx.to_numpy_array(G)),   # Dense Adjacency Matrix
                    target = torch.tensor(target)                      # Target class for this cath_id
                    )
    
        # Append the completed PyTorch Geometric Object
        data_list.append(data)

In [None]:
data_list[:10]

## Stratified Data Split
- Group the data by superfamily, keeping all proteins from the same superfamily together to prevent splitting related proteins across the train and test sets.
- Perform a stratified split to maintain the proportion of each superfamily in both the training and testing sets, ensuring a representative distribution.
- Within each superfamily, randomly select proteins to assign to the training and testing sets.
- Use a common split ratio, such as 80% for training and 20% for testing
- Ensure that small superfamilies, with very few members, are represented in both sets if possible. We want both our training and testing sets to be representative of the overall dataset. Including small superfamilies in both sets helps maintain this representation. If we only include large superfamilies in the test set, we might overfit to common superfamilies and misjudge model performance on more rare superfamilies

In [None]:
superfamily_sizes = df['superfamily'].value_counts()
superfamily_sizes

In [None]:
def stratified_split_by_superfamily(df, test_size=0.2, small_family_threshold=4, random_state=42):
    np.random.seed(random_state)

    # Group by superfamily and get sizes
    superfamily_sizes = df['superfamily'].value_counts()
    
    # Separate small and large superfamilies
    small_families = superfamily_sizes[superfamily_sizes < small_family_threshold].index.tolist() 
    large_families = superfamily_sizes[superfamily_sizes >= small_family_threshold].index.tolist()
    
    # Shuffle both lists
    np.random.shuffle(small_families)
    np.random.shuffle(large_families)
    
    # Split large families
    split_point = int(len(large_families) * (1 - test_size))
    train_large_names = large_families[:split_point] # 80% train
    test_large_names = large_families[split_point:]  # 20% test

    # Initialize train and test dataframes with large family data
    train_df = df[df['superfamily'].isin(train_large_names)]
    test_df = df[df['superfamily'].isin(test_large_names)]
    
    # Handle small families
    for family in small_families:
        family_data = df[df['superfamily'] == family]
        if len(family_data) == 1:
            # If only one instance, always put it in the training set
            train_df = pd.concat([train_df, family_data])
        else:
            family_train, family_test = train_test_split(family_data, test_size=test_size, random_state=random_state)
            train_df = pd.concat([train_df, family_train])
            test_df = pd.concat([test_df, family_test])

    print(f"Train set size: {len(train_df)}")
    print(f"Test set size: {len(test_df)}")
    print(f"Actual test size: {len(test_df) / len(df):.2f}")
    print(f"Small families in train: {sum(train_df['superfamily'].isin(small_families))}")
    print(f"Small families in test: {sum(test_df['superfamily'].isin(small_families))}")
    
    return train_df, test_df

In [None]:
# Usage
train_df, test_df = stratified_split_by_superfamily(df)

# Save the split datasets
train_df.to_csv('../models/train_data.csv', index=False)
test_df.to_csv('../models/test_data.csv', index=False)

## Saving the PyG Objects Allows Preprocessing to be Performed Only Once
- Afterwards, you simply reload data from disk

In [None]:
# Partition PyTorch Geometric Objects by train_df and test_df and Save
train_cath_ids = list(train_df['cath_id'])
test_cath_ids = list(train_df['cath_id'])

train_file = '../models/train_compressed.pt'
test_file = '../models/test_compressed.pt'

train_data = []
test_data = []

for pyg_obj in data_list:
    # Looks like: Data(cath_id=[1], seq=[1], node_labels=[3731], node_feat=[3731, 8], edge_feat=[3762, 2], adj_matrix=[3731, 3731], target=6)
    obj_dict = {
            'cath_id': pyg_obj.cath_id,
            'seq': pyg_obj.seq,
            'node_labels': pyg_obj.node_labels,
            'node_feat': pyg_obj.node_feat,
            'edge_feat': pyg_obj.edge_feat,
            'adj_matrix': pyg_obj.adj_matrix,
            'target': pyg_obj.target
        }

    if pyg_obj.cath_id[0] in train_cath_ids:
        train_data.append(obj_dict)
        
    elif pyg_obj.cath_id[0] in test_cath_ids:
        test_data.append(obj_dict)

# Save
torch.save(train_data, train_file, _use_new_zipfile_serialization=True)
torch.save(test_data, test_file, _use_new_zipfile_serialization=True)        

In [None]:
def load_pyg_list(file_path):
    loaded_list = torch.load(file_path)
    data_list = []
    for data_dict in loaded_list:
        data = Data(
            cath_id=data_dict['cath_id'],
            seq=data_dict['seq'],
            node_labels=data_dict['node_labels'],
            node_feat=data_dict['node_feat'],
            edge_feat=data_dict['edge_feat'],
            adj_matrix=data_dict['adj_matrix'],
            target=data_dict['target']
        )
        data_list.append(data)
    
    print(f"Loaded {len(data_list)} PyG objects from {file_path}")
    return data_list

test_list = load_pyg_list('../models/train_compressed.pt')
test_list