In [1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import re
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from collections import defaultdict
from Chemprompt.data.data_loader import DataLoader

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/HDD1/bbq9088/miniconda3/envs/ChEmPrompt1027/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


In [3]:
# Suppress RDKit warnings
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")

In [4]:
device = "cuda:0"

In [5]:
# Load ChemDFM model
model_name = "OpenDFM/ChemDFM-v1.5-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
).to(device)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
# SMILES canonicalization
def preprocess_smiles(smiles: str) -> str:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    formatted_mol = Chem.MolToSmiles(mol, canonical=True)
    return formatted_mol

In [7]:
def format_prompt(smiles: str, dataset_name: str, idx) -> str:
    smiles = preprocess_smiles(smiles)
    if smiles is None:
        return ""  # Invalid SMILES

    if dataset_name == "FreeSolv":
        query = (
            "Estimate the solvation free energy in water for the following molecule. "
            "Just return a number. No explanation. No units.\n"
        )
        
    elif dataset_name == "ESOL":
        query = (
            "Estimate the aqueous solubility (logS) of the following molecule. "
            "Just return a number. No explanation. No units.\n"
        )
    elif dataset_name == "Lipo":
        query = (
            "Estimate the lipophilicity (logP) of the following molecule. "
            "Just return a number. No explanation. No units.\n"
        )
    elif dataset_name == "HPPB":
        query = (
            "Estimate the human plasma protein binding as percent unbound (fu, %) "
            "for the following molecule. "
            "Just return a number between 0 and 100. No explanation. No units.\n"
        )
        
    elif dataset_name == "Caco2_Wang":
        query = (
            "Estimate the intestinal epithelial permeability (Caco-2) "
            "for the following molecule based on the Caco-2 definition. "
            "Just return a number. No explanation. No units.\n"
        )
    
    else:
        return ""
    input_text = f"[Round {idx}]\nHuman: {query}\n{smiles}\nAssistant:"
    # print(input_text)
    return input_text

In [8]:
# Extract number from output
def extract_number(text: str):
    text = text.strip()

    # 1. Handle human-written scientific notation: a * 10^b
    sci_match = re.findall(
        r"([-+]?\d*\.?\d+)\s*(?:\*|x)?\s*10\^?\s*([-+]?\d+)",
        text,
        flags=re.IGNORECASE
    )
    if sci_match:
        base, exp = sci_match[-1]
        try:
            return float(base) * (10 ** int(exp))
        except:
            pass

    # 2. Fallback: extract the last standard number or e-notation number
    matches = re.findall(
        r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?",
        text
    )
    if matches:
        try:
            return float(matches[-1])
        except:
            return None

    return None

In [9]:
# Generation configuration
gen_config = GenerationConfig(
    do_sample=False,
    top_k=0,
    top_p=1.0,
    temperature=0.0,
    repetition_penalty=1.0,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id
)
# gen_config = GenerationConfig(
#     do_sample=True,
#     top_k=20,
#     top_p=0.9,
#     temperature=0.9,
#     max_new_tokens=1024,
#     repetition_penalty=1.05,
#     eos_token_id=tokenizer.eos_token_id
# )

In [10]:
# Dataset list
dataset_names = ["FreeSolv", "ESOL", "Lipo"]

In [11]:
loader = DataLoader()

In [12]:
base_output_dir = "./result/QandA/ChemDFM-v1.5-8B"
os.makedirs(base_output_dir, exist_ok=True)

