In [27]:
import rdkit
import random
import time
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from random import randint
import matplotlib.pyplot as plt
from rdkit.Chem.Scaffolds import MurckoScaffold
import math
from itertools import combinations
import json

In [2]:
def get_mol(smiles):
    return Chem.MolFromSmiles(smiles)

In [3]:
def get_smiles(mol):
    return Chem.MolToSmiles(mol)

In [4]:
def draw_mol(mol):
    return Draw.MolToImage(mol)

In [5]:
def draw_smiles(smi):
    return draw_mol(get_mol(smi))

In [6]:
def get_scaffold(mol):
    return MurckoScaffold.GetScaffoldForMol(mol)

In [7]:
def plot_molecules(mol1, mol2):
    fig, axs = plt.subplots(
        1, 2, figsize=(10, 5)
    )  # Create a subplot with 1 row and 2 columns

    # Draw the input molecule
    axs[0].imshow(draw_mol(mol1))
    axs[0].set_title("Input")
    axs[0].axis("off")  # Hide axes

    # Draw the output molecule
    axs[1].imshow(draw_mol(mol2))
    axs[1].set_title("Output")
    axs[1].axis("off")  # Hide axes

    return plt.show()

In [8]:
def get_idxs_of_carbon_for_new_bond(mol):
  """Return indices of all carbon atoms available for new bond.

  Technically, it returns the indices of Carbon atoms that have
  at least one bond with hydrogen atom.
  """

  carbon_indices = []

  for atom in mol.GetAtoms():
    if atom.GetAtomicNum() == 6:
      bonds = atom.GetBonds()
      num_bonds = sum([b.GetBondTypeAsDouble() for b in bonds])
      if num_bonds < 4:
        carbon_indices.append(atom.GetIdx())

  return carbon_indices

In [9]:
def get_scaffold_and_attachment_points(mol):
    scaffold = get_scaffold(mol)
    
    if Chem.MolToSmiles(scaffold) == "":
        scaffold = Chem.Mol(mol)
        
    return (scaffold, get_idxs_of_carbon_for_new_bond(scaffold))

In [10]:
def get_mol_after_adding_attachment_points_at(mol, at):
    connecting_atom = Chem.Atom("*")
    mutable_copy = Chem.RWMol(mol)
    
    for attachment_idx in at:
        connection_idx = mutable_copy.AddAtom(connecting_atom)
        mutable_copy.AddBond(attachment_idx, connection_idx, Chem.BondType.SINGLE)

    
    _mol = mutable_copy.GetMol()
    AllChem.Compute2DCoords(_mol)
    
    return _mol

In [11]:
def attach_num_to_attachment_points(mol):
    mol_str = get_smiles(mol)
    smiles = ""
    count = 1
    for char in mol_str:
        if char == "*":
          smiles += f"[*:{count}]"
          count += 1
        else:
          smiles += char

    return get_mol(smiles)

In [12]:
def get_target_num_mols_for_given_mol(attachment_points_using, total_combinations):
    target = 1000
    target_for_current_mol = target // total_combinations

    muliplier = 1.5
    if attachment_points_using == 2:
        muliplier = 2   
    if attachment_points_using == 3:
        muliplier = 2.5

    # Usually there is a lot of loss when we try to generate new molecules using libinvent
    # Like, if we ask to generate 100 new molecules libinvent could generate around 60 or 70.
    # And if the input molecule has more attachment point then it could generate around 40 or 30.
    # Therefore we are asking to generate more in order to minimise the loss.
    return math.ceil(target_for_current_mol * muliplier)

In [13]:
def filter_duplicate_molecules(molecules):
    _mols = [get_smiles(mol) for mol in molecules]
    seen = {}
    filtered_list = []

    for m in _mols:
        if m in seen:
            continue
        seen[m] = m
        filtered_list.append(get_mol(m))

    return filtered_list
    

In [14]:
def get_all_comb_of_mol_with_attachment_points(mol):
    scaffold, at = get_scaffold_and_attachment_points(mol)
    
    combs = [combinations(at, 1), combinations(at, 2), combinations(at, 3)]

    # At index 0 we have all molecules with one attachment points.
    # At index 1 we have all molecules with two attachment points.
    # At index 2 we have all molecules with three attachment points.
    molecules_list = []
    attachment_point = 0
    
    for comb in combs:
        molecules_list.append([])
        
        for current_at in comb:
            s = Chem.Mol(scaffold)
            m = get_mol_after_adding_attachment_points_at(s, current_at)
            m = attach_num_to_attachment_points(m)
            molecules_list[attachment_point].append(m)
            
        attachment_point += 1
    
    return [filter_duplicate_molecules(molecules) for molecules in molecules_list]

