In [1]:
from collections import defaultdict
from functools import partial
import json
import multiprocessing as mp
from pathlib import Path
import random
import re
import sys
import time
from typing import Any

from google.cloud import storage
import numpy as np
import pandas as pd
from tqdm import tqdm
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
from rdkit import RDLogger
import vertexai
from vertexai.generative_models import GenerativeModel, GenerationConfig
from vertexai.tuning import sft

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet
from utils import ECFP_from_smiles, tanimoto_similarity, standardize, standardize_cxsmiles

[05:10:44] Initializing Normalizer


In [2]:
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)

In [3]:
PROJECT_ID = "bioptic-io"
LOCATION = "us-central1"
vertexai.init(project=PROJECT_ID, location=LOCATION)

RANDOM_SEED = 42

In [4]:
TARGET_COLUMNS = ["HLM", "MLM", "LogD", "KSOL", "MDR1-MDCKII"]
PROPERTIES = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']
PROPERTY = 'LogD'

In [5]:
def save_and_upload_training_data(data, path, gs_uri):
    with open(path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

    storage_client = storage.Client()
    bucket_name, key = gs_uri[5:].split("/", 1)
    bucket = storage_client.bucket(bucket_name)
    blob = bucket.blob(key)
    blob.upload_from_filename(path)

In [6]:
def start_finetuning_sft(
        train_dataset_gs_uri: str, val_dataset_gs_uri: str | None = None, source_model: str = "gemini-1.5-pro-002",
        adapter_size: int = 4, epochs: int = 10, tuned_model_display_name: str | None = None
    ):
    sft_tuning_job = sft.train(
        source_model=source_model,
        train_dataset=train_dataset_gs_uri,
        validation_dataset=val_dataset_gs_uri,
        adapter_size=adapter_size,
        epochs=epochs,
        tuned_model_display_name=tuned_model_display_name
    )

In [7]:
def get_single_property_answer(model, content):
    result = model.generate_content(content['contents'])
    answer = result.candidates[0].content.parts[0].text
    try:
        answer = float(re.search(r'(-?\d+\.\d+)', answer.strip()).group(1))
    except Exception as e:
        print(answer)
        return None
    
    return answer

In [8]:
def get_answer(model, content):
    result = model.generate_content(content['contents'])
    answer = result.candidates[0].content.parts[0].text
    try:
        # Extract all property=value pairs from the answer
        matches = re.findall(r'(\w+(?:-\w+)*?)=(-?\d+\.\d+)', answer.strip())
        if not matches:
            print(answer)
            return None
            
        # Convert to dictionary
        properties = {prop: float(val) for prop, val in matches}
        # reorder to match PROPERTIES
        properties = {prop: properties[prop] for prop in PROPERTIES if prop in properties}
        
        # Return the requested property value
        return properties
        
    except Exception as e:
        print(answer)
        return None

## Hypothesis 1 - top-k closest molecules from train with given properties

In [9]:
TOPK = 5
PROPERTY = 'LogD'
TARGET_COLUMN = 'LogD'
EPOCHS = 10

In [10]:
def get_all_topk_smiles_with_properties(ref_data, query_data, query2ref_dist, topk=TOPK, property=PROPERTY):
    for i in range(query2ref_dist.shape[0]):
        query_smiles = query_data.iloc[i]['cxsmiles_std']
        query_property = query_data.iloc[i][property]
        if np.isnan(query_property):
            continue

        dist = query2ref_dist[i]
        order = np.argsort(dist)[::-1]
        ordered_dist = dist[order]
        order = order[~np.isclose(ordered_dist, 1.0)]  # remove self-similarity including duplicates
        topk_idx = order[:topk]
        topk_smiles = ref_data.iloc[topk_idx]['cxsmiles_std'].tolist()
        topk_properties = ref_data.iloc[topk_idx][property].tolist()
        yield topk_smiles, topk_properties, query_smiles, query_property

In [11]:
def format_data(
    input_smiles: list[str],
    input_properties: list[float],
    query_smiles: str,
    query_property: float | None = None,
    system_instruction: str | None = None
):
    if system_instruction is not None:
        output = {
            "systemInstruction": {
            "role": "Ignored",
            "parts": [
            {
                "text": system_instruction
            }
            ]
        }
    }
    else:
        output = {}
    
    output.update({
        "contents": [
            {
            "role": "user",
            "parts": [{
                "text": "; ".join([
                    f"{smi}, {PROPERTY}={prop:.2f}" 
                    for smi, prop in zip(input_smiles, input_properties)
                ]) + f"; Determine {PROPERTY} of {query_smiles}"
            }]
            }
        ]
    })

    if query_property is not None:
        output['contents'].append({
            "role": "model",
            "parts": [
                {
                    "text": f"{PROPERTY}={query_property:.2f}"
                }
            ]
        })
    
    return output

In [12]:
SYSTEM_INSTRUCTIONS = {
    f"{PROPERTY}": \
        f"You are an experienced medicinal chemist who worked determining the {PROPERTY} of molecules for years. "
        f"You can determine the {PROPERTY} of a molecule based on its structure and properties of similar molecules from a reference set. "
        f"You are given a list of {TOPK} reference molecules represented in SMILES paired with their {PROPERTY} values. "
        f"Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, "
        f"it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means "
        f"that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. "
        f"If a molecule doesn't have a notation or notation is like |a:16|, it means there's absolute (known) stereochemistry of the compound. "
        f"Your task is to determine the {PROPERTY} the molecule represented in SMILES. "
        f"Answer only with a single floating point number."
}

In [13]:
for split_idx in range(5):
    data = pd.read_csv(f'../data/asap/datasets/rnd_splits/split_{split_idx}.csv')
    train = data[data['split'] == 'train'].reset_index(drop=True)
    val = data[data['split'] == 'val'].reset_index(drop=True)
    ref_data = data[(data['split'] == 'train') & ~data[PROPERTY].isna()].reset_index(drop=True)
    train_ecfp = np.array(train['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    val_ecfp = np.array(val['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    ref_ecfp = np.array(ref_data['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    train2ref_dist = tanimoto_similarity(train_ecfp, ref_ecfp)
    val2ref_dist = tanimoto_similarity(val_ecfp, ref_ecfp)

    train_dataset = [
        format_data(
            input_smiles=input_smiles,
            input_properties=input_properties,
            query_smiles=query_smiles,
            query_property=query_property,
            system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
        )
        for input_smiles, input_properties, query_smiles, query_property in get_all_topk_smiles_with_properties(
            ref_data=ref_data,
            query_data=train,
            query2ref_dist=train2ref_dist,
            topk=TOPK,
            property=PROPERTY
        )
    ]

    data_dir = Path(f"../data/asap/datasets/rnd_splits/split_{split_idx}/")
    data_dir.mkdir(parents=True, exist_ok=True)
    train_data_path = data_dir / f"gemini_{PROPERTY}_top{TOPK}_train.jsonl"
    train_gs_uri = f"<gs_bucket>/data/asap/datasets/rnd_splits/split_{split_idx}/gemini_{PROPERTY}_top{TOPK}_train.jsonl"

    save_and_upload_training_data(train_dataset, train_data_path, train_gs_uri)
    start_finetuning_sft(
        train_dataset_gs_uri=train_gs_uri,
        source_model="gemini-1.5-flash-002",
        epochs=EPOCHS,
        tuned_model_display_name=f"{PROPERTY}_top{TOPK}_split_{split_idx}"
    )

Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/8436637399500455936
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/8436637399500455936')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/8436637399500455936?project=199759238457


Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/1809590547824771072
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/1809590547824771072')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/1809590547824771072?project=199759238457


Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/8315040209561452544
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/8315040209561452544')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/8315040209561452544?project=199759238457


Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/2503144890439827456
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/2503144890439827456')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/2503144890439827456?project=199759238457


Creating SupervisedTuningJob
SupervisedTuningJob created. Resource name: projects/199759238457/locations/us-central1/tuningJobs/82460090728185856
To use this SupervisedTuningJob in another session:
tuning_job = sft.SupervisedTuningJob('projects/199759238457/locations/us-central1/tuningJobs/82460090728185856')
View Tuning Job:
https://console.cloud.google.com/vertex-ai/generative/language/locations/us-central1/tuning/tuningJob/82460090728185856?project=199759238457


In [14]:
tuned_job_ids = [
    "8436637399500455936",
    "1809590547824771072",
    "8315040209561452544",
    "2503144890439827456",
    "82460090728185856"
]

for split_idx, tuned_job_id in enumerate(tuned_job_ids):
    data = pd.read_csv(f'../data/asap/datasets/rnd_splits/split_{split_idx}.csv')
    val = data[data['split'] == 'val'].reset_index(drop=True)
    ref_data = data[(data['split'] == 'train') & ~data[PROPERTY].isna()].reset_index(drop=True)
    val_ecfp = np.array(val['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    ref_ecfp = np.array(ref_data['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
    val2ref_dist = tanimoto_similarity(val_ecfp, ref_ecfp)

    val_dataset = [
        format_data(
            input_smiles=input_smiles,
            input_properties=input_properties,
            query_smiles=query_smiles,
            system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
        )
        for input_smiles, input_properties, query_smiles, _ in get_all_topk_smiles_with_properties(
            ref_data=ref_data,
            query_data=val,
            query2ref_dist=val2ref_dist,
            topk=TOPK,
            property=PROPERTY
        )
    ]

    sft_tuning_job = sft.SupervisedTuningJob(f"projects/{PROJECT_ID}/locations/{LOCATION}/tuningJobs/{tuned_job_id}")
    tuned_model = GenerativeModel(
        sft_tuning_job.tuned_model_endpoint_name,
        generation_config=GenerationConfig(
            temperature=0.0
        ),
        system_instruction=SYSTEM_INSTRUCTIONS[PROPERTY]
    )

    predictions = list(tqdm(map(partial(get_single_property_answer, tuned_model), val_dataset), total=len(val_dataset)))
    val.loc[val[PROPERTY].notna(), f'pred_{PROPERTY}'] = predictions
    val_preds = extract_preds(val, target_columns=[PROPERTY])
    val_refs = extract_refs(val, target_columns=[PROPERTY])
    metrics = eval_admet(val_preds, val_refs, target_columns=[TARGET_COLUMN])
    print(f"Split {split_idx}")
    print(json.dumps(metrics, indent=2))

100%|██████████| 61/61 [00:48<00:00,  1.26it/s]


Split 0
{
  "LogD": {
    "mean_absolute_error": 0.41950819672131145,
    "r2": 0.8120427453513654
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.41950819672131145,
    "macro_r2": 0.8120427453513654
  }
}


100%|██████████| 66/66 [00:52<00:00,  1.27it/s]


Split 1
{
  "LogD": {
    "mean_absolute_error": 0.5974242424242424,
    "r2": 0.5433327788077019
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.5974242424242424,
    "macro_r2": 0.5433327788077019
  }
}


100%|██████████| 66/66 [00:52<00:00,  1.27it/s]


Split 2
{
  "LogD": {
    "mean_absolute_error": 0.5742424242424242,
    "r2": 0.6421093280841206
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.5742424242424242,
    "macro_r2": 0.6421093280841206
  }
}


100%|██████████| 65/65 [00:51<00:00,  1.25it/s]


Split 3
{
  "LogD": {
    "mean_absolute_error": 0.4390769230769231,
    "r2": 0.7534894272776074
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.4390769230769231,
    "macro_r2": 0.7534894272776074
  }
}


100%|██████████| 60/60 [00:52<00:00,  1.15it/s]

Split 4
{
  "LogD": {
    "mean_absolute_error": 0.5196666666666666,
    "r2": 0.6338386937862661
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.5196666666666666,
    "macro_r2": 0.6338386937862661
  }
}



