In [1]:
import math
from copy import deepcopy
from itertools import product

from rdkit import Chem, DataStructs

import sqlite3
import pandas as pd

import numpy as np
import os
import glob

In [2]:
def read_db_data_in_folder(folder_path: str) -> pd.DataFrame:
    """
    Reads all data from the `results` table of all SQLite databases in a specified folder.

    Parameters
    ----------
    folder_path: str
        Path to the folder containing sqlite3 database files.

    Returns
    -------
    pd.DataFrame
        A combined dataframe containing rows from the `results` table 
        from all found databases in the folder. Returns an empty dataframe if none of the 
        tables exist or other errors.
    """
    # Step 1: Identify all .db files in the directory
    db_files = glob.glob(os.path.join(folder_path, "*.db"))

    combined_data = []
    columns = None

    for db_path in db_files:
        conn = sqlite3.connect(db_path)
        
        try:
            df = pd.read_sql("SELECT * FROM results", conn)
            
            if columns is None:
                columns = df.columns.tolist()
            
            combined_data.append(df)
        except sqlite3.OperationalError as e:
            if "no such table: results" in str(e):
                print(f"The table 'results' does not exist in the database at path: {db_path}")
            else:
                raise e
        finally:
            conn.close()

    if not combined_data:
        return pd.DataFrame(columns=columns or [])
    
    return pd.concat(combined_data, ignore_index=True)


In [3]:
df = read_db_data_in_folder("logs/debug_run_seh_frag/final")

In [4]:
# Drop duplicates in the SMILES column
df = df.drop_duplicates(subset=["smi"])

df["mol"] = df["smi"].apply(lambda x: Chem.MolFromSmiles(x))
#df["fp"] = df["smi"].apply(lambda x: Chem.RDKFingerprint(Chem.MolFromSmiles(x)))
df

Unnamed: 0,smi,r,fr_0,ci_beta,mol
0,NS(=O)(=O)NCC(=O)O,0.089591,0.089591,6.038920,<rdkit.Chem.rdchem.Mol object at 0x7fbb7651d5f0>
1,OC(F)(F)F,0.079602,0.079602,29.778170,<rdkit.Chem.rdchem.Mol object at 0x7fbb7651d970>
2,CC(C)(C)c1c(CS)nc(S(=O)(=O)O)nc1C1C(C(=N)N)CC(...,0.167655,0.167655,46.800671,<rdkit.Chem.rdchem.Mol object at 0x7fbb7651d900>
3,CC(O)C1CCN(C(=O)[O-])CC1C(=O)NO,0.293675,0.293675,46.848484,<rdkit.Chem.rdchem.Mol object at 0x7fbb76539510>
4,N=C(N)Br,0.000100,0.000100,62.910297,<rdkit.Chem.rdchem.Mol object at 0x7fbb76539580>
...,...,...,...,...,...
315,N#COCn1cnc2c(-c3cc(C4CCCO4)c(C(=O)[O-])c(-c4cc...,0.362217,0.362217,31.835855,<rdkit.Chem.rdchem.Mol object at 0x7fbb764e7d60>
316,NN1CCN(C2CC2)CC1,0.204442,0.204442,19.714422,<rdkit.Chem.rdchem.Mol object at 0x7fbb764e7dd0>
317,CN(C)C1CC(c2nccs2)C(C2CCCN2CNC=N)C(c2cc(C(F)(F...,0.376826,0.376826,57.953789,<rdkit.Chem.rdchem.Mol object at 0x7fbb764e7e40>
318,N=C(N)n1cc(CS)c(=O)[nH]c1=O,0.240088,0.240088,40.979397,<rdkit.Chem.rdchem.Mol object at 0x7fbb764e7eb0>


In [6]:
smiles = df["smi"].tolist()
rewards = df["r"].tolist()
mols = df["mol"].tolist()

candidates = list(zip(rewards, smiles, mols))
candidates = sorted(candidates, key=lambda m: m[0], reverse=True)

In [7]:
def compute_diverse_top_k(candidates, k, thresh=0.7):
    modes = [candidates[0]]
    mode_fps = [Chem.RDKFingerprint(candidates[0][2])]
    for i in range(1, len(candidates)):
        fp = Chem.RDKFingerprint(candidates[i][2])
        sim = DataStructs.BulkTanimotoSimilarity(fp, mode_fps) 
        # if sim to any of the modes is less than thresh, add to modes
        if max(sim) < thresh:
            modes.append(candidates[i])
            mode_fps.append(fp)
        if len(modes) >= k:
            # last_idx = i
            break
    return np.mean([i[0] for i in modes])  # return sim


In [8]:
def get_topk(rewards,k):
    # Sort the rewards
    rewards = sorted(rewards, reverse=True)
    # Get the top k rewards
    topk_rewards = rewards[:k]
    # Return the mean of the top k rewards
    return np.mean(topk_rewards)

In [9]:
get_topk(rewards, k=100)

0.4532394626736641

In [10]:
compute_diverse_top_k(candidates, k=100, thresh=0.7)

0.4431504625082016