In [None]:
# TODO: calculate novelty

In [24]:
import torch
from rdkit import Chem
from rdkit.Chem import QED, AllChem, rdFingerprintGenerator
from rdkit import rdBase
rdBase.DisableLog('rdApp.warning')
import os
import numpy as np
import pandas as pd
from tqdm import tqdm
import pickle
import hashlib
from rdkit import DataStructs
from rdkit.Chem import RDConfig
import sys
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
import matplotlib.pyplot as plt

In [25]:
def process_molecules(smiles_list, desc):
    """Process SMILES strings into RDKit molecules."""
    mols = []
    for smi in tqdm(smiles_list, desc=desc, unit="molecule"):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            mols.append(mol)
    return mols

In [26]:
def calculate_fingerprints(mols, radius=2, nBits=2048):
    """Calculate Morgan fingerprints for a list of molecules using MorganGenerator."""
    fingerprints = []
    morgan_gen = rdFingerprintGenerator.GetMorganGenerator(radius=radius)
    
    for mol in tqdm(mols, desc="Calculating fingerprints", unit="molecule"):
        fp = morgan_gen.GetFingerprint(mol)
        fingerprints.append(fp)
    
    return fingerprints

In [27]:
def calculate_diversity(fingerprints):
    print("Calculating pairwise diversities...")
    n = len(fingerprints)
    diversity = 0
    total_pairs = (n * (n - 1)) // 2
    
    with tqdm(total=total_pairs, desc="Calculating diversity", unit="pair") as pbar:
        for i in range(n):
            diversity += sum(1 - DataStructs.TanimotoSimilarity(fingerprints[i], fingerprints[j]) 
                             for j in range(i + 1, n))
            pbar.update(n - i - 1)
    
    return diversity / total_pairs if n > 1 else 0

In [28]:
def evaluate_generated_molecules(generated_smiles, specified_total):
    print("Starting molecule evaluation...")
    
    print("Processing generated molecules...")
    valid_mols = process_molecules(generated_smiles, "Processing generated molecules")
    gen_fps = calculate_fingerprints(valid_mols)
    
    # Calculate basic properties
    unique_smiles = set(Chem.MolToSmiles(mol) for mol in valid_mols)
    qed_scores = [QED.qed(mol) for mol in valid_mols]
    sas_scores = [sascorer.calculateScore(mol) for mol in valid_mols]
    
    validity = len(valid_mols) / specified_total
    uniqueness = len(unique_smiles) / len(generated_smiles) if generated_smiles else 0
    
    print("Calculating diversity...")
    diversity = calculate_diversity(gen_fps)
    
    qed_mean = sum(qed_scores) / len(qed_scores) if qed_scores else 0
    sas_mean = sum(sas_scores) / len(sas_scores) if sas_scores else 0
    
    return {
        "validity": validity,
        "uniqueness": uniqueness,
        "diversity": diversity,
        "qed_mean": qed_mean,
        "sas_mean": sas_mean,
    }

In [29]:
def load_smiles(filename):
    with open(filename, 'r') as f:
        return [line.strip() for line in f]

In [30]:
hybrid_smiles = load_smiles("Results/HYBRID_20M_dropout_little_10000_samples.txt")
safe_gpt_smiles = load_smiles("Results/SAFE_GPT_20M_10000_samples.txt")
ssm_smiles = load_smiles("Results/SSM_20M_little_dropout_10000_samples.txt")

In [31]:
models = {
    "MAMBA-HYBRID": hybrid_smiles,
    "SAFE-GPT": safe_gpt_smiles,
    "MAMBA": ssm_smiles
}

In [32]:
results = {}
for model_name, smiles_list in models.items():
    print(f"Evaluating {model_name}...")
    results[model_name] = evaluate_generated_molecules(smiles_list, 10000)

Evaluating MAMBA-HYBRID...
Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 11918.59molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 22482.16molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:35<00:00, 522834.66pair/s]


Evaluating SAFE-GPT...
Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 9991/9991 [00:00<00:00, 12637.88molecule/s]
Calculating fingerprints: 100%|██████████| 9991/9991 [00:00<00:00, 23385.11molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49905045/49905045 [01:36<00:00, 517041.62pair/s]


Evaluating MAMBA...
Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 12801.86molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 12837.63molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:37<00:00, 513322.82pair/s]


In [33]:
df_results = pd.DataFrame(results).T
df_results = df_results.round(3)

In [34]:
df_results

