In [17]:
import MDAnalysis as mda
import pandas as pd
from biopandas.pdb import PandasPdb
import os
import glob
import re
import math
import numpy as np
from rdkit import Chem
from scipy.spatial.transform import Rotation as R

def grid_list(atom_df):
    return list(zip(atom_df['x_coord'], atom_df['y_coord'], atom_df['z_coord']))

def filtering_proteins(atom_df, grid_list, radius=5.0):
    atom_coords = atom_df[['x_coord', 'y_coord', 'z_coord']].values
    filtered_atoms = set()

    for x, y, z in grid_list:
        distances_sq = (atom_coords[:, 0] - x)**2 + (atom_coords[:, 1] - y)**2 + (atom_coords[:, 2] - z)**2
        mask = distances_sq <= radius**2
        filtered_atoms.update(atom_df.index[mask])

    print(f"Total atoms within {radius} Ã… cutoff: {len(filtered_atoms)}")
    return atom_df.loc[list(filtered_atoms)]


In [18]:
def get_protein_name(filename):
    basename = os.path.basename(filename)  # Get file name without path
    match = re.match(r'([a-zA-Z0-9]{4})', basename)  # Match the first 4-character PDB ID
    if match:
        return match.group(1).upper()
    else:
        return None
def get_mode_index(filename):
    basename = os.path.basename(filename)
    match = re.search(r'mode_(\d+)', basename)
    if match:
        return int(match.group(1))
    else:
        return None  # or raise ValueError("No mode index found.")

def natural_sort_key(s):
    """Function to sort strings in a natural alphanumeric order."""
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]


In [19]:
def create_grid(size=20, resolution=1):
    num_cells = int(size * resolution)
    grid = np.zeros((num_cells, num_cells, num_cells, 23))  # 23 features per grid point
    return grid

# Function to apply 3D rotation to atomic coordinates
def rotate_molecule(mol_to_rot, rotation_matrix):
    
    conf = mol_to_rot.GetConformer()
    for atom_idx in range(mol_to_rot.GetNumAtoms()):
        pos = conf.GetAtomPosition(atom_idx)
        new_pos = np.dot(rotation_matrix, np.array([pos.x, pos.y, pos.z]))
        conf.SetAtomPosition(atom_idx, new_pos)
    return mol_to_rot

# Generate a random rotation matrix
def generate_random_rotation_matrix():
    # Generate a random 3D rotation using Euler angles
    rotation = R.from_euler('xyz', np.random.uniform(0, 360, size=3), degrees=True)
    return rotation.as_matrix()

# Function to encode atomic features (same as before)
atom_types = {'C': 0, 'N': 1, 'O': 2, 'S': 3, 'other': 4}

def encode_atom_features(atom):
    features = np.zeros(14)
   # One-hot encoding for atom types
    atom_symbol = atom.GetSymbol()
    # if atom_symbol == 'H':
    #     print(atom_symbol, " is atom symbol")
    if atom_symbol in atom_types:
        features[atom_types[atom_symbol]] = 1
    else:
        features[atom_types['other']] = 1
    
    hybridization = atom.GetHybridization()
    print(hybridization, "is atom hybridization type")
    if hybridization == Chem.HybridizationType.SP:
        features[5] = 1
    elif hybridization == Chem.HybridizationType.SP2:
        features[6] = 1
    elif hybridization == Chem.HybridizationType.SP3:
        features[7] = 1

    for neighbor in atom.GetNeighbors():
        print("Neighbor:", neighbor.GetSymbol(), "AtomicNum:", neighbor.GetAtomicNum())

    num_heavy_atoms = sum(1 for neighbor in atom.GetNeighbors() if neighbor.GetAtomicNum() > 1)
    print(num_heavy_atoms, "is number of heavy atoms")
    features[8] = num_heavy_atoms    
   
    # Number of bonded hetero atoms (atoms other than carbon and hydrogen)
    num_hetero_atoms = sum(1 for neighbor in atom.GetNeighbors() if neighbor.GetAtomicNum() not in {1, 6})
    print(num_hetero_atoms, "is number of hetero atoms")
    features[9] = num_hetero_atoms
    features[10] = 1 if atom.GetIsAromatic() else 0

    # formal charge, 0 is no charge, 1 is negative, and 2 is positive charge
    residue = atom.GetPDBResidueInfo().GetResidueName().strip() 
    atom_name = atom.GetPDBResidueInfo().GetName().strip()

    if atom.GetFormalCharge() == 1: # setting positive charge to 2
        atom.SetFormalCharge(2) 

    if residue == "ASP" and atom_name == "CG": 
        atom.SetFormalCharge(1) # setting to negative charge, CG has summed up charge of OD1 and OD2
    if residue == "GLU" and atom_name == "CD": 
        atom.SetFormalCharge(1) # setting to negative charge, CD has summed up charge of OE1 and OE2

    features[11] = 1 if atom.GetFormalCharge() != 0 else 0 # binary label, charge or no charge
   
    features[12] = atom.GetFormalCharge()

    print(features[11], "charge sign", features[12], "is charge present")
    # if atom.GetFormalCharge() != 0:
    #     print(atom.GetFormalCharge(), "is atom that has formal charge and", atom_name, "is atom name")
    
    features[13] = 1 if atom.IsInRing() else 0
    
    #print(features)
    return features

