In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from safe.tokenizer import SAFETokenizer
from safe.trainer.model import SAFEDoubleHeadsModel
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import os
from molfeat.trans.pretrained.hf_transformers import HFModel
import datamol as dm
import safe as sf
import numpy as np
import pandas as pd
from tqdm import tqdm

Failed to find the pandas get_adjustment() function to patch
Failed to patch pandas - PandasTools will have limited functionality


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
# Load the trained model and tokenizer
model_checkpoint = "./checkpoint-1712000"
safe_model = SAFEDoubleHeadsModel.from_pretrained(model_checkpoint)

In [4]:
model_tokenizer = "./tokenizer.json"
safe_tokenizer = SAFETokenizer.from_pretrained(model_tokenizer)

In [5]:
# # Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
safe_model.to(device)

SAFEDoubleHeadsModel(
  (transformer): GPT2Model(
    (wte): Embedding(1880, 512)
    (wpe): Embedding(1024, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x GPT2Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=512, out_features=1880, bias=False)
  (multiple_choice_head): PropertyHead(
    (summary): Linear(in_features=512, out_features=128, bias=True)
    (

In [6]:
designer = sf.SAFEDesign(
    model=safe_model,
    tokenizer=safe_tokenizer,
    verbose=True,
)

In [7]:
# generated_smiles = designer.de_novo_generation(sanitize=True, n_samples_per_trial=10000, early_stopping=False)

In [8]:
# type(generated_smiles)

In [9]:
# Save the generated SMILES
# with open("generated_smiles_10k.md", "w") as f:
#     for smiles in generated_smiles:
#         f.write(smiles + "\n")

In [10]:
# Load the generated SMILES
generated_smiles = []
with open("generated_smiles_10k.md", "r") as f:
    for line in f:
        generated_smiles.append(line.strip())

In [11]:
from rdkit.Chem import QED, Crippen
from rdkit import DataStructs
from rdkit.Chem import RDConfig
import os
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
# now you can import sascore!
import sascorer

In [12]:
def calculate_diversity(molecules, radius=2, nBits=2048):
    print("Calculating fingerprints...")
    fingerprints = []
    for mol in tqdm(molecules, desc="Generating fingerprints", unit="molecule"):
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        fingerprints.append(fp)
    
    print("Calculating pairwise diversities...")
    diversity = 0
    n = len(fingerprints)
    total_pairs = (n * (n - 1)) // 2
    pair_counter = 0
    
    with tqdm(total=total_pairs, desc="Calculating diversity", unit="pair") as pbar:
        for i in range(n):
            for j in range(i + 1, n):
                similarity = DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j])
                diversity += 1 - similarity
                pair_counter += 1
                pbar.update(1)
    
    if n > 1:
        diversity /= total_pairs  # Normalize by the number of pairwise comparisons
    else:
        diversity = 0  # If only one molecule, diversity is zero
    
    return diversity

In [13]:
def calculate_novelty(generated_mols, reference_mols, radius=2, nBits=2048):
    ref_fps = [AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) for mol in reference_mols]
    gen_fps = [AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits) for mol in generated_mols]
    
    novel_count = 0
    for gen_fp in tqdm(gen_fps, desc="Calculating novelty", unit="molecule"):
        if all(DataStructs.TanimotoSimilarity(gen_fp, ref_fp) < 0.9 for ref_fp in ref_fps):
            novel_count += 1
    
    return novel_count / len(generated_mols) if generated_mols else 0

In [14]:
train_set = pd.read_csv("../../../train_from_scratch/Datasets/MOSES/train.csv")
test_set = pd.read_csv("../../../train_from_scratch/Datasets/MOSES/test.csv")
all_smiles = pd.concat([train_set, test_set])["SMILES"].unique()

moses_smiles = all_smiles.tolist()

In [15]:
def evaluate_generated_molecules(generated_smiles, reference_smiles):
    print("Starting molecule evaluation...")
    valid_mols = []
    unique_smiles = set()
    qed_scores = []
    sas_scores = []

    print("Converting reference SMILES to molecules...")
    reference_mols = []
    for smi in tqdm(reference_smiles, desc="Processing reference molecules", unit="molecule"):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            reference_mols.append(mol)
    
    print("Processing generated molecules...")
    for smi in tqdm(generated_smiles, desc="Processing generated molecules", unit="molecule"):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            valid_mols.append(mol)
            unique_smiles.add(Chem.MolToSmiles(mol))
            qed_scores.append(QED.qed(mol))
            sas_scores.append(sascorer.calculateScore(mol))
    
    validity = len(valid_mols) / len(generated_smiles) if generated_smiles else 0
    uniqueness = len(unique_smiles) / len(generated_smiles) if generated_smiles else 0
    
    print("Calculating diversity...")
    diversity = calculate_diversity(valid_mols)
    
    print("Calculating novelty...")
    novelty = calculate_novelty(valid_mols, reference_mols)
    
    qed_mean = sum(qed_scores) / len(qed_scores) if qed_scores else 0
    sas_mean = sum(sas_scores) / len(sas_scores) if sas_scores else 0
    
    return {
        "validity": validity,
        "uniqueness": uniqueness,
        "diversity": diversity,
        "novelty": novelty,
        "qed_mean": qed_mean,
        "sas_mean": sas_mean,
    }

