# Model Evaluation Report

In [None]:
from pathlib import Path

import dask.dataframe as dd
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Markdown, display
from rdkit.Chem import MolFromSmiles
from rdkit.Chem.Draw import MolsToGridImage
from selfies import decoder, encoder

from mol_gen.preprocessing.filter import DESCRIPTOR_TO_FUNCTION
from mol_gen.training.molecule_generator import MoleculeGenerator

In [None]:
train_dir = None
checkpoint_dir = None
string_lookup_config_filepath = None

In [None]:
training_data = dd.read_csv(
    Path(train_dir) / "*",
    names=["SELFIES"],
)

In [None]:
for model_filepath in Path(checkpoint_dir).glob("model.*.h5"):
    mol_generator = MoleculeGenerator.from_files(
        model_filepath, string_lookup_config_filepath
    )

    display(Markdown(f"## {model_filepath.name}"))

    mols = mol_generator.generate_molecules(1024)
    n_mols = len(mols)

    display(Markdown("### Duplicates"))
    unique_mols = set(mols)
    percent_unique = round(100 * (len(unique_mols) / n_mols))
    print(f"{percent_unique}% of molecules generated unique.")

    display(Markdown("### Novelty"))
    repeated = training_data.loc[training_data["SELFIES"].isin(unique_mols)].compute()
    percent_repeated = round(100 * (len(repeated) / n_mols))
    print(f"{percent_repeated}% of molecules generated novel.")

    display(Markdown("### Validity"))
    smiles = [decoder(i) for i in mols]
    valid_mols = [MolFromSmiles(i) for i in smiles if MolFromSmiles(i)]
    percent_valid = round(100 * (len(valid_mols) / n_mols))
    print(f"{percent_valid}% of molecules generated valid.")

    display(Markdown("### Properties Distribution"))
    for descriptor, func in DESCRIPTOR_TO_FUNCTION.items():
        values = [func(i) for i in valid_mols]
        plt.hist(values, bins=10)
        plt.xlabel(descriptor)
        plt.ylabel("count")
        plt.show()

    display(Markdown("### Subset"))
    mols_subset = np.random.choice(valid_mols, 100, replace=False)
    display(
        MolsToGridImage(mols_subset, subImgSize=(500, 500), molsPerRow=5, maxMols=50)
    )