# Function to perform one-hot encoding for residue types
def encode_residue_type(residue):
    features = np.zeros(9)
    if residue in ['ASP', 'GLU']:
        features[0] = 1
    elif residue in ['LYS', 'ARG']:
        features[1] = 1
    elif residue == 'HIS':
        features[2] = 1
    elif residue == 'CYS':
        features[3] = 1
    elif residue in ['ASN', 'GLN', 'SER', 'THR']:
        features[4] = 1
    elif residue == 'GLY':
        features[5] = 1
    elif residue == 'PRO':
        features[6] = 1
    elif residue in ['PHE', 'TYR', 'TRP']:
        features[7] = 1
    elif residue in ['ALA', 'ILE', 'LEU', 'MET', 'VAL']:
        features[8] = 1
    return features

# Map atoms to the grid based on their 3D coordinates
def map_atoms_to_grid(mol, grid, grid_center, grid_size=20, resolution=1):
    conf = mol.GetConformer()

    # Compute bounds for min max normalization
    all_positions = np.array([[pos.x, pos.y, pos.z] for pos in [conf.GetAtomPosition(atom_idx) for atom_idx in range(mol.GetNumAtoms())]])
    min_coords = np.min(all_positions, axis=0)
    max_coords = np.max(all_positions, axis=0)
    scale = max_coords - min_coords
    print(scale, "is scale")

    # Apply min-max normalization to scale positions to [0, grid_size)
    def shift(pos, min_coords):
        return ((pos - min_coords))
    
    for atom in mol.GetAtoms(): 
        pos = conf.GetAtomPosition(atom.GetIdx())
        print(f"\n{atom.GetSymbol()} is atom symbol {atom.GetIdx()} is atom id")
        shifted_pos = shift(np.array([pos.x, pos.y, pos.z]), min_coords)

        # Map to grid coordinates
        grid_coord = np.rint(shifted_pos).astype(int)
        
        if np.all(grid_coord >= 0) and np.all(grid_coord < (grid_size * resolution)):
            atom_features = encode_atom_features(atom)

            residue = atom.GetPDBResidueInfo().GetResidueName()
            residue_features = encode_residue_type(residue)

            combined_features = np.concatenate((atom_features, residue_features))
            if np.any(grid[tuple(grid_coord)]):
                grid_coord = np.floor(shifted_pos).astype(int) # try flooring if rint doesn't work
                if np.any(grid[tuple(grid_coord)]):
                    grid_coord = np.ceil(shifted_pos).astype(int) # last ditch effort is to try ceiling if flooring fails
                    if np.any(grid[tuple(grid_coord)]):
                        print("Overwritten atoms")
                        raise Exception("Overwritten atoms!")
            grid[tuple(grid_coord)] = combined_features # print this part as well
        else:
            print("Atom didn't go in the grid")
            raise Exception("Atom out of bounds")

    return grid

def min_max_normalize(grid):
    min_val = np.min(grid)
    max_val = np.max(grid)
    
    if max_val - min_val == 0:
        return grid  # Avoid division by zero if all values are the same
    
    return (grid - min_val) / (max_val - min_val)

# Main function to generate multiple rotated grids
def generate_rotated_grids(grid_center, filtered_pdb_path, num_rotations=20, grid_size=30, resolution=1):
    mol = Chem.MolFromPDBFile(filtered_pdb_path, sanitize=True)
    
    if mol is None:
        return None
    
    grids = []
    
    for i in range(num_rotations):
        # Create a new grid
        grid = create_grid(size=grid_size, resolution=resolution)
        
        # Generate a random rotation matrix
        rotation_matrix = generate_random_rotation_matrix()
        
        # Rotate the molecule
        rotated_mol = rotate_molecule(mol, rotation_matrix)
        
        # Map rotated atoms to the grid
        grid = map_atoms_to_grid(rotated_mol, grid, grid_center, grid_size, resolution)

        # Apply Min-Max normalization
        grid = min_max_normalize(grid)
        
        # Store the rotated grid
        grids.append(grid)
    
    return grids
def saving_features(rotated_grids,output_path,protein_name_):
    os.makedirs(output_path, exist_ok=True)
    # Save each grid
    for idx, grid in enumerate(rotated_grids):
        np.save(f'{output_path}/{protein_name_}_grid_{idx}.npy', grid)
        print(f"Saved rotated grid {idx} successfully.")
    return

In [20]:
file = "filtered-pdbs-distinct-5A/positive/5AVU-filtered.pdb"

grid_center = np.array([0, 0, 0])  # Grid center at origin

# Generate rotated grids (5 rotations)
rotated_grids = generate_rotated_grids(grid_center, file, num_rotations=1)

[10.96713373 17.04131548 18.83188009] is scale

C is atom symbol 0 is atom id
SP3 is atom hybridization type
0 is number of heavy atoms
0 is number of hetero atoms
0.0 charge sign 0.0 is charge present

C is atom symbol 1 is atom id
SP3 is atom hybridization type
0 is number of heavy atoms
0 is number of hetero atoms
0.0 charge sign 0.0 is charge present

C is atom symbol 2 is atom id
SP3 is atom hybridization type
Neighbor: C AtomicNum: 6
Neighbor: C AtomicNum: 6
2 is number of heavy atoms
0 is number of hetero atoms
0.0 charge sign 0.0 is charge present

C is atom symbol 3 is atom id
SP2 is atom hybridization type
Neighbor: C AtomicNum: 6
Neighbor: O AtomicNum: 8
2 is number of heavy atoms
1 is number of hetero atoms
0.0 charge sign 0.0 is charge present

O is atom symbol 4 is atom id
SP2 is atom hybridization type
Neighbor: C AtomicNum: 6
1 is number of heavy atoms
0 is number of hetero atoms
0.0 charge sign 0.0 is charge present

C is atom symbol 5 is atom id
SP3 is atom hybridizat