In [16]:
evaluation_results = evaluate_generated_molecules(generated_smiles, moses_smiles)

Starting molecule evaluation...
Converting reference SMILES to molecules...


Processing reference molecules:  27%|██▋       | 477295/1760737 [00:37<01:34, 13570.69molecule/s]

In [None]:
print(f"Validity of paper is 1, computed is {evaluation_results['validity']:.3f}")
print(f"Uniqueness of paper is 0.999, computed is {evaluation_results['uniqueness']:.3f}")
print(f"Diversity of paper is 0.864, computed is {evaluation_results['diversity']:.3f}")
print(f"Novelty: {evaluation_results['novelty']:.3f}")
print(f"QED mean is {evaluation_results['qed_mean']:.3f}")
print(f"SAS mean is {evaluation_results['sas_mean']:.3f}")

Validity of paper is 1, computed is 1.0
Uniqueness of paper is 0.999, computed is 0.9990945674044266
Diversity of paper is 0.864, computed is 0.8649861835608464
QED mean is 0.8026260562951251
SAS mean is 2.47734016850271


In [None]:
with open("safe_small_results.md", "w") as f:
    f.write(f"Validity of paper is 1, computed is {evaluation_results['validity']:.3f}\n")
    f.write(f"Uniqueness of paper is 0.999, computed is {evaluation_results['uniqueness']:.3f}\n")
    f.write(f"Diversity of paper is 0.864, computed is {evaluation_results['diversity']:.3f}\n")
    f.write(f"Novelty: {evaluation_results['novelty']:.3f}\n")
    f.write(f"QED mean is {evaluation_results['qed_mean']:.3f}\n")
    f.write(f"SAS mean is {evaluation_results['sas_mean']:.3f}\n")

More nuanced evaluation

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, rdMolDescriptors, Lipinski
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
def smiles_to_mol(smiles_list):
    return [Chem.MolFromSmiles(smi) for smi in smiles_list if Chem.MolFromSmiles(smi) is not None]

In [None]:
generated_mols = smiles_to_mol(generated_smiles)
moses_mols = smiles_to_mol(moses_smiles)

In [None]:
def calculate_properties(mol_list):
    properties = {
        'MW': [], 'LogP': [], 'HBD': [], 'HBA': [], 'TPSA': [], 'RotBonds': [], 'QED': []
    }
    for mol in mol_list:
        properties['MW'].append(Descriptors.ExactMolWt(mol))
        properties['LogP'].append(Crippen.MolLogP(mol))
        properties['HBD'].append(Lipinski.NumHDonors(mol))
        properties['HBA'].append(Lipinski.NumHAcceptors(mol))
        properties['TPSA'].append(Descriptors.TPSA(mol))
        properties['RotBonds'].append(rdMolDescriptors.CalcNumRotatableBonds(mol))
        properties['QED'].append(Descriptors.qed(mol))
    return properties

In [None]:
generated_properties = calculate_properties(generated_mols)
moses_properties = calculate_properties(moses_mols)

In [None]:
def plot_property_distributions(gen_props, ref_props, property_name, xlabel, ylabel="Density", plot_type='line'):
    plt.figure(figsize=(10, 6))
    
    if plot_type == 'line':
        sns.kdeplot(gen_props, label='Generated', color='blue')
        sns.kdeplot(ref_props, label='MOSES', color='red')
    elif plot_type == 'bar':
        gen_counts, gen_bins = np.histogram(gen_props, bins=20)
        ref_counts, ref_bins = np.histogram(ref_props, bins=gen_bins)
        
        gen_counts = gen_counts / len(gen_props)
        ref_counts = ref_counts / len(ref_props)
        
        width = 0.35
        plt.bar(gen_bins[:-1], gen_counts, width, label='Generated', alpha=0.7, color='blue')
        plt.bar(ref_bins[:-1] + width, ref_counts, width, label='MOSES', alpha=0.7, color='red')
    
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(f'Distribution of {property_name}')
    plt.legend()
    plt.savefig(f'{property_name.lower().replace(" ", "_")}_distribution.png')
    plt.close()

