In [1]:
#import rdkit

In [2]:
# Import necessary libraries
import os
import csv
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import BRICS
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import NNConv, GlobalAttention, GCNConv
from torch_geometric.utils import to_networkx
from torch.distributions import Bernoulli, Categorical
import networkx as nx
from tqdm import tqdm

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [3]:
# Load the dataset
data_df = pd.read_csv('BRD4_mini_sampled_data_POS.csv')  # Replace with your dataset path

# Filter out invalid SMILES
valid_smiles = []
for idx, row in data_df.iterrows():
    smiles = row['molecule_smiles']
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        valid_smiles.append(smiles)

print(f"Total valid molecules: {len(valid_smiles)}")

Total valid molecules: 12500


In [4]:
def atom_features(atom):
    # Atomic number
    atomic_number = atom.GetAtomicNum()
    # Atom degree
    degree = atom.GetDegree()
    # Formal charge
    formal_charge = atom.GetFormalCharge()
    # Number of hydrogens
    num_hs = atom.GetTotalNumHs()
    # Aromaticity
    is_aromatic = atom.GetIsAromatic()
    # Hybridization
    hybridization = atom.GetHybridization()

    # One-hot encoding for hybridization
    hybridization_encoding = [0]*6
    hybridization_types = [
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2,
        Chem.rdchem.HybridizationType.UNSPECIFIED
    ]
    if hybridization in hybridization_types:
        idx = hybridization_types.index(hybridization)
        hybridization_encoding[idx] = 1

    # Combine all features into a list
    features = [
        atomic_number,
        degree,
        formal_charge,
        num_hs,
        int(is_aromatic)
    ] + hybridization_encoding
    return features

def bond_features(bond):
    bond_type = bond.GetBondType()

    # Bond type as one-hot encoding
    bond_type_feats = [
        int(bond_type == Chem.rdchem.BondType.SINGLE),
        int(bond_type == Chem.rdchem.BondType.DOUBLE),
        int(bond_type == Chem.rdchem.BondType.TRIPLE),
        int(bond_type == Chem.rdchem.BondType.AROMATIC)
    ]

    # Conjugation
    conjugation_feat = [int(bond.GetIsConjugated())]

    # Ring membership
    ring_feat = [int(bond.IsInRing())]

    # Stereo configuration as one-hot encoding
    stereo = bond.GetStereo()
    stereo_feats = [
        int(stereo == Chem.rdchem.BondStereo.STEREONONE),
        int(stereo == Chem.rdchem.BondStereo.STEREOANY),
        int(stereo == Chem.rdchem.BondStereo.STEREOZ),
        int(stereo == Chem.rdchem.BondStereo.STEREOE),
        int(stereo == Chem.rdchem.BondStereo.STEREOCIS),
        int(stereo == Chem.rdchem.BondStereo.STEREOTRANS)
    ]

    # Combine all features
    bond_feats = bond_type_feats + conjugation_feat + ring_feat + stereo_feats
    return bond_feats


In [5]:
def fragment_features(frag):
    # Compute atom features
    atom_feats = []
    for atom in frag.GetAtoms():
        feat = atom_features(atom)
        atom_feats.append(feat)
    # Aggregate atom features (sum)
    atom_feats = np.array(atom_feats)
    atom_feat_vector = np.sum(atom_feats, axis=0)

    # Compute bond features
    bond_feats = []
    bonds = frag.GetBonds()
    if bonds:
        for bond in bonds:
            feat = bond_features(bond)
            bond_feats.append(feat)
        bond_feats = np.array(bond_feats)
        bond_feat_vector = np.sum(bond_feats, axis=0)
    else:
        # Handle fragments with no bonds (e.g., single atoms)
        bond_feat_vector = np.zeros(13)

    # Concatenate atom and bond feature vectors
    frag_feat_vector = np.concatenate([atom_feat_vector, bond_feat_vector])
    return frag_feat_vector

In [6]:
def decompose_molecule(mol):
    # Find BRICS bonds to break
    brics_bonds = BRICS.FindBRICSBonds(mol)
    bonds_to_break = []
    for (atom_idx1, atom_idx2), (label1, label2) in brics_bonds:
        bond = mol.GetBondBetweenAtoms(atom_idx1, atom_idx2)
        if bond is not None:
            bonds_to_break.append(bond.GetIdx())

    # Break the bonds
    fragment_mol = Chem.FragmentOnBonds(mol, bonds_to_break, addDummies=True)

    # Get the fragments
    fragments = Chem.GetMolFrags(fragment_mol, asMols=True, sanitizeFrags=True)

    # Build mapping from fragment to fragment ID
    frag_id_mapping = {}
    for idx, frag in enumerate(fragments):
        frag_id_mapping[idx] = frag

    # Build the tree by connecting fragments via attachment points
    # Attachment points are marked with dummy atoms (*)
    # We can find which fragments are connected by looking at the dummy atoms
    # Each dummy atom has an isotope number indicating the original bond

    # Build a mapping from dummy atom isotope labels to fragment IDs and atom indices
    dummy_atom_mapping = {}
    for idx, frag in frag_id_mapping.items():
        for atom in frag.GetAtoms():
            if atom.GetAtomicNum() == 0:  # Dummy atom (*)
                isotope = atom.GetIsotope()
                if isotope not in dummy_atom_mapping:
                    dummy_atom_mapping[isotope] = []
                dummy_atom_mapping[isotope].append((idx, atom.GetIdx(), atom.GetAtomMapNum()))

    # Build the tree edges with bond features
    tree_edges = []
    edge_features = []
    for isotope, connections in dummy_atom_mapping.items():
        if len(connections) == 2:
            frag1_id, atom1_idx, atom1_mapnum = connections[0]
            frag2_id, atom2_idx, atom2_mapnum = connections[1]

            # Since the original bond was broken, we might not have bond features
            bond_feat = np.zeros(13)  # Adjust the size based on your bond feature length

            tree_edges.append((frag1_id, frag2_id))
            edge_features.append(bond_feat)
        else:
            # Handle cases with more than two connections (e.g., branching)
            for i in range(len(connections)):
                for j in range(i+1, len(connections)):
                    frag1_id, atom1_idx, atom1_mapnum = connections[i]
                    frag2_id, atom2_idx, atom2_mapnum = connections[j]

                    bond_feat = np.zeros(13)  # Adjust the size based on your bond feature length

                    tree_edges.append((frag1_id, frag2_id))
                    edge_features.append(bond_feat)

    return frag_id_mapping, tree_edges, edge_features


In [7]:
substructure_library = {}
substructure_vocab = []
substructure_to_idx = {}
idx_to_substructure = {}
label_counter = 0

def mol_to_substructure_tree(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None  # Skip invalid molecules

    # Decompose molecule into fragments using BRICS
    frag_id_mapping, tree_edges, edge_features = decompose_molecule(mol)

    # For each fragment, compute features and assign labels
    fragment_features_list = []
    fragment_labels = []
    for frag_id in sorted(frag_id_mapping.keys()):
        frag = frag_id_mapping[frag_id]
        feat_vector = fragment_features(frag)
        # Use the canonical SMILES of the fragment
        smiles_frag = Chem.MolToSmiles(frag, isomericSmiles=True)
        # Add to substructure library
        if smiles_frag not in substructure_to_idx:
            substructure_to_idx[smiles_frag] = len(substructure_vocab)
            idx_to_substructure[len(substructure_vocab)] = smiles_frag
            substructure_vocab.append(smiles_frag)
        frag_label = substructure_to_idx[smiles_frag]
        fragment_features_list.append(feat_vector)
        fragment_labels.append(frag_label)

    # Prepare edge indices and edge attributes
    edge_index = []
    edge_attr = []
    for idx, (i, j) in enumerate(tree_edges):
        edge_index.append([i, j])
        edge_index.append([j, i])  # Undirected graph
        bond_feat = edge_features[idx]
        edge_attr.append(bond_feat)
        edge_attr.append(bond_feat)  # Duplicate for both directions

    # Convert to tensors
    x = torch.tensor(fragment_features_list, dtype=torch.float)
    x = x / x.norm(dim=1, keepdim=True)  # Normalize node features
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    y = torch.tensor(fragment_labels, dtype=torch.long)  # Substructure labels

    # Create Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
    return data


In [8]:

data_list = []
for smiles in tqdm(valid_smiles, total=len(valid_smiles)):
    data = mol_to_substructure_tree(smiles)
    if data is not None:
        data_list.append(data)

print(f"Total data samples: {len(data_list)}")
print(f"Substructure vocabulary size: {len(substructure_vocab)}")

  x = torch.tensor(fragment_features_list, dtype=torch.float)
100%|██████████| 12500/12500 [04:46<00:00, 43.62it/s]

Total data samples: 12500
Substructure vocabulary size: 10106





In [9]:
def write_substructure_library(substructure_vocab, filename='substructureLibrary.csv'):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Index', 'SMILES'])
        for idx, smiles in enumerate(substructure_vocab):
            writer.writerow([idx, smiles])

write_substructure_library(substructure_vocab)


In [10]:
class MPNNEncoder(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim, latent_dim):
        super(MPNNEncoder, self).__init__()

        # Define the network to compute edge weights
        self.nn_edge1 = nn.Sequential(
            nn.Linear(edge_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, node_feature_dim * hidden_dim)
        )
        self.nn_edge2 = nn.Sequential(
            nn.Linear(edge_feature_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * hidden_dim)
        )

        # Message Passing Layers
        self.conv1 = NNConv(node_feature_dim, hidden_dim, self.nn_edge1, aggr='mean')
        self.conv2 = NNConv(hidden_dim, hidden_dim, self.nn_edge2, aggr='mean')

        # Attention-based pooling using softmax
        self.attention = GlobalAttention(
            gate_nn=nn.Sequential(nn.Linear(hidden_dim, 1)),
            nn=nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.ReLU())
        )

        # Latent space projections
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, data):
        x, edge_index, edge_attr = data.x.to(device), data.edge_index.to(device), data.edge_attr.to(device)

        # Message Passing
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)

        # Attention-based Pooling
        x = self.attention(x, data.batch)

        # Latent Space
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

