In [None]:
import sys
import os
sys.path.append(os.path.abspath('../Graph_Framework/'))

import numpy as np
import pandas as pd
import torch
import os
from tqdm import tqdm
from rdkit.Chem import SDMolSupplier

from metrics import *

def get_sdf_molsuppliers(paths):
    suppliers = []
    for folder in paths:
        suppliers.append([])
        for file in os.listdir(folder):
            supplier = SDMolSupplier(os.path.join(folder, file), sanitize=False, removeHs=False, strictParsing=False)
            suppliers[-1].append(supplier)
    return suppliers

def get_file_content(paths):
    content = []
    for folder in paths:
        content.append([])
        for file in os.listdir(folder):
            d = torch.load(os.path.join(folder, file))
            content[-1].append(d)
    return content

def latex_table_line(line):
    return " & ".join(line) + r"\\" + "\n"

def latex_table(names, columns, headers, index):
    table = latex_table_line(headers)
    for i, name in enumerate(names):
        table += latex_table_line([name] + [f"{x.mean():.2f} $\pm$ {x.std():.2f}" for x in columns[i][index]])
    return table

def plot_counts(columns, names, index, headers, xlabel, ylabel):
    data = [[x.mean() for x in columns[i][index]] for i in range(len(names))]
    data = [x / sum(x) for x in data]
    data = [[headers[i]] + list(x) for i, x in enumerate(np.array(data).transpose(1,0))]

    df = pd.DataFrame(data, columns=[xlabel] + names)
    df.plot(x=xlabel, ylabel=ylabel, kind="bar", stacked=False)
    

# Paths

In [None]:
paths = [
    '../experiments/cat_cat_H_cosine/samples',
    '../experiments/cat_con_H_cosine/samples',
    '../experiments/cat_cat_noH_cosine/samples',
    '../experiments/cat_con_noH_cosine/samples',
]
names = [
    'Categorical',
    'Continuous',
    'Categorical (No H)',
    'Continuous (No H)',
]

## Validity, uniqueness and novelty

In [None]:
# Dataset smiles
dset_smiles = load_smiles('./smiles.txt')

# Content
content = []
columns = get_sdf_molsuppliers(paths)
for row in columns:
    content.append([])
    for col in tqdm(row):
        mols = [mol for mol in col]
        valid, all = compute_validity(mols)
        samples = len(all)
        unique = compute_uniqueness(valid)
        novel = compute_novelty(unique, dset_smiles)
        novel_h = compute_novelty(unique, dset_smiles, remove_h=True)
        valid, unique, novel, novel_h = len(valid), len(unique), len(novel), len(novel_h)
        content[-1].append([valid/samples, unique/samples, unique/valid, novel/samples, novel/valid, novel/unique, novel_h/samples, novel_h/valid, novel_h/unique])

content = np.array(content).transpose(0, 2, 1) * 100

In [None]:
headers = [r"\textbf{Distribution}", r"\textbf{Schedule}", r"$\boldsymbol{V}$", r"$\boldsymbol{U_s}$", r"$\boldsymbol{U_v}$"]
index = [0, 1, 2]
print(latex_table(names, content, headers, index))

headers = [r"\textbf{Distribution}", r"\textbf{Schedule}", r"$\boldsymbol{N_s}$", r"$\boldsymbol{N_v}$", r"$\boldsymbol{N_u}$"]
index = [3, 4, 5]
print(latex_table(names, content, headers, index))

headers = [r"\textbf{Distribution}", r"\textbf{Schedule}", r"$\boldsymbol{N_s}$", r"$\boldsymbol{N_v}$", r"$\boldsymbol{N_u}$"]
index = [6, 7, 8]
print(latex_table(names, content, headers, index))

# RMSD & Energy

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
import numpy as np