In [None]:
def plot_property_boxplots(gen_props, ref_props):
    plt.figure(figsize=(15, 10))
    
    data = []
    for prop, name, _, _ in properties_to_plot:
        data.append({
            'Property': name,
            'Value': gen_props[prop],
            'Dataset': 'Generated'
        })
        data.append({
            'Property': name,
            'Value': ref_props[prop],
            'Dataset': 'MOSES'
        })
    
    df = pd.DataFrame(data)
    
    sns.boxplot(x='Property', y='Value', hue='Dataset', data=df)
    plt.xticks(rotation=45, ha='right')
    plt.title('Box Plots of Molecular Properties')
    plt.tight_layout()
    plt.savefig('molecular_properties_boxplots.png')
    plt.close()

In [None]:
properties_to_plot = [
    ('MW', 'Molecular Weight', 'Molecular Weight (Da)', 'line'),
    ('LogP', 'LogP', 'LogP', 'line'),
    ('HBD', 'H-Bond Donors', 'Number of H-Bond Donors', 'bar'),
    ('HBA', 'H-Bond Acceptors', 'Number of H-Bond Acceptors', 'bar'),
    ('TPSA', 'Topological Polar Surface Area', 'TPSA (Å²)', 'line'),
    ('RotBonds', 'Rotatable Bonds', 'Number of Rotatable Bonds', 'bar'),
    ('QED', 'QED', 'Quantitative Estimate of Drug-likeness', 'line')
]

for prop, name, xlabel, plot_type in properties_to_plot:
    plot_property_distributions(generated_properties[prop], moses_properties[prop], name, xlabel, plot_type=plot_type)


`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(gen_props, shade=True, label='Generated', color='blue')

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(ref_props, shade=True, label='MOSES', color='red')

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(gen_props, shade=True, label='Generated', color='blue')

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(ref_props, shade=True, label='MOSES', color='red')

`shade` is now deprecated in favor of `fill`; setting `fill=True`.
This will become an error in seaborn v0.14.0; please update your code.

  sns.kdeplot(gen_props, shade

In [None]:
plot_property_boxplots(generated_properties, moses_properties)

In [None]:
from tabulate import tabulate

def print_summary_stats(gen_props, ref_props):
    headers = ["Property", "Generated (mean ± std)", "MOSES (mean ± std)"]
    table_rows = []
    
    for prop, name, _ in properties_to_plot:
        gen_mean, gen_std = np.mean(gen_props[prop]), np.std(gen_props[prop])
        ref_mean, ref_std = np.mean(ref_props[prop]), np.std(ref_props[prop])
        row = [name, f"{gen_mean:.2f} ± {gen_std:.2f}", f"{ref_mean:.2f} ± {ref_std:.2f}"]
        table_rows.append(row)
    
    print(tabulate(table_rows, headers=headers, tablefmt="grid"))

In [None]:
print_summary_stats(generated_properties, moses_properties)

+--------------------------------+--------------------------+----------------------+
| Property                       | Generated (mean ± std)   | MOSES (mean ± std)   |
| Molecular Weight               | 310.36 ± 29.27           | 306.92 ± 28.05       |
+--------------------------------+--------------------------+----------------------+
| LogP                           | 2.47 ± 0.96              | 2.44 ± 0.93          |
+--------------------------------+--------------------------+----------------------+
| H-Bond Donors                  | 1.12 ± 0.84              | 1.12 ± 0.83          |
+--------------------------------+--------------------------+----------------------+
| H-Bond Acceptors               | 4.32 ± 1.43              | 4.22 ± 1.40          |
+--------------------------------+--------------------------+----------------------+
| Topological Polar Surface Area | 66.89 ± 18.70            | 65.83 ± 18.10        |
+--------------------------------+--------------------------+----

scaffold analysis

In [None]:
from rdkit.Chem.Scaffolds import MurckoScaffold

In [None]:
def analyze_scaffolds(mol_list):
    scaffolds = {}
    for mol in mol_list:
        scaffold = MurckoScaffold.GetScaffoldForMol(mol)
        scaffold_smiles = Chem.MolToSmiles(scaffold)
        scaffolds[scaffold_smiles] = scaffolds.get(scaffold_smiles, 0) + 1
    return scaffolds

In [None]:
generated_scaffolds = analyze_scaffolds(generated_mols)
moses_scaffolds = analyze_scaffolds(moses_mols)

In [None]:
print(f"Unique scaffolds in generated set: {len(generated_scaffolds)}")
print(f"Unique scaffolds in MOSES set: {len(moses_scaffolds)}")

Unique scaffolds in generated set: 7299
Unique scaffolds in MOSES set: 77215


In [None]:
# Scaffold diversity
gen_scaffold_diversity = len(generated_scaffolds) / len(generated_mols)
moses_scaffold_diversity = len(moses_scaffolds) / len(moses_mols)

print(f"Scaffold diversity in generated set: {gen_scaffold_diversity:.4f}")
print(f"Scaffold diversity in MOSES set: {moses_scaffold_diversity:.4f}")

Scaffold diversity in generated set: 0.7343
Scaffold diversity in MOSES set: 0.4385
