In [1]:
from collections import defaultdict
from functools import partial
import json
from pathlib import Path
import re
import sys
import os
from dotenv import load_dotenv

import numpy as np
import pandas as pd
from tqdm import tqdm
from openai import OpenAI
from rdkit import RDLogger

sys.path.insert(0, '../agenticadmet')
from eval import extract_preds, extract_refs, eval_admet
from utils import ECFP_from_smiles, tanimoto_similarity

[03:06:20] Initializing Normalizer


In [2]:
logger = RDLogger.logger()
logger.setLevel(RDLogger.CRITICAL)

In [3]:
load_dotenv()
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
RANDOM_SEED = 42
SPLIT = 0

client = OpenAI(api_key=OPENAI_API_KEY)

In [4]:
TARGET_COLUMNS = ["HLM", "MLM", "LogD", "KSOL", "MDR1-MDCKII"]
PROPERTIES = ['LogHLM', 'LogMLM', 'LogD', 'LogKSOL', 'LogMDR1-MDCKII']
PROPERTY = 'LogD'

In [5]:
data = pd.read_csv(f'../data/asap/datasets/rnd_splits/split_{SPLIT}.csv')
data

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,COC1=CC=CC(Cl)=C1NC(=O)N1CCC[C@H](C(N)=O)C1 |a...,,,0.3,,2.0,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1,COc1cccc(Cl)c1NC(=O)N1CCC[C@H](C(N)=O)C1 |a:16|,191,|a:16|,,,,0.477121,val
1,O=C(NCC(F)F)[C@H](NC1=CC2=C(C=C1Br)CNC2)C1=CC(...,,333.0,2.9,,0.2,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,335,|&1:7|,,,2.523746,0.079181,train
2,O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=...,,,0.4,,0.5,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,336,|&1:7|,,,,0.176091,train
3,NC(=O)[C@H]1CCCN(C(=O)CC2=CC=CC3=C2C=CO3)C1 |&...,,376.0,1.0,,8.5,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1 |&1:3|,300,|&1:3|,,,2.576341,0.977724,train
4,CC1=CC(CC(=O)N2CCC[C@H](C(N)=O)C2)=CC=N1 |&1:11|,,375.0,-0.3,,0.9,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1 |&1:11|,249,|&1:11|,,,2.575188,0.278754,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
399,CC(C)NC[C@H](O)COC1=CC=CC2=CC=CC=C12 |&1:5|,25.5,,,63.0,,CC(C)NC[C@H](O)COc1cccc2ccccc12,CC(C)NC[C@H](O)COc1cccc2ccccc12 |&1:5|,22,|&1:5|,1.423246,1.806180,,,val
400,O=C(O)CC1=CC=CC=C1NC1=C(Cl)C=CC=C1Cl,216.0,,,386.0,,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,O=C(O)Cc1ccccc1Nc1c(Cl)cccc1Cl,380,,2.336460,2.587711,,,val
401,NCC1=CC(Cl)=CC(C(=O)NC2=CC=C3CNCC3=C2)=C1,,,2.0,,,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,303,,,,,,train
402,COC(=O)NC1=NC2=CC=C(C(=O)C3=CC=CC=C3)C=C2N1,,,2.9,,,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,166,,,,,,train


In [6]:
def save_and_upload_training_data(data, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)

    with open(path, 'w') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

    with open(path, 'rb') as f:
        file = client.files.create(
            file=f,
            purpose="fine-tune"
        )
    
    return file.id

In [7]:
def complete_chat(model_id, messages):
    response = client.chat.completions.create(
        model=model_id,
        messages=messages
    )
    
    return response.choices[0].message.content

In [8]:
def start_finetuning_sft(
        training_file_id: str, validation_file_id: str, model: str = "gpt-4o-mini-2024-07-18",
        batch_size: str | int = "auto", learning_rate_multiplier: str | float = "auto", epochs: str | int = "auto"
    ):
    job = client.fine_tuning.jobs.create(
        training_file=training_file_id,
        validation_file=validation_file_id,
        model=model,
        method={
            "type": "supervised",
            "supervised": {
                "hyperparameters": {
                    "batch_size": batch_size,
                    "learning_rate_multiplier": learning_rate_multiplier,
                    "n_epochs": epochs
                }
            }
        },
        seed=RANDOM_SEED
    )

    return job.id

In [9]:
def get_answer(model, content):
    result = model.generate_content(content['contents'])
    answer = result.candidates[0].content.parts[0].text
    try:
        answer = float(re.search(r'(-?\d+\.\d+)', answer.strip()).group(1))
    except Exception as e:
        print(answer)
        return None
    
    return answer

## Hypothesis 1 - top-k closest molecules from train with given properties (for all properties)

In [10]:
TOPK = 5

In [11]:
train = ref_data = data[data['split'] == 'train'].reset_index(drop=True)
val = data[data['split'] == 'val'].reset_index(drop=True)
train_ecfp = np.array(train['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
val_ecfp = np.array(val['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
ref_ecfp = np.array(ref_data['smiles_std'].apply(partial(ECFP_from_smiles, use_chirality=True)).tolist())
train2ref_dist = tanimoto_similarity(train_ecfp, ref_ecfp)
val2ref_dist = tanimoto_similarity(val_ecfp, ref_ecfp)
train2ref_dist.shape, val2ref_dist.shape

((323, 323), (81, 323))

In [12]:
ref_data

Unnamed: 0,smiles,HLM,KSOL,LogD,MLM,MDR1-MDCKII,smiles_std,cxsmiles_std,mol_idx,smiles_ext,LogHLM,LogMLM,LogKSOL,LogMDR1-MDCKII,split
0,O=C(NCC(F)F)[C@H](NC1=CC2=C(C=C1Br)CNC2)C1=CC(...,,333.0,2.9,,0.2,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,O=C(NCC(F)F)[C@H](Nc1cc2c(cc1Br)CNC2)c1cc(Cl)c...,335,|&1:7|,,,2.523746,0.079181,train
1,O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=...,,,0.4,,0.5,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,O=C(NCC(F)F)[C@H](Nc1ccc2c(c1)CNC2)c1cc(Br)cc2...,336,|&1:7|,,,,0.176091,train
2,NC(=O)[C@H]1CCCN(C(=O)CC2=CC=CC3=C2C=CO3)C1 |&...,,376.0,1.0,,8.5,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1,NC(=O)[C@H]1CCCN(C(=O)Cc2cccc3occc23)C1 |&1:3|,300,|&1:3|,,,2.576341,0.977724,train
3,CC1=CC(CC(=O)N2CCC[C@H](C(N)=O)C2)=CC=N1 |&1:11|,,375.0,-0.3,,0.9,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1,Cc1cc(CC(=O)N2CCC[C@H](C(N)=O)C2)ccn1 |&1:11|,249,|&1:11|,,,2.575188,0.278754,train
4,COC1=CC(Cl)=CC([C@H](NC2=CC=C3CNCC3=C2)C(=O)NC...,,,0.5,,0.7,COc1cc(Cl)cc([C@H](Nc2ccc3c(c2)CNC3)C(=O)NCC(F...,COc1cc(Cl)cc([C@H](Nc2ccc3c(c2)CNC3)C(=O)NCC(F...,178,|&1:8|,,,,0.230449,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
318,C=CC(=O)N1CCCC2=CC=C(N(C(=O)CC3=CN=CC4=CC=CC=C...,1380.0,,,1760.0,,C=CC(=O)N1CCCc2ccc(N(C(=O)Cc3cncc4ccccc34)N(C)...,C=CC(=O)N1CCCc2ccc(N(C(=O)Cc3cncc4ccccc34)N(C)...,3,,3.140194,3.245759,,,train
319,CN(C)CCCN1C2=CC=CC=C2CCC2=CC=CC=C21,,,,169.0,,CN(C)CCCN1c2ccccc2CCc2ccccc21,CN(C)CCCN1c2ccccc2CCc2ccccc21,103,,,2.230449,,,train
320,NCC1=CC(Cl)=CC(C(=O)NC2=CC=C3CNCC3=C2)=C1,,,2.0,,,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,NCc1cc(Cl)cc(C(=O)Nc2ccc3c(c2)CNC3)c1,303,,,,,,train
321,COC(=O)NC1=NC2=CC=C(C(=O)C3=CC=CC=C3)C=C2N1,,,2.9,,,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,COC(=O)Nc1nc2ccc(C(=O)c3ccccc3)cc2[nH]1,166,,,,,,train


In [13]:
def get_all_topk_smiles_with_properties(ref_data, query_data, query2ref_dist, topk=TOPK):
    for i in range(query2ref_dist.shape[0]):
        query_smiles = query_data.iloc[i]['cxsmiles_std']
        query_properties = {}
        topk_smiles2properties = defaultdict(dict)
        for property in PROPERTIES:
            query_property = query_data.iloc[i][property]
            if not np.isnan(query_property):
                query_properties[property] = query_property

            ref_data_with_property = ref_data[ref_data[property].notna()]
            dist = query2ref_dist[i, ref_data_with_property.index]
            order = np.argsort(dist)[::-1]
            ordered_dist = dist[order]
            order = order[~np.isclose(ordered_dist, 1.0)]  # remove self-similarity including duplicates
            topk_idx = order[:topk]
            topk_smiles = ref_data_with_property.iloc[topk_idx]['cxsmiles_std'].tolist()
            topk_properties = ref_data_with_property.iloc[topk_idx][property].tolist()

            for smi, prop in zip(topk_smiles, topk_properties):
                topk_smiles2properties[smi][property] = prop
        
        yield topk_smiles2properties, query_smiles, query_properties

In [14]:
system_instruction = \
"You are an experienced medicinal chemist who worked many years determining ADME properties of drug-like molecules. " \
"You can determine properties like LogD (solubility), LogHLM (human liver microsome stability), LogMLM (mouse liver microsome stability), " \
"LogKSOL (kinetic solubility), and LogMDR1-MDCKII (MDR1-MDCKII permeability) based on molecular structure and properties " \
"of similar molecules from a reference set. You are given a list of reference molecules represented in SMILES " \
"paired with their property values. Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, " \
"it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means " \
"that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. If a molecule " \
"doesn't have a notation or notation is like |a:16|, it means there's absolute (known) stereochemistry of the compound. " \
"Only pay attention to the relevant or correlating properties, and ignore the rest. " \
"Your task is to determine the requested properties of the molecule represented in SMILES. " \
"Answer only with a single floating point number per property."

In [15]:
def format_data(
    topk_smiles2properties: dict[str, dict[str, float]],
    query_smiles: str,
    query_properties: dict[str, float] | None = None,
    system_instruction: str | None = None
):
    if system_instruction is not None:
        messages = [{
            "role": "system",
            "content": system_instruction
        }]
    else:
        messages = []

    messages.append({
        "role": "user",
        "content": "; ".join([
            f"SMILES: {smi}, " + ", ".join([
                    f"{prop_name}={prop_val:.2f}"
                    for prop_name, prop_val in prop_name2prop_val.items()
                ])
                for smi, prop_name2prop_val in topk_smiles2properties.items()
            ])
        })
    if query_properties is not None:
        messages.append({
            "role": "user",
            "content": f"SMILES: {query_smiles}; " + "Determine " + ", ".join([
                f"{prop_name}" for prop_name in query_properties.keys()
            ])
        })
    else:
        messages.append({
            "role": "user",
            "content": f"SMILES: {query_smiles}; " + "Determine " +", ".join([
                f"{prop_name}" for prop_name in PROPERTIES
            ])
        })

    if query_properties is not None:
        messages.append({
            "role": "assistant",
            "content": ", ".join([
                f"{prop_name}={prop_val:.2f}"
                for prop_name, prop_val in query_properties.items()
            ])
        })

    output = {"messages": messages}
    
    return output

In [16]:
train_dataset = [
    format_data(
        topk_smiles2properties=topk_smiles2properties,
        query_smiles=query_smiles,
        query_properties=query_properties,
        system_instruction=system_instruction
    )
    for topk_smiles2properties, query_smiles, query_properties in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=train,
        query2ref_dist=train2ref_dist,
        topk=TOPK
    )
]

In [17]:
train_dataset[0]

{'messages': [{'role': 'system',
   'content': "You are an experienced medicinal chemist who worked many years determining ADME properties of drug-like molecules. You can determine properties like LogD (solubility), LogHLM (human liver microsome stability), LogMLM (mouse liver microsome stability), LogKSOL (kinetic solubility), and LogMDR1-MDCKII (MDR1-MDCKII permeability) based on molecular structure and properties of similar molecules from a reference set. You are given a list of reference molecules represented in SMILES paired with their property values. Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. If a molecule doesn't have a notation or notation is like |a:16|, it means there's absolute (known) stereochemistry of the compound. Only pay 

In [18]:
val_dataset = [
    format_data(
        topk_smiles2properties=topk_smiles2properties,
        query_smiles=query_smiles,
        query_properties=query_properties,
        system_instruction=system_instruction
    )
    for topk_smiles2properties, query_smiles, query_properties in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=val,
        query2ref_dist=val2ref_dist,
        topk=TOPK
    )
]

In [19]:
val_dataset[0]

{'messages': [{'role': 'system',
   'content': "You are an experienced medicinal chemist who worked many years determining ADME properties of drug-like molecules. You can determine properties like LogD (solubility), LogHLM (human liver microsome stability), LogMLM (mouse liver microsome stability), LogKSOL (kinetic solubility), and LogMDR1-MDCKII (MDR1-MDCKII permeability) based on molecular structure and properties of similar molecules from a reference set. You are given a list of reference molecules represented in SMILES paired with their property values. Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. If a molecule doesn't have a notation or notation is like |a:16|, it means there's absolute (known) stereochemistry of the compound. Only pay 

In [20]:
train_data_path = Path(f"../data/asap/datasets/rnd_splits/split_{SPLIT}/gpt_all_prop_top{TOPK}_train.jsonl")
train_file_id = save_and_upload_training_data(train_dataset, train_data_path)
val_data_path = Path(f"../data/asap/datasets/rnd_splits/split_{SPLIT}/gpt_all_prop_top{TOPK}_val.jsonl")
val_file_id = save_and_upload_training_data(val_dataset, val_data_path)

In [21]:
job_id = start_finetuning_sft(train_file_id, val_file_id, model="gpt-4o-2024-08-06", epochs=3)

In [24]:
job = client.fine_tuning.jobs.retrieve(job_id)
print(job.id)
print(job.status)
print(job.model)
print(job.training_file)
print(job.validation_file)
print(job.hyperparameters)
print(job.created_at)
print(job.error)

ftjob-yBuLLbN462N282MHLCYE73mJ
succeeded
gpt-4o-2024-08-06
file-XErDcGqubqSm58ysX9c1mp
file-PcGVJ6hKZSkiJMrdM1A2nM
Hyperparameters(batch_size=1, learning_rate_multiplier=2.0, n_epochs=3)
1741921586
Error(code=None, message=None, param=None)


In [25]:
val_dataset = [
    format_data(
        topk_smiles2properties=topk_smiles2properties,
        query_smiles=query_smiles,
        system_instruction=system_instruction
    )
    for topk_smiles2properties, query_smiles, query_properties in get_all_topk_smiles_with_properties(
        ref_data=ref_data,
        query_data=val,
        query2ref_dist=val2ref_dist,
        topk=TOPK
    )
]

In [26]:
val_dataset[0]

{'messages': [{'role': 'system',
   'content': "You are an experienced medicinal chemist who worked many years determining ADME properties of drug-like molecules. You can determine properties like LogD (solubility), LogHLM (human liver microsome stability), LogMLM (mouse liver microsome stability), LogKSOL (kinetic solubility), and LogMDR1-MDCKII (MDR1-MDCKII permeability) based on molecular structure and properties of similar molecules from a reference set. You are given a list of reference molecules represented in SMILES paired with their property values. Some values might be incorrect due to assay errors. If you see extended notation like |&1:3|, it means that the molecule has mixed stereochemistry. If you see notation like |o1:4|, it means that the molecule has undefined stereochemistry (either R or S isomer). Take these compounds with care. If a molecule doesn't have a notation or notation is like |a:16|, it means there's absolute (known) stereochemistry of the compound. Only pay 

In [28]:
for content in val_dataset:
    result = complete_chat(job.fine_tuned_model, content['messages'])
    print(result)
    break

LogHLM=1.02, LogMLM=1.41, LogD=1.50, LogKSOL=2.58, LogMDR1-MDCKII=0.20


In [29]:
def get_answer(model, content):
    answer = complete_chat(model, content['messages'])
    try:
        # Extract all property=value pairs from the answer
        matches = re.findall(r'(\w+(?:-\w+)*?)=(-?\d+\.\d+)', answer.strip())
        if not matches:
            print(answer)
            return None
            
        # Convert to dictionary
        properties = {prop: float(val) for prop, val in matches}
        # reorder to match PROPERTIES
        properties = {prop: properties[prop] for prop in PROPERTIES if prop in properties}
        
        # Return the requested property value
        return properties
        
    except Exception as e:
        print(answer)
        return None

In [31]:
predictions = list(tqdm(map(partial(get_answer, job.fine_tuned_model), val_dataset), total=len(val_dataset)))
predictions = pd.DataFrame(predictions).rename(columns={prop: f'pred_{prop}' for prop in PROPERTIES})
val_ = pd.concat([val, predictions], axis=1)
# val_.loc[9, 'pred_LogKSOL'] = 2.5
val_preds = extract_preds(val_)
val_refs = extract_refs(val_)
metrics = eval_admet(val_preds, val_refs, target_columns=TARGET_COLUMNS)
print(json.dumps(metrics, indent=2))

100%|██████████| 81/81 [01:37<00:00,  1.20s/it]

{
  "HLM": {
    "mean_absolute_error": 0.4327145832230898,
    "r2": 0.036173730051123565
  },
  "MLM": {
    "mean_absolute_error": 0.48660588593784565,
    "r2": 0.13529233365353288
  },
  "LogD": {
    "mean_absolute_error": 0.519344262295082,
    "r2": 0.7372712316377459
  },
  "KSOL": {
    "mean_absolute_error": 0.31851278888613116,
    "r2": 0.36567336750911406
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.2562405718794265,
    "r2": 0.33449713264053915
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.402683618444315,
    "macro_r2": 0.3217815590984111
  }
}





**gpt-4o-2024-08-06, epochs=3**

{
  "HLM": {
    "mean_absolute_error": 0.4327145832230898,
    "r2": 0.036173730051123565
  },
  "MLM": {
    "mean_absolute_error": 0.48660588593784565,
    "r2": 0.13529233365353288
  },
  "LogD": {
    "mean_absolute_error": 0.519344262295082,
    "r2": 0.7372712316377459
  },
  "KSOL": {
    "mean_absolute_error": 0.31851278888613116,
    "r2": 0.36567336750911406
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.2562405718794265,
    "r2": 0.33449713264053915
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.402683618444315,
    "macro_r2": 0.3217815590984111
  }
}

**gpt-4o-mini-2024-07-18, epochs=1**

{
  "HLM": {
    "mean_absolute_error": 0.4657938786266147,
    "r2": -0.02197201727747755
  },
  "MLM": {
    "mean_absolute_error": 0.5541113327048163,
    "r2": -0.215385095794018
  },
  "LogD": {
    "mean_absolute_error": 0.8390163934426229,
    "r2": 0.25983516081479563
  },
  "KSOL": {
    "mean_absolute_error": 0.4504258160572402,
    "r2": -0.0008487104920886779
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.3318020325575335,
    "r2": -0.07659769073501965
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.5282298906777655,
    "macro_r2": -0.010993670696761649
  }
}

**gpt-4o-mini-2024-07-18, epochs=10**

{
  "HLM": {
    "mean_absolute_error": 0.3842828382200825,
    "r2": 0.01455307919105353
  },
  "MLM": {
    "mean_absolute_error": 0.4244377994930234,
    "r2": 0.09937253828722759
  },
  "LogD": {
    "mean_absolute_error": 0.530655737704918,
    "r2": 0.6583112171362293
  },
  "KSOL": {
    "mean_absolute_error": 0.3223456655356202,
    "r2": 0.3652354032636217
  },
  "MDR1-MDCKII": {
    "mean_absolute_error": 0.24955787253975956,
    "r2": 0.2384280229486857
  },
  "aggregated": {
    "macro_mean_absolute_error": 0.3822559826986808,
    "macro_r2": 0.27518005216536356
  }
}