In [1]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import rdMolTransforms
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import os
import sys
import nbimporter

project_root = os.path.join(os.getcwd(), '..')
sys.path.append(project_root)

from datapreparation.Process_graph_2d_data import *

# Generate 3D conformation and minimize energy

In [2]:
def generate_3d_coordinates(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None,False
    mol_with_h = Chem.AddHs(mol)
    if AllChem.EmbedMolecule(mol_with_h, AllChem.ETKDG()) != 0:
        return None, False
    optimization_result = AllChem.MMFFOptimizeMolecule(mol_with_h,maxIters=10000)
    if optimization_result != 0:
            return None, False
    return mol_with_h, True

# calculate distance for bond

In [3]:
def calculate_distance_for_bond(mol, start_atom, end_atom):
    conf = mol.GetConformer()
    pos1 = conf.GetAtomPosition(start_atom)
    pos2 = conf.GetAtomPosition(end_atom)
    distance = pos1.Distance(pos2)
    return distance

# calculate adjacent angles for bond

In [4]:
def calculate_adjacent_angles_for_bond(mol, atom_idx1, atom_idx2):
    bond_angles = []
    for atom_idx in [atom_idx1, atom_idx2]:
        atom = mol.GetAtomWithIdx(atom_idx)
        neighbors = [n.GetIdx() for n in atom.GetNeighbors()]
        for neighbor_idx in neighbors:
            if neighbor_idx != atom_idx1 and neighbor_idx != atom_idx2:
                if atom_idx == atom_idx1:
                    angle = rdMolTransforms.GetAngleDeg(mol.GetConformer(), neighbor_idx, atom_idx, atom_idx2)
                else:
                    angle = rdMolTransforms.GetAngleDeg(mol.GetConformer(), neighbor_idx, atom_idx, atom_idx1)
                bond_angles.append(angle)
    return bond_angles


# calculate dihedral angles for bond

In [5]:
def calculate_dihedral_angles_for_bond(mol, start_atom, end_atom):
    start_neighbors = [n.GetIdx() for n in mol.GetAtomWithIdx(start_atom).GetNeighbors() if n.GetIdx() != end_atom]
    end_neighbors = [n.GetIdx() for n in mol.GetAtomWithIdx(end_atom).GetNeighbors() if n.GetIdx() != start_atom]
    dihedral_angles = []
    if len(start_neighbors) >= 1 and len(end_neighbors) >= 1:
        for sn in start_neighbors:
            for en in end_neighbors:
                angle = rdMolTransforms.GetDihedralDeg(mol.GetConformer(), sn, start_atom, end_atom, en)
                dihedral_angles.append(angle)
    return dihedral_angles if dihedral_angles else [0]  # Return a list with 0 if no angles were calculated


# get bond 3D information

In [6]:
def get_bond_3d_feature(mol, start, end):
    
    '''
    distance = calculate_distance_for_bond(mol, start, end)  
    adjacent_angles = calculate_adjacent_angles_for_bond(mol, start, end)
    dihedral_angles = calculate_dihedral_angles_for_bond(mol, start, end)
    adjacent_angles_features = [np.mean(adjacent_angles), np.max(adjacent_angles), np.min(adjacent_angles)] if adjacent_angles else [0, 0, 0]
    dihedral_angles_features = [np.mean(dihedral_angles), np.max(dihedral_angles), np.min(dihedral_angles)] if dihedral_angles else [0, 0, 0]
    bond_features = np.array([distance] + adjacent_angles_features + dihedral_angles_features, dtype=np.float32)
    return bond_features
    '''

    distance = calculate_distance_for_bond(mol, start, end)
    bond_features = np.array([distance], dtype=np.float32)
    return bond_features



# voxelization

In [7]:
def voxelization(coords, resolution, grid_size):
    voxels = np.zeros((grid_size, grid_size, grid_size))
    voxel_size = 1.0 / resolution
    center = grid_size / 2
    for x, y, z in coords:
        ix = int((x / voxel_size) + center)
        iy = int((y / voxel_size) + center)
        iz = int((z / voxel_size) + center)
        
        if 0 <= ix < grid_size and 0 <= iy < grid_size and 0 <= iz < grid_size:
            voxels[ix, iy, iz] = 1  
    return voxels

# Integrate 3D graph data information

In [8]:
def preprocess_3d_graph(smiles_list, labels, atom_numbers):
    graph_data_3d = []
    label_map = {'Negative': 0, 'Positive': 1}

    for smiles, label in tqdm(zip(smiles_list, labels), total=len(smiles_list), desc="get graph_data"):
        mol_with_h, success = generate_3d_coordinates(smiles)
        if not success:
            continue
            
        edge_index_list = []
        edge_attr_list = []
        
        atoms_features = [get_atom_features(atom,atom_numbers) for atom in mol_with_h.GetAtoms()]
        
        coords = np.array([mol_with_h.GetConformer().GetAtomPosition(atom.GetIdx()) for atom in mol_with_h.GetAtoms()])
        
        voxels = voxelization(coords, resolution=10, grid_size=20)

        for bond in mol_with_h.GetBonds():
            start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            
            bond_features = get_bond_3d_feature(mol_with_h, start, end)
            edge_attr_list.append(bond_features)
            
            edge_index_list.append([start, end])
            #edge_index_list.append([end, start])
            
        edge_index_array = np.array(edge_index_list).T  
        edge_attr_array = np.array(edge_attr_list, dtype=np.float32)

        graph_data_3d.append({
            'nodes_features': np.array(atoms_features),
            'edge_index': edge_index_array,  
            'edge_attr': edge_attr_array,
            'voxels': voxels,
            'label': label_map[label]
        })

    return graph_data_3d


In [9]:
def get_voxels_labels(graph_data_3d):
    voxels = np.array([item['voxels'] for item in graph_data_3d])  # [num_samples, depth, height, width]
    labels = np.array([item['label'] for item in graph_data_3d])  # [num_samples,]
    #voxels = torch.tensor(voxels, dtype=torch.float).permute(0, 4, 1, 2, 3)  # [num_samples, 1, depth, height, width]
    voxels = torch.tensor(voxels, dtype=torch.float).unsqueeze(1)
    labels = torch.tensor(labels, dtype=torch.long)
    return voxels,labels

In [10]:
def load_3d_voxels_data(graph_data_3d, batch_size):  
    voxels,labels = get_voxels_labels(graph_data_3d)
    voxels_train, voxels_test, labels_train, labels_test = train_test_split(voxels, labels, test_size=0.2, random_state=42)

    train_dataset = TensorDataset(voxels_train, labels_train)
    test_dataset = TensorDataset(voxels_test, labels_test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader,test_loader

In [11]:
def load_3d_voxels_data_10fold_cv(voxels, labels, train_idx, test_idx, batch_size):
    voxels_train, labels_train = voxels[train_idx], labels[train_idx]
    voxels_test, labels_test = voxels[test_idx], labels[test_idx]
    
    train_dataset = TensorDataset(voxels_train, labels_train)
    test_dataset = TensorDataset(voxels_test, labels_test)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader