In [1]:
import random
import torch
import torch_geometric

import numpy as np

from torch_geometric.data import Data, DataLoader, Batch

import torch.nn.functional as F
import sys
import os

import pandas as pd

from rdkit import Chem
from tqdm import tqdm
from copy import deepcopy

from concurrent.futures import ThreadPoolExecutor, as_completed
import concurrent
from concurrent.futures import ProcessPoolExecutor, as_completed

from pathlib import Path

cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)

from torch_scatter import scatter_max, scatter_add

from DataPipeline.preprocessing import process_encode_graph, get_subgraph_with_terminal_nodes_step
from DataPipeline.preprocessing import node_encoder
from Model.GNN1 import ModelWithEdgeFeatures as GNN1
from Model.GNN1 import ModelWithNodeConcat as GNN1_node_concat
from Model.GNN2 import ModelWithEdgeFeatures as GNN2
from Model.GNN2 import ModelWithNodeConcat as GNN2_node_concat
from Model.GNN3 import ModelWithEdgeFeatures as GNN3
from Model.GNN3 import ModelWithgraph_embedding_modif as GNN3_embedding

In [2]:
SCORED_SAVING_DIR = Path('.') / 'generated_mols' / 'scored'
ZINC_DATA_PATH = SCORED_SAVING_DIR / 'zinc_scored_filtered.csv'
df_zinc = pd.read_csv(ZINC_DATA_PATH)

In [3]:
def tensor_to_smiles(node_features, edge_index, edge_attr, edge_mapping = 'aromatic', encoding_type = 'charged'):
    # Create an empty editable molecule
    mol = Chem.RWMol()

    # Define atom mapping
    if encoding_type == 'charged':
        
        atom_mapping = {
            0: ('C', 0),
            1: ('N', 0),
            2: ('N', 1),
            3: ('N', -1),
            4: ('O', 0),
            5: ('O', -1),
            6: ('F', 0),
            7: ('S', 0),
            8: ('S', -1),
            9: ('Cl', 0),
            10: ('Br', 0),
            11: ('I', 0)
        }

    elif encoding_type == 'polymer':
        atom_mapping = {
            0: ('C', 0),
            1: ('N', 0),
            2: ('O', 0),
            3: ('F', 0),
            4: ('Si', 0),
            5: ('P', 0),
            6: ('S', 0)}

    # Add atoms
    for atom_feature in node_features:
        atom_idx = atom_feature[:12].argmax().item()
        atom_symbol, charge = atom_mapping.get(atom_idx)
        atom = Chem.Atom(atom_symbol)
        atom.SetFormalCharge(charge)
        mol.AddAtom(atom)

    # Define bond type mapping
    if edge_mapping == 'aromatic':
        bond_mapping = {
            0: Chem.rdchem.BondType.AROMATIC,
            1: Chem.rdchem.BondType.SINGLE,
            2: Chem.rdchem.BondType.DOUBLE,
            3: Chem.rdchem.BondType.TRIPLE,
        }
    elif edge_mapping == 'kekulized':
        bond_mapping = {
            0: Chem.rdchem.BondType.SINGLE,
            1: Chem.rdchem.BondType.DOUBLE,
            2: Chem.rdchem.BondType.TRIPLE,
        }

    # Add bonds
    for start, end, bond_attr in zip(edge_index[0], edge_index[1], edge_attr):
        bond_type_idx = bond_attr[:4].argmax().item()
        bond_type = bond_mapping.get(bond_type_idx)

        # RDKit ignores attempts to add a bond that already exists,
        # so we need to check if the bond exists before we add it
        if mol.GetBondBetweenAtoms(start.item(), end.item()) is None:
            mol.AddBond(start.item(), end.item(), bond_type)

    # Convert the molecule to SMILES
    smiles = Chem.MolToSmiles(mol)

    return smiles

def extract_all_graphs(batch):
    all_graphs = []
    nb_graphs = batch.batch.max().item() + 1

    for i in range(nb_graphs):
        # Create a mask of booleans
        mask = batch.batch == i
        
        # Extract all the node features that correspond to the i-th graph
        subgraph_x = batch.x[mask]
        # Create a mapping of the corresponding indices from the big graph to the individual graph

        indices_mapping = {j.item(): k for k, j in enumerate(torch.where(mask)[0])}
        mapping_func = np.vectorize(indices_mapping.get)

        # Extract all the edges that correspond to the i-th graph
        edge_mask = mask[batch.edge_index[0]] & mask[batch.edge_index[1]]

        if edge_mask.sum() == 0:
            subgraph_edge_index = torch.tensor([], dtype=torch.long)
        else:
            subgraph_edge_index = torch.tensor(mapping_func(batch.edge_index[:, edge_mask].cpu().numpy()), dtype=torch.long)

        # Extract all the edge features that correspond to the i-th graph

        
        if batch.edge_attr is not None:
            subgraph_edge_attr = batch.edge_attr[edge_mask]
        else:
            subgraph_edge_attr = None

        # Construct the subgraph
        subgraph = Data(x=subgraph_x, edge_index=subgraph_edge_index, edge_attr=subgraph_edge_attr)
        # Append the subgraph to the list
        all_graphs.append(subgraph)

    return all_graphs

In [4]:
def sample_random_subgraph_ZINC(pd_dataframe, start_size):
    indice = random.choice(pd_dataframe.index)
    smiles_str = pd_dataframe.loc[indice, 'smiles']

    torch_graph = process_encode_graph(smiles_str, encoding_option='charged', kekulize=True)
    subgraph_data, terminal_node_info, id_map = get_subgraph_with_terminal_nodes_step(torch_graph, start_size, impose_edges=True)

    return subgraph_data, terminal_node_info, id_map

def create_batch_from_zinc(pd_dataframe, batch_size, start_size, encoding_option='reduced'):
    graphs = []
    for _ in range(batch_size):
        subgraph_data, terminal_node_info, id_map = sample_random_subgraph_ZINC(pd_dataframe, start_size)
        graphs.append(subgraph_data)
    return Batch.from_data_list(graphs)

In [5]:
from FinalPipeline.utils import GenerationModule, tensor_to_smiles
def convert_to_smiles(graph, kekulize=True, encoding_type='charged'):
    smiles = []
    for g in graph:
        smiles.append(tensor_to_smiles(g.x, g.edge_index, g.edge_attr, edge_mapping='kekulized', encoding_type=encoding_type))
    return smiles

In [6]:
from tqdm import tqdm
def generate_scaffolds(nb, batch_size):
    smiles_scaffolds = []
    for _ in tqdm(range(nb)):
        graph_batch = create_batch_from_zinc(df_zinc, batch_size, 3)
        graph_list = extract_all_graphs(graph_batch)
        smiles_scaffolds.extend(convert_to_smiles(graph_list, kekulize=True, encoding_type='charged'))

    return smiles_scaffolds

In [8]:
scaffolds_1000000 = generate_scaffolds(1000, 1000)

100%|██████████| 1000/1000 [41:46<00:00,  2.51s/it]


In [None]:
scaffolds = generate_scaffolds(100, 1000)

100%|██████████| 100/100 [04:16<00:00,  2.57s/it]


In [None]:
# Count of unique scaffolds

from collections import Counter

scaffolds_counter = Counter(scaffolds)
print(f'Number of unique scaffolds: {len(scaffolds_counter)}')

Number of unique scaffolds: 139


In [9]:
# Count of unique scaffolds1000000

from collections import Counter

scaffolds_counter_1000000 = Counter(scaffolds_1000000)
print(f'Number of unique scaffolds: {len(scaffolds_counter_1000000)}')

Number of unique scaffolds: 188