In [13]:
# Iterate over datasets
for dataset_name in tqdm(dataset_names, desc="Dataset"):
    print(f"\n[+] Processing dataset: {dataset_name}")
    x, y = loader.load_dataset(dataset_name)
    x = np.array(x)
    y = np.array(y)

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []
    metrics_per_fold = defaultdict(list)

    output_dir = os.path.join(base_output_dir, dataset_name)
    os.makedirs(output_dir, exist_ok=True)
    answer_txt_path = os.path.join(output_dir, f"{dataset_name}_answers.txt")
    with open(answer_txt_path, "w") as f:
        f.write("")

    for fold, (_, test_idx) in enumerate(tqdm(kf.split(x, y), total=5, desc=f"Fold {dataset_name}")):
        x_test = x[test_idx]
        y_test = y[test_idx]
        predictions = []
        i = 0

        for smiles in tqdm(x_test, desc=f"Predicting fold {fold + 1}", leave=False):
            i += 1
            prompt = format_prompt(smiles, dataset_name, i)
            if not prompt.strip():
                predictions.append(None)
                continue

            inputs = tokenizer(prompt, return_tensors="pt").to(device)
            outputs = model.generate(**inputs, generation_config=gen_config)
            result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            reply = result.split("Assistant:")[-1].strip()

            with open(answer_txt_path, "a") as f:
                f.write(f"{reply}\n")

            pred_val = extract_number(reply)
            predictions.append(pred_val)

        i = 0
        df_fold = pd.DataFrame({
            "smiles": list(x_test),
            "true_value": list(y_test),
            "predicted_value": list(predictions),
            "fold": [fold + 1] * len(x_test)
        })
        fold_results.append(df_fold)

        # Fold-wise metrics
        df_valid = df_fold.dropna(subset=["predicted_value"])
        y_true = df_valid["true_value"].astype(float).values
        y_pred = df_valid["predicted_value"].astype(float).values

        mse = mean_squared_error(y_true, y_pred)
        rmse = np.sqrt(mse)
        r2 = r2_score(y_true, y_pred)
        pcc, _ = pearsonr(y_true, y_pred)
        spearman, _ = spearmanr(y_true, y_pred)

        metrics_per_fold["Fold"].append(fold + 1)
        metrics_per_fold["RMSE"].append(rmse)
        metrics_per_fold["R2"].append(r2)
        metrics_per_fold["PCC"].append(pcc)
        metrics_per_fold["SPEARMAN"].append(spearman)

    # Save all results
    df_all = pd.concat(fold_results, ignore_index=True)
    df_all.to_csv(os.path.join(output_dir, f"{dataset_name}_predictions.csv"), index=False)

    # Save combined metrics
    df_metrics = pd.DataFrame(metrics_per_fold)
    mean_row = df_metrics.drop(columns=["Fold"]).mean()
    std_row = df_metrics.drop(columns=["Fold"]).std()

    df_metrics = pd.concat([
        df_metrics,
        pd.DataFrame([["mean"] + list(mean_row.values)], columns=df_metrics.columns),
        pd.DataFrame([["std"] + list(std_row.values)], columns=df_metrics.columns)
    ], ignore_index=True)

    df_metrics.to_csv(os.path.join(output_dir, "combined_metrics.csv"), index=False)

    # Save short summary
    with open(os.path.join(output_dir, "metrics.txt"), "w") as f:
        f.write(f"RMSE:     {mean_row['RMSE']:.3f}\n")
        f.write(f"R2:       {mean_row['R2']:.3f}\n")
        f.write(f"PCC:      {mean_row['PCC']:.3f}\n")
        f.write(f"SPEARMAN: {mean_row['SPEARMAN']:.3f}\n")

    print(f"    [✓] Saved results to: {output_dir}")
print("fin all")

Dataset:   0%|                                                                                          | 0/3 [00:00<?, ?it/s]


[+] Processing dataset: FreeSolv
(642, 2)



Fold FreeSolv:   0%|                                                                                    | 0/5 [00:00<?, ?it/s][A

Predicting fold 1:   0%|                                                                              | 0/129 [00:00<?, ?it/s][A[A

Predicting fold 1:   1%|▌                                                                     | 1/129 [00:00<01:09,  1.84it/s][A[A

Predicting fold 1:   2%|█                                                                     | 2/129 [00:00<00:47,  2.69it/s][A[A

Predicting fold 1:   2%|█▋                                                                    | 3/129 [00:01<00:39,  3.17it/s][A[A

Predicting fold 1:   3%|██▏                                                                   | 4/129 [00:01<00:36,  3.45it/s][A[A

Predicting fold 1:   4%|██▋                                                                   | 5/129 [00:01<00:34,  3.63it/s][A[A

Predicting fold 1:   5%|███▎                                    

    [✓] Saved results to: ./result/QandA/ChemDFM-v1.5-8B/FreeSolv

[+] Processing dataset: ESOL
(1128, 2)



Fold ESOL:   0%|                                                                                        | 0/5 [00:00<?, ?it/s][A

Predicting fold 1:   0%|                                                                              | 0/226 [00:00<?, ?it/s][A[A

Predicting fold 1:   0%|▎                                                                     | 1/226 [00:00<03:39,  1.02it/s][A[A

Predicting fold 1:   1%|▌                                                                     | 2/226 [00:01<02:03,  1.82it/s][A[A

Predicting fold 1:   1%|▉                                                                     | 3/226 [00:01<01:32,  2.42it/s][A[A

Predicting fold 1:   2%|█▏                                                                    | 4/226 [00:02<02:21,  1.57it/s][A[A

Predicting fold 1:   2%|█▌                                                                    | 5/226 [00:02<01:49,  2.01it/s][A[A

Predicting fold 1:   3%|█▊                                      

    [✓] Saved results to: ./result/QandA/ChemDFM-v1.5-8B/ESOL

[+] Processing dataset: Lipo
(1400, 2)



Fold Lipo:   0%|                                                                                        | 0/5 [00:00<?, ?it/s][A

Predicting fold 1:   0%|                                                                              | 0/280 [00:00<?, ?it/s][A[A

Predicting fold 1:   0%|▎                                                                     | 1/280 [00:00<01:09,  4.00it/s][A[A

Predicting fold 1:   1%|▌                                                                     | 2/280 [00:00<01:09,  3.98it/s][A[A

Predicting fold 1:   1%|▊                                                                     | 3/280 [00:01<02:41,  1.72it/s][A[A

Predicting fold 1:   1%|█                                                                     | 4/280 [00:01<02:04,  2.21it/s][A[A

Predicting fold 1:   2%|█▎                                                                    | 5/280 [00:02<02:56,  1.56it/s][A[A

Predicting fold 1:   2%|█▌                                      

    [✓] Saved results to: ./result/QandA/ChemDFM-v1.5-8B/Lipo
fin all