In [15]:
smiles = [
  "c1ccccc1",
  "CC(C)(C)C1=CC(O)=CC=C1O",
  "CC(C)(C)c1nc(c(s1)-c1ccnc(N)n1)-c1cccc(NS(=O)(=O)c2c(F)cccc2F)c1F",
  "CO",
  "OC=O",
  "CCOCC",
  "COC1=CC23CCCN2CCC4=CC5=C(C=C4C3C1O)OCO5", # Cephalotaxin
  "CC1(OC2C(OC(C2O1)(C#N)C3=CC=C4N3N=CN=C4N)CO)C", # Remdesivir
  "CC(=O)OC1=CC=CC=C1C(=O)O", # Aspirin
  "CC(CN1C=NC2=C(N=CN=C21)N)OCP(=O)(O)O", #Tenofovir
  "C1=CN=CC=C1C(=O)NN", # Isoniazid
  "CC1C2C(CC3(C=CC(=O)C(=C3C2OC1=O)C)C)O", # Artemisin
  "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O", #Ibuprofen
]

molecules = [get_mol(smi) for smi in smiles]

In [30]:
# Print
def print_details(c, filename):
    total = sum([len(a) for a in c])

    attachment_points = 1
    out = {}
    
    for molecules in c:
        for mol in molecules:
            s = get_smiles(mol)
            t = get_target_num_mols_for_given_mol(attachment_points, total)

            if t in out:
                out[t].append(s)
            else:
                out[t] = [s]
        attachment_points += 1
    print(out)

    with open(filename, "w") as fp:
        json.dump(out, fp)

In [33]:
for mol in molecules[6:7]:
    print("Input", get_smiles(mol))
    c = get_all_comb_of_mol_with_attachment_points(mol)
    print_details(c, "benzene.json")
    print("-------------------------------------------------------------------------------------------")
    print("-------------------------------------------------------------------------------------------")
    
    

Input COC1=CC23CCCN2CCc2cc4c(cc2C3C1O)OCO4
{5: ['C1=C([*:1])CC2c3cc4c(cc3CCN3CCCC123)OCO4', 'C1=C([*:1])C23CCCN2CCc2cc4c(cc2C3C1)OCO4', 'C1=CC23C(C1)c1cc4c(cc1CCN2CCC3[*:1])OCO4', 'C1=CC23CC([*:1])CN2CCc2cc4c(cc2C3C1)OCO4', 'C1=CC23CCC([*:1])N2CCc2cc4c(cc2C3C1)OCO4', 'C1=CC23CCCN2C([*:1])Cc2cc4c(cc2C3C1)OCO4', 'C1=CC23CCCN2CC([*:1])c2cc4c(cc2C3C1)OCO4', 'C1=CC23CCCN2CCc2c(cc4c(c2[*:1])OCO4)C3C1', 'C1=CC23CCCN2CCc2cc4c(c([*:1])c2C3C1)OCO4', 'C1=CC23CCCN2CCc2cc4c(cc2C3([*:1])C1)OCO4', 'C1=CC23CCCN2CCc2cc4c(cc2C3C1[*:1])OCO4', 'C1=CC23CCCN2CCc2cc4c(cc2C3C1)OC([*:1])O4'], 6: ['c1c2c(cc3c1OCO3)C1CC([*:1])=C([*:2])C13CCCN3CC2', 'C1=C([*:1])CC2c3cc4c(cc3CCN3CCC([*:2])C123)OCO4', 'C1=C([*:1])CC2c3cc4c(cc3CCN3CC([*:2])CC123)OCO4', 'C1=C([*:1])CC2c3cc4c(cc3CCN3C([*:2])CCC123)OCO4', 'C1=C([*:1])CC2c3cc4c(cc3CC([*:2])N3CCCC123)OCO4', 'C1=C([*:1])CC2c3cc4c(cc3C([*:2])CN3CCCC123)OCO4', 'C1=C([*:1])CC2c3cc4c(c([*:2])c3CCN3CCCC123)OCO4', 'C1=C([*:1])CC2c3c(cc4c(c3[*:2])OCO4)CCN3CCCC123', 'C1=C([*:1])C