In [1]:
from collections import defaultdict
from functools import partial
import json
import multiprocessing as mp
from pathlib import Path
import random
import re
import sys
import time
from typing import Any

from google.cloud import storage
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
from rdkit import RDLogger
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig
from vertexai.tuning import sft

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet
from utils import ECFP_from_smiles, tanimoto_similarity, standardize, standardize_cxsmiles

[06:47:38] Initializing Normalizer


In [2]:
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)

In [3]:
TARGET_COLUMNS = ["HLM", "MLM", "LogD", "KSOL", "MDR1-MDCKII"]
PROPERTIES = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']
TOPK = 10

In [4]:
def get_all_topk_smiles_with_properties(ref_data, query_data, query2ref_dist, topk, property):
    for i in range(query2ref_dist.shape[0]):
        query_smiles = query_data.iloc[i]['cxsmiles_std']
        query_property = query_data.iloc[i][property]
        dist = query2ref_dist[i]
        order = np.argsort(dist)[::-1]
        ordered_dist = dist[order]
        order = order[~np.isclose(ordered_dist, 1.0)]  # remove self-similarity including duplicates
        topk_idx = order[:topk]
        topk_dist = ordered_dist[:topk]
        topk_smiles = ref_data.iloc[topk_idx]['cxsmiles_std'].tolist()
        topk_properties = ref_data.iloc[topk_idx][property].tolist()
        yield topk_smiles, topk_properties, topk_dist, query_smiles, query_property

In [5]:
for split_idx in range(5):
    data = pd.read_csv(f'../data/asap/datasets/rnd_splits/split_{split_idx}.csv')
    val = data[data['split'] == 'val'].reset_index(drop=True)
    val_ecfp = np.array(val['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    
    for property in PROPERTIES:
        ref_data = data[(data['split'] == 'train') & ~data[property].isna()].reset_index(drop=True)
        ref_ecfp = np.array(ref_data['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
        val2ref_dist = tanimoto_similarity(val_ecfp, ref_ecfp)

        predictions = [
            (topk_properties @ np.array(topk_dist)) / np.sum(topk_dist)
            # np.mean(topk_properties)
            for topk_smiles, topk_properties, topk_dist, query_smiles, query_property in get_all_topk_smiles_with_properties(
                ref_data=ref_data,
                query_data=val,
                query2ref_dist=val2ref_dist,
                topk=TOPK,
                property=property
            )
        ]

        val[f'pred_{property}'] = predictions
    
    val_preds = extract_preds(val, target_columns=PROPERTIES)
    val_refs = extract_refs(val, target_columns=PROPERTIES)
    metrics = eval_admet(val_preds, val_refs, target_columns=TARGET_COLUMNS)
    print(f'Split {split_idx}:')
    print(json.dumps(metrics, indent=2))

Split 0:
{
  "HLM": {
    "mean_absolute_error": 0.3339532573813491,
    "r2": 0.34242360307323816
  },
  "MLM": {
    "mean_absolute_error": 0.3998483762839397,
    "r2": 0.2585010012488018
  },
  "LogD": {
    "mean_absolute_error": 0.6864164999666937,
    "r2": 0.5134110174767418
  },
  "KSOL": {
    "mean_absolute_error": 0.3612120850436849,
    "r2": 0.3463204859294108
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.25280873003265086,
    "r2": 0.32686510610674124
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.4068477897416637,
    "macro_r2": 0.35750424276698684
  }
}
Split 1:
{
  "HLM": {
    "mean_absolute_error": 0.39612242437927675,
    "r2": 0.3889677977625817
  },
  "MLM": {
    "mean_absolute_error": 0.43791718928395623,
    "r2": 0.3915498949364441
  },
  "LogD": {
    "mean_absolute_error": 0.6243270226233567,
    "r2": 0.46012319742247576
  },
  "KSOL": {
    "mean_absolute_error": 0.34260903293878336,
    "r2": 0.38723269201586974
  },
  "MDR1-MDCKII": 