Unnamed: 0,validity,uniqueness,diversity,qed_mean,sas_mean
MAMBA-HYBRID,1.0,0.999,0.86,0.813,2.417
SAFE-GPT,0.999,0.999,0.859,0.81,2.424
MAMBA,1.0,0.998,0.86,0.81,2.42


| Model        | Validity | Uniqueness | Diversity | QED Mean | SAS Mean |
|--------------|----------|------------|-----------|----------|----------|
| MAMBA-HYBRID | 1.000    | 0.997      | 0.856     | 0.816    | 2.357    |
| SAFE-GPT     | 0.994    | 1.000      | 0.866     | 0.801    | 2.500    |
| MAMBA        | 1.000    | 0.996      | 0.855     | 0.820    | 2.357    |

MAMBA seems to perform just as well as safe-gpt in most cases.

In [36]:
import random

In [37]:
def run_evaluation_with_seeds(model_smiles, num_seeds=5, num_samples=10000):
    results = []
    for seed in tqdm(range(num_seeds), desc="Running seeds"):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        # Sample from the full set of SMILES strings
        sampled_smiles = random.sample(model_smiles, num_samples)
        
        # Run your existing evaluation
        result = evaluate_generated_molecules(sampled_smiles, num_samples)
        results.append(result)
    
    # Calculate mean and standard deviation across seeds
    mean_results = {k: np.mean([r[k] for r in results]) for k in results[0]}
    std_results = {k: np.std([r[k] for r in results]) for k in results[0]}
    
    return mean_results, std_results

In [38]:
multi_seed_results = {}
for model_name, smiles_list in models.items():
    print(f"Evaluating {model_name} with multiple seeds...")
    mean_results, std_results = run_evaluation_with_seeds(smiles_list)
    multi_seed_results[model_name] = {
        "mean": mean_results,
        "std": std_results
    }

Evaluating MAMBA-HYBRID with multiple seeds...


Running seeds:   0%|          | 0/5 [00:00<?, ?it/s]

Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:01<00:00, 8804.74molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 20117.08molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:43<00:00, 482369.76pair/s]
Running seeds:  20%|██        | 1/5 [02:10<08:42, 130.63s/it]

Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 12539.41molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 13490.75molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:38<00:00, 506534.11pair/s]
Running seeds:  40%|████      | 2/5 [04:15<06:22, 127.50s/it]

Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 11316.29molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 22083.69molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:38<00:00, 506223.80pair/s]
Running seeds:  60%|██████    | 3/5 [06:21<04:13, 126.68s/it]

Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 12381.28molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 22465.80molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:40<00:00, 496391.11pair/s]
Running seeds:  80%|████████  | 4/5 [08:28<02:06, 126.90s/it]

Starting molecule evaluation...
Processing generated molecules...


Processing generated molecules: 100%|██████████| 10000/10000 [00:00<00:00, 10809.50molecule/s]
Calculating fingerprints: 100%|██████████| 10000/10000 [00:00<00:00, 23760.53molecule/s]


Calculating diversity...
Calculating pairwise diversities...


Calculating diversity: 100%|██████████| 49995000/49995000 [01:39<00:00, 503392.62pair/s]
Running seeds: 100%|██████████| 5/5 [10:34<00:00, 126.94s/it]


Evaluating SAFE-GPT with multiple seeds...


Running seeds:   0%|          | 0/5 [00:00<?, ?it/s]


ValueError: Sample larger than population or is negative

In [None]:
# Create a DataFrame for mean results
df_mean = pd.DataFrame({model: results["mean"] for model, results in multi_seed_results.items()}).T
df_mean = df_mean.round(3)

# Create a DataFrame for standard deviations
df_std = pd.DataFrame({model: results["std"] for model, results in multi_seed_results.items()}).T
df_std = df_std.round(3)

In [None]:
print("Mean Results:")
print(df_mean)
print("\nStandard Deviations:")
print(df_std)

# Visualize results with error bars
metrics = list(df_mean.columns)
num_metrics = len(metrics)
num_models = len(df_mean)

In [None]:
fig, axes = plt.subplots(nrows=(num_metrics + 1) // 2, ncols=2, figsize=(15, 5 * ((num_metrics + 1) // 2)))
axes = axes.flatten()

for i, metric in enumerate(metrics):
    ax = axes[i]
    means = df_mean[metric]
    stds = df_std[metric]
    
    ax.bar(range(num_models), means, yerr=stds, capsize=5)
    ax.set_title(metric)
    ax.set_xticks(range(num_models))
    ax.set_xticklabels(df_mean.index, rotation=45)
    ax.set_ylabel('Value')

plt.tight_layout()
plt.show()