In [63]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import Descriptors, Crippen, rdMolDescriptors
from tqdm import tqdm

In [64]:
sns.set_context("paper", font_scale=1.5)
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.autolayout'] = True

In [65]:
def load_smiles(filename):
    with open(filename, 'r') as f:
        return [line.strip() for line in f]

In [66]:
def calculate_properties(smiles_list):
    properties = {
        'MW': [], 'LogP': [], 'HBA': [], 'HBD': [], 'TPSA': [], 
        'RotBonds': [], 'AromaticRings': [], 'QED': []
    }
    
    for smi in tqdm(smiles_list, desc="Calculating properties"):
        mol = Chem.MolFromSmiles(smi)
        if mol is not None:
            properties['MW'].append(Descriptors.ExactMolWt(mol))
            properties['LogP'].append(Crippen.MolLogP(mol))
            properties['HBA'].append(rdMolDescriptors.CalcNumHBA(mol))
            properties['HBD'].append(rdMolDescriptors.CalcNumHBD(mol))
            properties['TPSA'].append(Descriptors.TPSA(mol))
            properties['RotBonds'].append(rdMolDescriptors.CalcNumRotatableBonds(mol))
            properties['AromaticRings'].append(rdMolDescriptors.CalcNumAromaticRings(mol))
            properties['QED'].append(Descriptors.qed(mol))
    
    return pd.DataFrame(properties)

In [67]:
# models = {
#     "Mamba_Hybrid_Small": load_smiles("molecules/HYBRID_20M_dropout_little_10000_samples.txt"),
#     "Safe_Small": load_smiles("molecules/SAFE_GPT_20M_10000_samples.txt"),
#     "Mamba_Small": load_smiles("molecules/SSM_20M_little_dropout_10000_samples.txt")
# }

models = {
    "Mamba_Large": load_smiles("molecules/SSM_100M_10000_samples.txt"),
    "Safe_Large": load_smiles("molecules/SAFE_GPT_large_generated_smiles_10k.txt")
}

In [68]:
properties = {model: calculate_properties(smiles) for model, smiles in models.items()}

Calculating properties: 100%|██████████| 10000/10000 [00:13<00:00, 717.38it/s]
Calculating properties: 100%|██████████| 9805/9805 [00:15<00:00, 637.78it/s]


In [69]:
def load_and_sample_smiles(filename, sample_size=3000000):
    with open(filename, 'r') as f:
        smiles_list = [line.strip() for line in f]
    return np.random.choice(smiles_list, size=sample_size, replace=False)

In [70]:
# train_data = pd.read_csv('train.csv')
# train_properties = calculate_properties(train_data['SMILES'])
train_data_sampled = load_and_sample_smiles('zinc_train.csv', sample_size=3000000)
train_properties_sampled = calculate_properties(train_data_sampled)

Calculating properties: 100%|██████████| 3000000/3000000 [1:12:06<00:00, 693.37it/s]


In [71]:
# properties['MOSES (Train Dataset)'] = train_properties
properties['ZINC (Train Dataset)'] = train_properties_sampled

In [76]:
# MODEL_ORDER = ['SAFE-GPT', 'MAMBA', 'MAMBA-HYBRID']
# MODEL_ORDER = ['Safe_Small', 'Mamba_Small', 'Mamba_Hybrid_Small', 'MOSES (Train Dataset)']
MODEL_ORDER = ['Safe_Large', 'Mamba_Large', 'ZINC (Train Dataset)']
# COLOR_MAP = {
#     'SAFE-GPT': '#ff7f0e',
#     'MAMBA': '#2ca02c',
#     'MAMBA-HYBRID': '#1f77b4'
# }
# COLOR_MAP = {
#     'Safe_Small': '#ff7f0e',
#     'Mamba_Small': '#2ca02c',
#     'Mamba_Hybrid_Small': '#1f77b4',
#     'MOSES (Train Dataset)': '#d62728'
# }
COLOR_MAP = {
    'Safe_Large': '#1f77b4',
    'Mamba_Large': '#9467bd',
    'ZINC (Train Dataset)': '#d62728'
}

def plot_property_distribution(property_name, xlabel, ylabel="Proportion", is_discrete=False):
    plt.figure(figsize=(14, 8))
    
    if is_discrete:
        # For discrete properties, use a grouped bar plot
        all_data = []
        for model in MODEL_ORDER:
            df = properties[model]
            counts = df[property_name].value_counts().sort_index()
            proportions = counts / len(df)
            all_data.append(pd.DataFrame({'Model': model, 'Value': proportions.index, 'Proportion': proportions.values}))
        
        combined_data = pd.concat(all_data)
        
        sns.barplot(
            x='Value',
            y='Proportion',
            hue='Model',
            data=combined_data,
            palette=COLOR_MAP,
            hue_order=MODEL_ORDER
        )
        
        plt.xticks(rotation=0)
    else:
        # For continuous properties, use KDE plot
        for model in MODEL_ORDER:
            if model == 'MOSES (Train Dataset)' or model == 'ZINC (Train Dataset)':
                # Plot MOSES as a dashed line
                sns.kdeplot(
                    data=properties[model][property_name],
                    label=model,
                    color=COLOR_MAP[model],
                    linewidth=2,
                    linestyle='--'
                )
            else:
                # Plot other models as solid lines
                sns.kdeplot(
                    data=properties[model][property_name],
                    label=model,
                    color=COLOR_MAP[model],
                    linewidth=2.5,
                    linestyle='-'
                )
    
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(f"Distribution of {property_name}")
    
    # Adjust legend
    if is_discrete:
        plt.legend(title="Model", bbox_to_anchor=(1.05, 1), loc='upper left')
    else:
        plt.legend(title="Model")
    
    # Add grid for better readability
    plt.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    # plt.savefig(f"large_plots/{property_name.lower()}_distribution.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"large_plots/{property_name.lower()}_distribution_large.png", dpi=300, bbox_inches='tight')
    plt.close()

In [77]:
plot_property_distribution('MW', 'Molecular Weight (Da)')
plot_property_distribution('LogP', 'LogP')
plot_property_distribution('HBA', 'Number of Hydrogen Bond Acceptors', is_discrete=True)
plot_property_distribution('HBD', 'Number of Hydrogen Bond Donors', is_discrete=True)
plot_property_distribution('TPSA', 'Topological Polar Surface Area (Å²)')
plot_property_distribution('RotBonds', 'Number of Rotatable Bonds', is_discrete=True)
plot_property_distribution('AromaticRings', 'Number of Aromatic Rings', is_discrete=True)
plot_property_distribution('QED', 'Quantitative Estimate of Drug-likeness (QED)')