In [1]:
from functools import partial
import json
import multiprocessing as mp
from pathlib import Path
import re
import sys
import time

from google.cloud import storage
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import RDLogger
import vertexai
from vertexai.generative_models import GenerativeModel
from vertexai.tuning import sft

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet
from utils import ECFP_from_smiles, tanimoto_similarity, standardize, standardize_cxsmiles

[19:31:10] Initializing Normalizer


In [2]:
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)

In [3]:
TOPK = 10
PROPERTY = 'LogD'

In [None]:
data = pd.read_csv('../data/asap/datasets/rnd_splits/split_0.csv')
data

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.3,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val
1,O=C(NCC(F)F)[C@H](NC1=CC2=C(C=C1Br)CNC2)C1=CC(...,,333.0,2.9,,0.2,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,335,|&1:7|,,,2.523746,0.079181,train
2,O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=...,,,0.4,,0.5,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,336,|&1:7|,,,,0.176091,train
3,NC(=O)[C@H]1CCCN(C(=O)CC2=CC=CC3=C2C=CO3)C1 |&...,,376.0,1.0,,8.5,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1 |&1:3|,300,|&1:3|,,,2.576341,0.977724,train
4,CC1=CC(CC(=O)N2CCC[C@H](C(N)=O)C2)=CC=N1 |&1:11|,,375.0,-0.3,,0.9,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1 |&1:11|,249,|&1:11|,,,2.575188,0.278754,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
429,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val
430,O=C(O)CC1=CC=CC=C1NC1=C(Cl)C=CC=C1Cl,216.0,,,386.0,,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,380,,2.336460,2.587711,,,val
431,NCC1=CC(Cl)=CC(C(=O)NC2=CC=C3CNCC3=C2)=C1,,,2.0,,,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,303,,,,,,train
432,COC(=O)NC1=NC2=CC=C(C(=O)C3=CC=CC=C3)C=C2N1,,,2.9,,,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,166,,,,,,train


In [5]:
train = data[data['split'] == 'train'].reset_index(drop=True)
val = data[data['split'] == 'val'].reset_index(drop=True)
ref_data = data[(data['split'] == 'train') & ~data[PROPERTY].isna()].reset_index(drop=True)
ecfp = np.array(data['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
train_ecfp = ecfp[data['split'] == 'train']
val_ecfp = ecfp[data['split'] == 'val']
ref_ecfp = ecfp[ref_data.index]
train2ref_dist = tanimoto_similarity(train_ecfp, ref_ecfp)
val2ref_dist = tanimoto_similarity(val_ecfp, ref_ecfp)
train2ref_dist.shape, val2ref_dist.shape

((347, 277), (87, 277))

In [23]:
def get_all_topk_smiles_with_properties(ref_data, query_data, query2ref_dist, topk=TOPK, property=PROPERTY):
    for i in range(query2ref_dist.shape[0]):
        query_smiles = query_data.iloc[i]['cxsmiles_std']
        query_property = query_data.iloc[i][property]
        if np.isnan(query_property):
            continue

        dist = query2ref_dist[i]
        topk_idx = np.argsort(dist)[:topk][::-1]
        topk_smiles = ref_data.iloc[topk_idx]['cxsmiles_std'].tolist()
        topk_properties = ref_data.iloc[topk_idx][property].tolist()
        yield topk_smiles, topk_properties, query_smiles, query_property

In [7]:
def format_data(
    input_smiles: list[str],
    input_properties: list[float],
    query_smiles: str,
    query_property: float | None = None,
    system_instruction: str | None = None
):
    if system_instruction is not None:
        output = {
            "systemInstruction": {
            "role": "Ignored",
            "parts": [
            {
                "text": system_instruction
            }
            ]
        }
    }
    else:
        output = {}
    
    output.update({
        "contents": [
            {
            "role": "user",
            "parts": [
                {
                    "text": f"SMILES: {smi}, {PROPERTY}: {prop:.2f}",
                }
                for smi, prop in zip(input_smiles, input_properties)
            ] + [
                {
                    "text": f"Determine {PROPERTY} of {query_smiles}"
                }
            ]
            }
        ]
    })

    if query_property is not None:
        output['contents'].append({
            "role": "model",
            "parts": [
                {
                    "text": f"Answer: {query_property:.2f}"
                }
            ]
        })
    
    return output

In [8]:
SYSTEM_INSTRUCTIONS = {
    "LogD": \
        "You are an experienced medicinal chemist who worked many years with solubility testing of compounds. "
        "You can determine the LogD of a molecule based on its structure and properties of similar molecules from a reference set. "
        "You are given a list of ten reference molecules represented in SMILES paired with their LogD values. "
        "Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, "
        "it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means "
        "that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. "
        "Your task is to determine the LogD of the molecule represented in SMILES. "
        "Answer only with a single floating point number in the \"Answer: number\" format, e.g. \"Answer: 2.7\". "
        "Don't describe your solution and don't put any other text in your answer."
}

In [9]:
SYSTEM_INSTRUCTIONS['LogD']

'You are an experienced medicinal chemist who worked many years with solubility testing of compounds. You can determine the LogD of a molecule based on its structure and properties of similar molecules from a reference set. You are given a list of ten reference molecules represented in SMILES paired with their LogD values. Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. Your task is to determine the LogD of the molecule represented in SMILES. Answer only with a single floating point number in the "Answer: number" format, e.g. "Answer: 2.7". Don\'t describe your solution and don\'t put any other text in your answer.'

In [10]:
train_dataset = [
    format_data(
        input_smiles=input_smiles,
        input_properties=input_properties,
        query_smiles=query_smiles,
        query_property=query_property,
        system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
    )
    for input_smiles, input_properties, query_smiles, query_property in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=train,
        query2ref_dist=train2ref_dist,
        topk=TOPK,
        property=PROPERTY
    )
]

In [11]:
val_dataset = [
    format_data(
        input_smiles=input_smiles,
        input_properties=input_properties,
        query_smiles=query_smiles,
    )
    for input_smiles, input_properties, query_smiles, _ in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=val,
        query2ref_dist=val2ref_dist,
        topk=TOPK,
        property=PROPERTY
    )
]

In [10]:
data_path = Path("../data/asap/datasets/rnd_splits/split_0/gemini_train.jsonl")
data_path.parent.mkdir(parents=True, exist_ok=True)
gs_uri = "<gs_bucket>/data/asap/datasets/rnd_splits/split_0/gemini_train.jsonl"

with open(data_path, 'w') as f:
    for item in train_dataset:
        f.write(json.dumps(item) + '\n')

storage_client = storage.Client()
bucket_name, key = gs_uri[5:].split("/", 1)
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
blob.upload_from_filename(data_path)

In [12]:
PROJECT_ID = "bioptic-io"
LOCATION = "us-central1"
vertexai.init(project=PROJECT_ID, location=LOCATION)

In [11]:
sft_tuning_job = sft.train(
    source_model="gemini-1.5-pro-002",
    train_dataset=gs_uri,
)

# Polling for job completion
while not sft_tuning_job.has_ended:
    time.sleep(60)
    sft_tuning_job.refresh()

print(sft_tuning_job.tuned_model_name)
print(sft_tuning_job.tuned_model_endpoint_name)
print(sft_tuning_job.experiment)
# Example response:
# projects/123456789012/locations/us-central1/models/1234567890@1
# projects/123456789012/locations/us-central1/endpoints/123456789012345
# <google.cloud.aiplatform.metadata.experiment_resources.Experiment object at 0x7b5b4ae07af0>

Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/7630894387945275392
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/7630894387945275392')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/7630894387945275392?project=199759238457


KeyboardInterrupt: 

In [13]:
TUNING_JOB_ID = "7630894387945275392"
sft_tuning_job = sft.SupervisedTuningJob(f"projects/{PROJECT_ID}/locations/{LOCATION}/tuningJobs/{TUNING_JOB_ID}")
tuned_model = GenerativeModel(
    sft_tuning_job.tuned_model_endpoint_name,
    system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
)

In [14]:
for content in val_dataset:
    result = tuned_model.generate_content(content['contents'])
    print(result)
    break

candidates {
  content {
    role: "model"
    parts {
      text: "Answer: -0.60\n\n"
    }
  }
  finish_reason: STOP
  avg_logprobs: -0.23872767388820648
}
usage_metadata {
  prompt_token_count: 818
  candidates_token_count: 8
  total_token_count: 826
  prompt_tokens_details {
    modality: TEXT
    token_count: 818
  }
  candidates_tokens_details {
    modality: TEXT
    token_count: 8
  }
}
create_time {
  seconds: 1741288408
  nanos: 360613000
}
response_id: "2PPJZ6WBFtyHm9IPh_jk8Q8"



In [26]:
def get_answer(model, content):
    result = model.generate_content(content['contents'])
    answer = result.candidates[0].content.parts[0].text
    try:
        answer = float(re.search(r'Answer: (\d+\.\d+)', answer).group(1))
    except Exception as e:
        print(answer)
        return None
    
    return answer

In [19]:
predictions = list(tqdm(map(partial(get_answer, tuned_model), val_dataset), total=len(val_dataset)))

100%|██████████| 67/67 [00:50<00:00,  1.33it/s]


In [20]:
val.loc[val[PROPERTY].notna(), f'pred_{PROPERTY}'] = predictions
val

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split,pred_LogD
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.30,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val,0.06
1,CC(C)NC(=O)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Cl)=C...,,134.0,2.80,11.0,0.2,CC(C)NC(=O)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2[...,CC(C)NC(=O)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Cl)cc2[...,19,|o1:6|,,1.079181,2.130334,0.079181,val,3.20
2,O=C(NC1=CC=C2CNCC2=C1)C1=CC(F)=CC2=C1N=C(C1=CC...,,6.0,2.90,36.8,0.1,O=C(Nc1ccc2c(c1)CNC2)c1cc(F)cc2[nH]c(-c3ccc(F)...,O=C(Nc1ccc2c(c1)CNC2)c1cc(F)cc2[nH]c(-c3ccc(F)...,369,,,1.577492,0.845098,0.041393,val,1.80
3,N#CC1=CC2=C(C=C1NC(=O)C1=CC(F)=CC3=C1C=NN3)CNC2,16.7,193.0,0.60,288.0,3.9,N#Cc1cc2c(cc1NC(=O)c1cc(F)cc3[nH]ncc13)CNC2,N#Cc1cc2c(cc1NC(=O)c1cc(F)cc3[nH]ncc13)CNC2,298,,1.247973,2.460898,2.287802,0.690196,val,0.70
4,CO[C@H]1C[C@H](N2N=CC3=C(C(=O)NC4=CC=C5CNCC5=C...,,340.0,1.50,,1.6,CO[C@H]1C[C@H](n2ncc3c(C(=O)Nc4ccc5c(c4)CNC5)c...,CO[C@H]1C[C@H](n2ncc3c(C(=O)Nc4ccc5c(c4)CNC5)c...,173,,,,2.532754,0.414973,val,1.80
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82,CC(C)[C@H](CO)NC1=NC=NC2=C1C=CN2,6.0,,1.69,10.0,13.1,CC(C)[C@H](CO)Nc1ncnc2[nH]ccc12,CC(C)[C@H](CO)Nc1ncnc2[nH]ccc12,24,,0.845098,1.041393,,1.149219,val,0.90
83,C=CC(=O)NC1=CC=CC(N(CC2=CC=CC(Cl)=C2)C(=O)CC2=...,1070.0,24.7,3.80,2380.0,8.0,C=CC(=O)Nc1cccc(N(Cc2cccc(Cl)c2)C(=O)Cc2cncc3c...,C=CC(=O)Nc1cccc(N(Cc2cccc(Cl)c2)C(=O)Cc2cncc3c...,6,,3.029789,3.376759,1.409933,0.954243,val,3.60
84,C=CC(=O)N1CCCC2=CC=C(N(CC3=CC=CC(C(F)(F)F)=C3)...,614.0,,,2150.0,,C=CC(=O)N1CCCc2ccc(N(Cc3cccc(C(F)(F)F)c3)C(=O)...,C=CC(=O)N1CCCc2ccc(N(Cc3cccc(C(F)(F)F)c3)C(=O)...,4,,2.788875,3.332640,,,val,
85,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val,


In [12]:
save_path = Path("../output/asap/rnd_splits/gemini/run_0/split_0/val_predictions.csv")
save_path.parent.mkdir(parents=True, exist_ok=True)
# val.to_csv(save_path, index=False)
val = pd.read_csv(save_path)

In [20]:
val_preds = extract_preds(val, target_columns=[PROPERTY])
val_refs = extract_refs(val, target_columns=[PROPERTY])
metrics = eval_admet(val_preds, val_refs, target_columns=[PROPERTY])
print(json.dumps(metrics, indent=2))

{
  "LogD": {
    "mean_absolute_error": 0.49522388059701505,
    "r2": 0.6178674152052184
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.49522388059701505,
    "macro_r2": 0.6178674152052184
  }
}


In [19]:
val_tmp = val.copy()
val_tmp['pred_LogD'] = val_tmp[val_tmp['LogD'].notna()]['LogD'].mean()
val_preds = extract_preds(val_tmp, target_columns=[PROPERTY])
val_refs = extract_refs(val_tmp, target_columns=[PROPERTY])
eval_admet(val_preds, val_refs, target_columns=[PROPERTY])

defaultdict(dict,
            {'LogD': {'mean_absolute_error': 0.9977946090443305, 'r2': 0.0},
             'aggregated': {'macro_mean_absolute_error': 0.9977946090443305,
              'macro_r2': 0.0}})

In [25]:
val_tmp = val.copy()
pred_tmp = [
    np.mean(topk_properties)
    for topk_smiles, topk_properties, query_smiles, query_property in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=val,
        query2ref_dist=val2ref_dist,
        topk=TOPK,
        property=PROPERTY
    )
]
val_tmp.loc[val_tmp['LogD'].notna(), 'pred_LogD'] = pred_tmp
val_preds = extract_preds(val_tmp, target_columns=[PROPERTY])
val_refs = extract_refs(val_tmp, target_columns=[PROPERTY])
metrics = eval_admet(val_preds, val_refs, target_columns=[PROPERTY])
print(json.dumps(metrics, indent=2))

{
  "LogD": {
    "mean_absolute_error": 1.0747761194029852,
    "r2": -0.14455899098513036
  },
  "aggregated": {
    "macro_mean_absolute_error": 1.0747761194029852,
    "macro_r2": -0.14455899098513036
  }
}


In [27]:
model = GenerativeModel(
    "gemini-1.5-pro-002",
    system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
)

In [28]:
predictions = list(tqdm(map(partial(get_answer, model), val_dataset), total=len(val_dataset)))

100%|██████████| 67/67 [00:34<00:00,  1.96it/s]


In [29]:
val_tmp = val.copy()
val_tmp.loc[val_tmp[PROPERTY].notna(), f'pred_{PROPERTY}'] = predictions
val_preds = extract_preds(val_tmp, target_columns=[PROPERTY])
val_refs = extract_refs(val_tmp, target_columns=[PROPERTY])
metrics = eval_admet(val_preds, val_refs, target_columns=[PROPERTY])
print(json.dumps(metrics, indent=2))


{
  "LogD": {
    "mean_absolute_error": 0.7494029850746267,
    "r2": 0.31781726881996497
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.7494029850746267,
    "macro_r2": 0.31781726881996497
  }
}


In [30]:
model = GenerativeModel(
    "gemini-2.0-pro-exp-02-05",
    system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
)

In [32]:
predictions = list(tqdm(map(partial(get_answer, model), val_dataset), total=len(val_dataset)))

 12%|█▏        | 8/67 [00:05<00:38,  1.52it/s]


ResourceExhausted: 429 Quota exceeded for aiplatform.googleapis.com/generate_content_requests_per_minute_per_project_per_base_model with base model: gemini-experimental. Please submit a quota increase request. https://cloud.google.com/vertex-ai/docs/generative-ai/quotas-genai.