In [None]:
import os
import random
import scipy.spatial # very important, does not work without it, i don't know why
import scipy.stats
import numpy as np
import torch
from tqdm import tqdm
random.seed(0)
import torch
import numpy as np
import random
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, roc_curve
import pandas as pd
from tokenizers import Tokenizer

os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # to avoid warnings
1

## RDKIT


In [None]:


def get_fp(smiles: str):
    fp_obj = AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(smiles), radius=2, nBits=2048,
                                                   useChirality=False)
    return fp_obj

def get_fp_np(smiles: str):
    fp_obj = get_fp(smiles)
    fp = np.zeros((0,), dtype=np.int8)
    DataStructs.ConvertToNumpyArray(fp_obj, fp)
    return fp


def read_csv(path: str):
    df = pd.read_csv(path)
    non_chiral_smiles = df.iloc[:, 1].tolist()
    backbones = df.iloc[:, 2].tolist()
    chains = df.iloc[:, 3].tolist()
    assay_ids = df.iloc[:, 4].tolist()
    types = df.iloc[:, 5].tolist()
    labels = df.iloc[:, 6].tolist()
    return non_chiral_smiles, backbones, chains, assay_ids, types, labels


def calc_tani_sim(mol1_smiles, mol2_smiles):
    mol1 = Chem.MolFromSmiles(mol1_smiles)
    mol2 = Chem.MolFromSmiles(mol2_smiles)
    fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, radius=2, nBits=2048, useChirality=False)
    fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, radius=2, nBits=2048, useChirality=False)
    tani_sim = DataStructs.FingerprintSimilarity(fp1, fp2, metric=DataStructs.TanimotoSimilarity)
    return tani_sim


def get_canonical_smiles(smiles, chirality=True):
    chiral = 1 if chirality else 0
    return Chem.CanonSmiles(smiles, useChiral=chiral)


def load_tokenizer_from_file(file_path: str):
    return Tokenizer.from_file(file_path)


def smiles_valid(smiles,verbose=False):
    if smiles is None:
        return False
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return True
    print(smiles)
    return False


In [None]:
def split_into_chunks(lst, n):
    """Split list into n chunks as evenly as possible."""
    k, m = divmod(len(lst), n)
    return [lst[i*k + min(i, m):(i+1)*k + min(i+1, m)] for i in range(n)]

def create_cross_splits(lst, num_chunks=10):
    """Create cross splits: each split is (all_but_one_chunk, one_chunk)."""
    lst = list(lst)
    random.shuffle(lst)
    
    chunks = split_into_chunks(lst, num_chunks)
    splits = []

    for i in range(num_chunks):
        held_out = chunks[i]
        rest = [item for j, chunk in enumerate(chunks) if j != i for item in chunk]
        splits.append((rest, held_out))
    
    return splits

In [None]:
cpus = torch.get_num_threads()

def get_many_fps(smiles):
    with torch.multiprocessing.Pool(cpus) as pool:
        fps = pool.map(get_fp_np, smiles)
    return fps

def get_sims(args):
    org, opt = args
    fps = list(map(get_fp, [org, opt]))
    org_fp, opt_fp = fps[0], fps[1]
    sims = DataStructs.FingerprintSimilarity(org_fp, opt_fp, metric=DataStructs.TanimotoSimilarity)
    return sims

def get_many_sims(smiles_lists):
    with torch.multiprocessing.Pool(cpus) as pool:
        sims = pool.map(get_sims, smiles_lists)
    return sims    

def get_many_canons(smiles_lists):
    smiles = sum(smiles_lists, [])
    with torch.multiprocessing.Pool(cpus) as pool:
        canons = pool.map(get_canonical_smiles, smiles)
    canons_lists = []
    for l in smiles_lists:
        canons_lists.append(canons[:len(l)])
        canons = canons[len(l):]
    return canons


def get_clf(positive_smiles, negative_smiles, test_fraction=0.2, num_estimators=100, max_depth=2):
    positive_fp = get_many_fps(positive_smiles)
    negative_fp = get_many_fps(negative_smiles)
    # random.shuffle(positive_fp)
    # random.shuffle(negative_smiles)
    
    pos_ratio = len(positive_fp) / (len(negative_fp) + len(positive_fp))
    num_test = int(test_fraction * (len(negative_fp) + len(positive_fp)))
    pos_test_fp = positive_fp[:int(pos_ratio*num_test)]
    pos_train_fp = positive_fp[int(pos_ratio * num_test):]
    neg_test_fp = negative_fp[:int((1-pos_ratio)*num_test)]
    neg_train_fp = negative_fp[int((1-pos_ratio) * num_test):]
    X = np.stack(pos_train_fp + neg_train_fp, axis=0)
    y = np.concatenate([np.ones(len(pos_train_fp)), np.zeros(len(neg_train_fp))])
    sample_weights = [1 if cur_label == 0 else len(neg_train_fp) / len(pos_train_fp) for cur_label in y]
    clf = RandomForestClassifier(max_depth=max_depth, random_state=0, n_estimators=num_estimators)
    clf.fit(X, y, sample_weight=sample_weights)
    positives = np.ones(len(pos_test_fp))
    negatives = np.zeros(len(neg_test_fp))
    labels = np.concatenate([positives, negatives], axis=0)
    test_samples = np.concatenate([pos_test_fp, neg_test_fp], axis=0)
    probs = clf.predict_proba(test_samples)
    
    roc_auc = roc_auc_score(labels, probs[:, 1])
    fpr, tpr, thresholds = roc_curve(labels, probs[:, 1], pos_label=1)
    return clf, roc_auc, fpr, tpr, thresholds, thresholds[np.argmax(tpr - fpr)]



