In [28]:
import math
from copy import deepcopy
from itertools import product

from rdkit import Chem, DataStructs

import sqlite3
import pandas as pd

import numpy as np
import os
import glob

In [29]:
def read_db_data_in_folder(folder_path: str) -> pd.DataFrame:
    """
    Reads all data from the `results` table of all SQLite databases in a specified folder.

    Parameters
    ----------
    folder_path: str
        Path to the folder containing sqlite3 database files.

    Returns
    -------
    pd.DataFrame
        A combined dataframe containing rows from the `results` table 
        from all found databases in the folder. Returns an empty dataframe if none of the 
        tables exist or other errors.
    """
    # Step 1: Identify all .db files in the directory
    db_files = glob.glob(os.path.join(folder_path, "*.db"))

    combined_data = []
    columns = None

    for db_path in db_files:
        conn = sqlite3.connect(db_path)
        
        try:
            df = pd.read_sql("SELECT * FROM results", conn)
            
            if columns is None:
                columns = df.columns.tolist()
            
            combined_data.append(df)
        except sqlite3.OperationalError as e:
            if "no such table: results" in str(e):
                print(f"The table 'results' does not exist in the database at path: {db_path}")
            else:
                raise e
        finally:
            conn.close()

    if not combined_data:
        return pd.DataFrame(columns=columns or [])
    
    return pd.concat(combined_data, ignore_index=True)


In [30]:
df = read_db_data_in_folder("logs/debug_seh_frag/final")
len(df)

21248

In [31]:
# Drop duplicates in the SMILES column
df = df.drop_duplicates(subset=["smi"])

df["mol"] = df["smi"].apply(lambda x: Chem.MolFromSmiles(x))
#df["fp"] = df["smi"].apply(lambda x: Chem.RDKFingerprint(Chem.MolFromSmiles(x)))
df

Unnamed: 0,smi,r,fr_0,ci_beta,mol
0,CC(=O)Nc1c[nH]c(S(=O)(=O)[O-])c1,0.284038,0.284038,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425e959e0>
1,CC(=O)NC(=O)O,0.074949,0.074949,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425bd3e40>
2,O=c1nc2n(Br)c3cc(CNCS)c([SH](=O)=O)cc3nc-2c(=O...,0.422705,0.422705,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425bd3eb0>
3,CC(=O)NCc1csc(S(N)(=O)=O)n1,0.288808,0.288808,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425bd3d60>
4,CC(C)(O)N1C(S)C(COC=O)CC1C(F)(F)F,0.230196,0.230196,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425bd3f20>
...,...,...,...,...,...
21239,N#CC1=CC(C2=CN(C(=O)NCC=O)C=CC2)C(C2CNc3nc([NH...,0.558299,0.558299,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425aa14a0>
21242,CC(O)C1C(N2CCN(c3cn(-c4ncnc5[nH]cnc45)c(=O)[nH...,0.500449,0.500449,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425aa1510>
21243,C[SH+]C=C[n+]1cccc(C[NH3+])c1,0.134914,0.134914,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425aa1580>
21244,O=C([O-])COc1ccc(-c2csc([SH](=O)=O)n2)s1,0.321832,0.321832,1.0,<rdkit.Chem.rdchem.Mol object at 0x7f0425aa15f0>


In [32]:
smiles = df["smi"].tolist()
rewards = df["r"].tolist()
mols = df["mol"].tolist()

candidates = list(zip(rewards, smiles, mols))
candidates = sorted(candidates, key=lambda m: m[0], reverse=True)

In [33]:
def compute_diverse_top_k(candidates, k, thresh=0.7):
    modes = [candidates[0]]
    mode_fps = [Chem.RDKFingerprint(candidates[0][2])]
    for i in range(1, len(candidates)):
        fp = Chem.RDKFingerprint(candidates[i][2])
        sim = DataStructs.BulkTanimotoSimilarity(fp, mode_fps) 
        # if sim to any of the modes is less than thresh, add to modes
        if max(sim) < thresh:
            modes.append(candidates[i])
            mode_fps.append(fp)
        if len(modes) >= k:
            # last_idx = i
            break
    return np.mean([i[0] for i in modes])  # return sim


In [34]:
def get_topk(rewards,k):
    # Sort the rewards
    rewards = sorted(rewards, reverse=True)
    # Get the top k rewards
    topk_rewards = rewards[:k]
    # Return the mean of the top k rewards
    return np.mean(topk_rewards)

In [35]:
get_topk(rewards, k=100)

0.7189621144533157

In [36]:
compute_diverse_top_k(candidates, k=100, thresh=0.7)

0.7040923529863358