In [11]:
class TreeDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, substructure_vocab_size):
        super(TreeDecoder, self).__init__()
        self.hidden_dim = hidden_dim
        self.substructure_vocab_size = substructure_vocab_size

        # Initial hidden state projection from latent vector
        self.fc_hidden = nn.Linear(latent_dim, hidden_dim)
        self.fc_root = nn.Linear(latent_dim, substructure_vocab_size)

        # RNN cell for tree traversal
        self.rnn_cell = nn.GRUCell(hidden_dim, hidden_dim)

        # Child existence predictor
        self.fc_exist = nn.Linear(hidden_dim, 1)  # Outputs logit for sigmoid

        # Substructure selector
        self.fc_substruct = nn.Linear(hidden_dim, substructure_vocab_size)

    def forward(self, z):
        batch_size = z.size(0)
        hidden = self.fc_hidden(z)  # Initial hidden state
        root_logits = self.fc_root(z)  # Root substructure logits

        # Use Gumbel-Softmax to sample root substructure
        root_prob = F.gumbel_softmax(root_logits, tau=1, hard=True)
        root_substruct = root_prob.argmax(dim=1)

        # Initialize outputs
        trees = []  # List of generated trees for each sample
        decisions_list = []

        for i in range(batch_size):
            tree = []
            decisions = []

            # Initialize stack for DFS
            stack = []
            node = {
                'hidden': hidden[i],
                'substruct': root_substruct[i],
                'parent': None,
                'children': []
            }
            tree.append(node)
            stack.append(node)

            # Traverse the tree
            while stack:
                current_node = stack.pop()
                h = current_node['hidden']

                # Decide whether to create a child
                exist_logit = self.fc_exist(h)
                exist_prob = torch.sigmoid(exist_logit)
                # During training, use the probability; during inference, sample
                exist_decision = (exist_prob > 0.5).float()
                decisions.append({'exist_prob': exist_prob, 'exist_decision': exist_decision})

                if exist_decision.item() == 1:
                    # Decide which substructure to use
                    substruct_logits = self.fc_substruct(h)
                    substruct_prob = F.gumbel_softmax(substruct_logits, tau=1, hard=True)
                    substruct_idx = substruct_prob.argmax(dim=0)
                    decisions.append({'substruct_prob': substruct_prob, 'substruct_idx': substruct_idx})

                    # Create child node
                    h_child = self.rnn_cell(h, h)  # Update hidden state
                    child_node = {
                        'hidden': h_child,
                        'substruct': substruct_idx,
                        'parent': current_node,
                        'children': []
                    }
                    current_node['children'].append(child_node)
                    stack.append(child_node)

            trees.append(tree)
            decisions_list.append(decisions)

        return trees, decisions_list