## DATA

In [None]:
def load_data(path):
    data = {}
    for smile, _, _, assay_id, _, label in zip(*read_csv(path)):
        if assay_id not in data:
            data[assay_id] = {
                'active': [],
                'inactive': []
            }    
        if label == 1:
            data[assay_id]['active'].append(smile)
        else:
            data[assay_id]['inactive'].append(smile)
    return data

def load_clfs(assay_ids, data):
    clfs = {}
    for name in tqdm(assay_ids):
        print(f'-------{name}--------------')
        clfs[name] = []
        roc_aucs = []
        for (active_train, active_test), (inactive_train, inactive_test) in zip(
            create_cross_splits(data[name]['active'], 10), 
            create_cross_splits(data[name]['inactive'], 10)
        ):
            clf, roc_auc, fpr, tpr, thresholds, best_thresh = get_clf(active_train, inactive_train)
            # res = get_clf(data[name]['active'], data[name]['inactive'])
            clfs[name].append(dict(clf=clf,thresh=best_thresh,unseen_active=set(active_test), unseen_inactive=set(inactive_test)))
            roc_aucs.append(roc_auc)
        print('roc:', sum(roc_aucs)/len(roc_aucs))
    return clfs

In [None]:

from collections import defaultdict


def score_smiles(clf,smiles):
    smiles_dct = defaultdict(list)
    for i,sml in enumerate(smiles):
        smiles_dct[sml].append(i)
    smiles_unq = list(smiles_dct.keys())
    fps = get_many_fps(smiles_unq)
    fps = np.stack(fps, axis=0)
    scores = clf.predict_proba(fps)[:,1]
    res = [None]*(len(smiles))
    for score,sml in zip(scores,smiles_unq):
        for i in smiles_dct[sml]:
            res[i] = score
    return res


def sample_from_file(path):
    with open(path) as f:
        all_pairs = []
        for line in f.readlines()[1:]:
            if "None" in line:
                continue
            parts = line.replace('\n', '').replace(',',' ').split(' ')
            all_pairs.append((parts[0], parts[1]))
    opt_samples = {}
    for pair in all_pairs:
        if pair[0] in opt_samples:
            opt_samples[pair[0]].append(pair[1])
            opt_samples[pair[0]] = opt_samples[pair[0]][:20]
        else:
            opt_samples[pair[0]] = [pair[1]]
    return opt_samples


In [None]:
tasks_names = [
    "CHEMBL1119333",
    "CHEMBL1614027",
    "CHEMBL1614423",
    "CHEMBL1738485",
    "CHEMBL1963715",
    "CHEMBL1963723",
    "CHEMBL1963731",
    "CHEMBL1963741",
    "CHEMBL1963756",
    "CHEMBL1963810",
    "CHEMBL1963818",
    "CHEMBL1963819",
    "CHEMBL1963824",
    "CHEMBL1963825",
    "CHEMBL1963827",
    "CHEMBL1963831",
    "CHEMBL1964101",
    "CHEMBL1964115",
    "CHEMBL3214944",
    "CHEMBL3431930",
    "CHEMBL3431932",
]
df = pd.read_csv('./data/splits/test.csv')
full_tasks = df[['assay_id','non_chiral_smiles']][df['active']==0]
full_tasks = full_tasks[full_tasks['assay_id'].isin(tasks_names)]
full_tasks = full_tasks.rename(columns={'assay_id': 'Assay', 'non_chiral_smiles': 'Org'})
index = full_tasks.groupby(['Assay','Org']).agg('max').index

seen_smls = set(pd.read_csv('./data/splits/train.csv')['non_chiral_smiles'])


In [None]:
test_data= load_data('./data/splits/test.csv')
clfs = load_clfs(tasks_names, test_data)


In [None]:
for i in range(10):
    print(len(clfs[tasks_names[0]][i]['unseen_inactive']))

In [None]:
def get_generated(path):
    generated_smiles = {}
    for task in tasks_names:
        generated_smiles[task] = sample_from_file(path % task)
        canon_lists = get_many_canons(list(generated_smiles[task].values()))
        for org_mol, canons in zip(generated_smiles[task], canon_lists):
            generated_smiles[task][org_mol] = canons
    return generated_smiles