raws = []
results = []
suppliers = get_sdf_molsuppliers(paths)
for model in suppliers:
    results.append([])
    for supplier in tqdm(model):
        values = []
        for i, mol in enumerate(supplier):
            try :
                Chem.SanitizeMol(mol)
                mol_new = mol.__copy__()

                mp = AllChem.MMFFGetMoleculeProperties(mol_new, mmffVariant='MMFF94')
                ff = AllChem.MMFFGetMoleculeForceField(mol_new, mp)
                E_0 = ff.CalcEnergy()
                ff.Minimize()
                E_min = ff.CalcEnergy()

                rms = AllChem.GetBestRMS(mol, mol_new)

                if np.isnan(E_0) or np.isnan(E_min): continue
                if E_0 > 1000000000: continue

                values.append([i, E_0, E_min, (E_0 - E_min) / mol.GetNumAtoms(), rms])
            except: continue
        values = np.array(values)
        values[:, 3] = np.clip(values[:, 3], a_min=0, a_max=None)
        raws.append(values)
        results[-1].append(values.transpose(1, 0).mean(-1))

results = np.array(results).transpose(0, 2, 1)

In [None]:
headers = [r"\textbf{Model}", r"\textbf{Energy}", r"\textbf{Minimised}", r"\textbf{Strain}", r"\textbf{RMSD}"]
index = [1, 2, 3, 4, 5]
print(latex_table(names, results, headers, index))

# Visualise Selected Generated Molecules

In [None]:
import py3Dmol
from rdkit import Chem
from rdkit.Chem import AllChem

suppliers = get_sdf_molsuppliers(paths)
# Nice examples, good structure but high energy, stretched bonds
indices = [
    # 0, 5028; ; 0, 5
    [[0, 5028], [0, 2169], [0, 217]],
    # 0, 186; 1, 5764; 0, 1596
    [[0, 186], [1, 5764], [0, 1596]],
    # 0, 155; 1, 5542; 0, 3
    [[0, 155], [1, 5542], [0, 3]], 
    # 0, 904; 0, 2644; 0, 5884
    [[0, 904], [0, 2644], [0, 5884]] 
]

h = 300
v = py3Dmol.view(width=4*h, height=3*h, viewergrid=(3, 4), linked=False)
v.removeAllModels()


for i, index in enumerate(indices):
    for j, (k, l) in enumerate(index):
        mol = suppliers[i][k][l]

        Chem.SanitizeMol(mol)
        mol_new = mol.__copy__()
        mp = AllChem.MMFFGetMoleculeProperties(mol_new, mmffVariant='MMFF94')
        ff = AllChem.MMFFGetMoleculeForceField(mol_new, mp)
        E_0 = ff.CalcEnergy()
        ff.Minimize()
        E_min = ff.CalcEnergy()

        rms = AllChem.GetBestRMS(mol, mol_new)
        print(j, i, rms)

        v.addModel(Chem.MolToMolBlock(mol), viewer=(j, i))
        v.setStyle({'model': 0}, {'stick': {}}, viewer=(j, i))
        v.addModel(Chem.MolToMolBlock(mol_new), viewer=(j, i))
        v.setStyle({'model': 1}, {'stick': {'color': 'lightgreen', 'opacity':0.65}}, viewer=(j, i))
        v.setBackgroundColor('white', viewer=(j, i))
        v.zoomTo(viewer=(j, i))
v.render()


# Dataset RMSD & Energy

In [None]:
from rdkit.Chem import AllChem

supplier = Chem.SDMolSupplier('../Graph_Framework/data/qm9/raw/gdb9.sdf', sanitize=False, removeHs=False, strictParsing=False)
values = []
for i, mol in enumerate(tqdm(supplier)):
    try :
        Chem.SanitizeMol(mol)
        mol_new = mol.__copy__()

        mp = AllChem.MMFFGetMoleculeProperties(mol_new, mmffVariant='MMFF94')
        ff = AllChem.MMFFGetMoleculeForceField(mol_new, mp)
        E_0 = ff.CalcEnergy()
        ff.Minimize()
        E_min = ff.CalcEnergy()

        rms = AllChem.GetBestRMS(mol, mol_new)

        if np.isnan(E_0) or np.isnan(E_min): continue
        if E_0 > 1000000000: continue

        values.append([i, E_0, E_min, (E_0 - E_min) / mol.GetNumAtoms(), rms])
    except: continue
values = np.array(values)
print(values.shape)
values = values.transpose(1,0)
values.mean(-1), values.std(-1)