In [12]:
class VAE(nn.Module):
    def __init__(self, node_feature_dim, edge_feature_dim, hidden_dim, latent_dim, substructure_vocab_size):
        super(VAE, self).__init__()
        self.encoder = MPNNEncoder(node_feature_dim, edge_feature_dim, hidden_dim, latent_dim)
        self.decoder = TreeDecoder(latent_dim, hidden_dim, substructure_vocab_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std).to(device)
        z = mu + eps * std
        return z

    def forward(self, data):
        mu, logvar = self.encoder(data)
        z = self.reparameterize(mu, logvar)
        trees, decisions = self.decoder(z)
        return trees, decisions, mu, logvar

In [13]:
def compute_loss(decisions_list, data, substructure_vocab_size):
    total_exist_loss = 0
    total_substruct_loss = 0

    # Since we have variable-length outputs, we need to handle each sample individually
    for decisions, sample_data in zip(decisions_list, data):
        # Ground truth tree from sample_data
        ground_truth_labels = sample_data.y  # Substructure labels
        # For simplicity, assume a linearized version of the tree (you may need to adjust this)
        gt_substruct_indices = ground_truth_labels.tolist()

        # Initialize pointers
        gt_idx = 0
        for decision in decisions:
            if 'exist_prob' in decision:
                # Binary cross-entropy loss for existence decision
                target = torch.tensor([1.0]).to(device) if gt_idx < len(gt_substruct_indices) - 1 else torch.tensor([0.0]).to(device)
                exist_loss = F.binary_cross_entropy(decision['exist_prob'], target)
                total_exist_loss += exist_loss

            if 'substruct_prob' in decision:
                if gt_idx < len(gt_substruct_indices):
                    target = torch.tensor([gt_substruct_indices[gt_idx]]).to(device)
                    substruct_loss = F.cross_entropy(decision['substruct_prob'].unsqueeze(0), target)
                    total_substruct_loss += substruct_loss
                    gt_idx += 1

    total_loss = total_exist_loss + total_substruct_loss
    return total_loss

: 

In [14]:
# Prepare DataLoader
batch_size = 16
loader = DataLoader(data_list, batch_size=batch_size, shuffle=True)

# Initialize the model
node_feature_dim = data_list[0].x.shape[1]
edge_feature_dim = data_list[0].edge_attr.shape[1]
hidden_dim = 64
latent_dim = 32
substructure_vocab_size = len(substructure_vocab)

model = VAE(node_feature_dim, edge_feature_dim, hidden_dim, latent_dim, substructure_vocab_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        data = data.to(device)
        trees, decisions_list, mu, logvar = model(data)
        loss_recon = compute_loss(decisions_list, data, substructure_vocab_size)
        loss_kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = loss_recon + loss_kl
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    avg_loss = total_loss / len(loader.dataset)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")



In [None]:
def reconstruct_molecule_from_tree(tree, idx_to_substructure):
    try:
        # Initialize molecule
        mol = None
        atom_idx_map = {}
        for node in tree:
            substruct_idx = node['substruct'].item()
            smiles_frag = idx_to_substructure[substruct_idx]
            frag_mol = Chem.MolFromSmiles(smiles_frag)
            if mol is None:
                mol = frag_mol
                # Map atom indices
                atom_idx_map[node['substruct']] = list(range(mol.GetNumAtoms()))
            else:
                # Attach frag_mol to mol
                # This is a simplified example; proper attachment requires handling dummy atoms
                mol = Chem.CombineMols(mol, frag_mol)
                # Update atom index mapping
                atom_idx_map[node['substruct']] = list(range(mol.GetNumAtoms() - frag_mol.GetNumAtoms(), mol.GetNumAtoms()))
        # Sanitize molecule
        Chem.SanitizeMol(mol)
        smiles = Chem.MolToSmiles(mol)
        return smiles
    except:
        return None


In [None]:
model.eval()
generated_molecules = []
num_samples = 100

with torch.no_grad():
    for _ in range(num_samples):
        # Sample z from standard normal distribution
        z = torch.randn(1, latent_dim).to(device)
        # Condition on 'binds' label = 1 (binding)
        condition = torch.tensor([[1.0]]).to(device)
        trees, _ = model.decoder(z, condition)

        # Reconstruct molecule from generated tree
        generated_smiles = reconstruct_molecule_from_tree(trees[0], idx_to_substructure)
        if generated_smiles is not None:
            generated_molecules.append(generated_smiles)


In [None]:
valid_count = 0
for smiles in generated_molecules:
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        valid_count += 1

validity = valid_count / len(generated_molecules)
print(f"Validity of generated molecules: {validity * 100:.2f}%")