def score_mols_df(generated_smiles):
    columns =['Assay', 'Org', 'Opt', 'Score_Org',  'Thresh_Org', 'Score_Opt','Thresh_Opt', 'Sim']
    rows = []
    for task in tqdm(tasks_names):
        for clf_dict in clfs[task]:
            optimized_mols = []
            origin_mols = []
            for org_mol in clf_dict['unseen_inactive']:
                if org_mol not in generated_smiles[task]:
                    continue
                for opt_mol in generated_smiles[task][org_mol]:
                    origin_mols.append(org_mol)
                    optimized_mols.append(opt_mol)
        # print(f'Assay: {task}')
            clf = clf_dict['clf']
            thresh = clf_dict['thresh']
            scores = score_smiles(clf, origin_mols + optimized_mols)
            origin_scores = scores[:len(origin_mols)]
            optimized_scores = scores[len(origin_mols):]
            
            sims = get_many_sims(list(zip(origin_mols, optimized_mols)))
            
            for org_mol, opt_mol, org_score, opt_score, sim in zip(origin_mols, optimized_mols, origin_scores, optimized_scores, sims):
                
                rows.append([task, org_mol, opt_mol, org_score, org_score > thresh, opt_score, opt_score > thresh, sim])
    df = pd.DataFrame(rows, columns=columns)
    return df

def print_df(df):
    print('Index', *df.columns)
    for idx, row in zip(df.index, df.values):
        print(idx, *['%.3f'%v for v in row], sep=' & ', end='\\\\\n')


In [None]:
from scipy.stats import binom, chi2

def mcnemar(l1, l2, continuity_correction: bool = False) -> float:
    l1 = np.array(l1, dtype=bool)
    l2 = np.array(l2, dtype=bool)
    b = int(np.sum(l1 & (~l2)))  # count of positives in l1 and negatives in l2
    c = int(np.sum((~l1) & l2))  # count of positives in l2 and negatives in l1
    
    check_valid = lambda n: isinstance(n, int) or (isinstance(n, float) and n.is_integer())
    if not all(map(check_valid, [b, c])):
        raise ValueError("b and c must be integers!")
    n_min, n_max = sorted([b, c])
    corr = int(continuity_correction)
    if (n_min + n_max) < 25:
        pvalue = 2 * binom.cdf(n_min, n_min+n_max, 0.5) - binom.pmf(n_min, n_min+n_max, 0.5)
    else:
        chi2_statistic = (abs(n_min - n_max) - corr) ** 2 / (n_min + n_max)
        pvalue = chi2.sf(chi2_statistic, 1)
    return pvalue

In [None]:
subtask_agg = 'max' # 'max' or 'mean', depending on how you want to aggregate the success metric

def calc_pvalue(succ1, succ2,assay=''):
    successes = succ1.join(succ2, lsuffix='_1', rsuffix='_2', how='inner')
    agg_results = {}
    for cur_assay, group in successes.groupby(level='Assay'):
        
        if assay and cur_assay != assay:
            continue
        l1 = group['Success_1'].values
        l2 = group['Success_2'].values
        pv = mcnemar(l1 ,l2, True)
        # pv = scipy.stats.ttest_rel(l1,l2).pvalue
        agg_results[cur_assay] = pv
    agg_df = pd.DataFrame.from_dict(agg_results, orient='index', columns=['pvalue'])  
    return agg_df

sim= 0.4
def calc_success(df):
    df['Success'] = df['Thresh_Opt'] & (df['Sim'] > sim) & (df.Org != df.Opt)
    return df


In [None]:
def prepare(df):
    calc_success(df)
    cols = ['Assay','Org','Success']
    df=df[cols].groupby(['Assay','Org']).agg(subtask_agg)
    df = df.reindex(index=index, fill_value=False)
    return df

def calc_summary(df):
    return df.groupby(['Assay']).agg('mean')

def calc_all(main_df, **kwargs):
    kwargs = dict(main=main_df, **kwargs)
    sums = {}
    ps = {}
    for name,df in kwargs.items():
        df = prepare(df)
        kwargs[name] = df
        sums[name] = calc_summary(df)['Success']
    scores = pd.concat(sums.values(), axis=1, keys=sums.keys())
    # return scores
    for assay, score_row in scores.iterrows():
        max_idx = score_row.idxmax()
        assay_p = []
        for name,df in kwargs.items():
            assay_p.append(calc_pvalue(kwargs[max_idx], df,assay=assay)['pvalue'][assay])
        ps[assay]=assay_p
    
    pss = pd.DataFrame.from_dict(ps, orient='index', columns=kwargs.keys())
    return scores, pss

def calc_div_nov_succ(seen_smls, **dfs):
    diversity = []
    novelty = []
    succs = []
    for name, df in dfs.items():
        novelty.append((~df['Opt'].isin(seen_smls)).mean())
        diversity.append((1-df['Sim']).mean())
        succs.append(prepare(df)['Success'].mean())
    data = {'metric': ['diversity','novelty','success']}
    for i, name in enumerate(dfs.keys()):
        data[name] = [diversity[i], novelty[i],succs[i]]
    div_nov_df = pd.DataFrame(data).set_index('metric')
    return div_nov_df

In [None]:
v4_gen = get_generated('experiments/modolo_margin_v2/2025_06_24_23_34/test_results/ModoloLightning/%s')
v4_df = score_mols_df(v4_gen)


In [None]:
res = calc_all(v4_df)[0]