In [1]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import re
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
import torch
from transformers import pipeline, AutoTokenizer
from collections import defaultdict
from Chemprompt.data.data_loader import DataLoader

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/HDD1/bbq9088/miniconda3/envs/ChEmPrompt1027/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


In [3]:
# Suppress RDKit warnings
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog("rdApp.*")

In [4]:
# Load GPT-OSS model (cuda:0, bfloat16)
model_id = "openai/gpt-oss-20b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

ValueError: The checkpoint you are trying to load has model type `gpt_oss` but Transformers does not recognize this architecture. This could be because of an issue with the checkpoint, or because your version of Transformers is out of date.

You can update Transformers with the command `pip install --upgrade transformers`. If this does not work, and the checkpoint is very new, then there may not be a release version that supports this model yet. In this case, you can get the most up-to-date code by installing Transformers from source with the command `pip install git+https://github.com/huggingface/transformers.git`

In [None]:
def format_prompt(smiles: str, dataset_name: str) -> str:
    if smiles is None:
        return ""  # Invalid SMILES
        
    if dataset_name == "FreeSolv":
        task = "Predict the solvation free energy in water for the molecule below."
    elif dataset_name == "ESOL":
        task = "Predict the aqueous solubility (logS) for the molecule below."
    elif dataset_name == "Lipo":
        task = "Predict the lipophilicity (logP) for the molecule below."
    elif dataset_name == "HPPB":
        task = "Predict the human plasma protein binding (%PPB) for the molecule below."
    elif dataset_name == "Caco2_Wang":
        task = "Predict the Caco-2 cell permeability (logPapp) for the molecule below."
    else:
        return ""

    rules = (
        "Return a single numeric prediction (float).\n"
        "Output EXACTLY one number on its own line.\n"
        "No units. No words. No labels. No extra text.\n"
        "Use '.' as the decimal separator. Scientific notation allowed.\n"
    )

    return f"Human: {task}\n{rules}Molecule (SMILES): {smiles}\nAssistant:"

In [None]:
def extract_number(text: str):
    if not text:
        return None

    for line in text.splitlines():
        line = line.strip()
        if re.fullmatch(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", line):
            return float(line)

    matches = re.findall(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", text)
    return float(matches[-1]) if matches else None

In [None]:
dataset_names = ["Caco2_Wang"]
loader = DataLoader()

base_output_dir = f"./result/QandA/{model_id.split('/')[-1]}"
os.makedirs(base_output_dir, exist_ok=True)

In [None]:
for dataset_name in tqdm(dataset_names, desc="Dataset"):
    print(f"\n[+] Processing dataset: {dataset_name}")

    x, y = loader.load_dataset(dataset_name)
    x, y = np.array(x), np.array(y)

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []
    metrics_per_fold = defaultdict(list)
    skipped_records = []

    output_dir = os.path.join(base_output_dir, dataset_name)
    os.makedirs(output_dir, exist_ok=True)

    answer_txt_path = os.path.join(output_dir, f"{dataset_name}_answers.txt")
    skipped_csv_path = os.path.join(output_dir, f"{dataset_name}_skipped_samples.csv")

    open(answer_txt_path, "w").close()

    # =========================
    # Cross-validation
    # =========================
    for fold, (_, test_idx) in enumerate(kf.split(x), start=1):
        x_test, y_test = x[test_idx], y[test_idx]
        predictions = []

        for i, smiles in enumerate(tqdm(x_test, desc=f"Fold {fold}", leave=False), start=1):
            prompt = format_prompt(smiles, dataset_name)
            if not prompt:
                predictions.append(None)
                skipped_records.append({"fold": fold, "index": i, "smiles": str(smiles)})
                continue

            messages = [
                {"role": "system", "content": "You are a scientific regression model. Output exactly one float."},
                {"role": "user", "content": prompt}
            ]

            outputs = pipe(
                messages,
                max_new_tokens=32767,
                do_sample=False,
            )

            reply = outputs[0]["generated_text"][-1]["content"].strip()

            with open(answer_txt_path, "a") as f:
                f.write(reply + "\n")
                f.write("=" * 50 + "\n")

            pred = extract_number(reply)
            if pred is None:
                skipped_records.append({"fold": fold, "index": i, "smiles": str(smiles)})

            predictions.append(pred)

        # =========================
        # Fold results
        # =========================
        df_fold = pd.DataFrame({
            "smiles": list(x_test),
            "true_value": list(y_test),
            "predicted_value": list(predictions),
            "fold": fold
        })
        fold_results.append(df_fold)

        df_valid = df_fold.dropna(subset=["predicted_value"])
        y_true = df_valid["true_value"].astype(float).values
        y_pred = df_valid["predicted_value"].astype(float).values

        # =========================
        # Metrics (NaN-safe)
        # =========================
        rmse = np.sqrt(mean_squared_error(y_true, y_pred)) if len(y_pred) > 0 else np.nan
        r2 = r2_score(y_true, y_pred) if len(y_pred) > 0 else np.nan
        pcc = pearsonr(y_true, y_pred)[0] if len(y_pred) >= 2 else np.nan
        spearman = spearmanr(y_true, y_pred)[0] if len(y_pred) >= 2 else np.nan

        metrics_per_fold["Fold"].append(fold)
        metrics_per_fold["RMSE"].append(rmse)
        metrics_per_fold["R2"].append(r2)
        metrics_per_fold["PCC"].append(pcc)
        metrics_per_fold["SPEARMAN"].append(spearman)

    # =========================
    # Save predictions
    # =========================
    df_all = pd.concat(fold_results, ignore_index=True)
    df_all.to_csv(os.path.join(output_dir, f"{dataset_name}_predictions.csv"), index=False)

    if skipped_records:
        pd.DataFrame(skipped_records).to_csv(skipped_csv_path, index=False)
    else:
        with open(skipped_csv_path, "w") as f:
            f.write("No skipped samples\n")

    # =========================
    # Save metrics
    # =========================
    df_metrics = pd.DataFrame(metrics_per_fold)

    metric_cols = ["RMSE", "R2", "PCC", "SPEARMAN"]
    mean_row = df_metrics[metric_cols].mean()
    std_row = df_metrics[metric_cols].std()

    mean_df = pd.DataFrame([["mean"] + mean_row.tolist()], columns=df_metrics.columns)
    std_df = pd.DataFrame([["std"] + std_row.tolist()], columns=df_metrics.columns)

    df_metrics = pd.concat([df_metrics, mean_df, std_df], ignore_index=True)
    df_metrics.to_csv(os.path.join(output_dir, "combined_metrics.csv"), index=False)

    # =========================
    # Save short summary
    # =========================
    with open(os.path.join(output_dir, "metrics.txt"), "w") as f:
        for k in metric_cols:
            f.write(f"{k}: {mean_row[k]:.3f}\n")

    print(f"[âœ“] Saved results to {output_dir}")