In [None]:
import os, sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [None]:
import re
import numpy as np
import pandas as pd
import torch
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, r2_score
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from collections import defaultdict
from Chemprompt.data.data_loader import DataLoader

In [None]:
device = "cuda:0"

In [None]:
# load model
model_repo = "AI4Chem/"
model_name = "ChemLLM-7B-Chat"
model_path = model_repo + model_name
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
    model_path, torch_dtype=torch.bfloat16, trust_remote_code=True
).to(device)

In [None]:
def normalize_dataset_name(name: str) -> str:
    if not isinstance(name, str):
        return ""
    key = name.strip().lower()
    # known aliases / typos
    if key == "freesolv":
        return "FreeSolv"
    if key == "esol":
        return "ESOL"
    if key == "hppb":
        return "HPPB"
    if key == "CaCo2_Wang":
        return "Caco2_Wang"
    return name

In [None]:
# Prompt: dataset-specific property
def make_prompt(smiles: str, dataset_name: str) -> str:
    dname = normalize_dataset_name(dataset_name)

    if dname == "FreeSolv":
        task = "the solvation free energy in water"
        extra = ""
        
    elif dname == "ESOL":
        task = "the aqueous solubility (log mol/L)"
        extra = ""
        
    elif dname == "Lipo":
        task = "the octanol-water partition coefficient (logP)"
        extra = ""
        
    elif dname == "HPPB":
        task = "the human plasma protein binding as percent unbound (fu, %)"
        extra = " Return a single number between 0 and 100."

    elif dname == "Caco2_Wang":
        task = "the apparent Caco-2 cell permeability (logPapp, cm/s)"
        extra = " Return a single numeric value in log scale."
    else:
        task = "the target property"
        extra = ""

    return f"""<|user|>
Please estimate {task} of the following molecule. Just return a number.
Do not explain. Do not include units or molecule name.{extra}

{smiles}
<|assistant|>"""

In [None]:
def extract_first_number(text: str):
    match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", text)
    if match:
        try:
            return float(match.group())
        except:
            return None
    return None

In [None]:
gen_config = GenerationConfig(
    do_sample=True,
    top_k=1,
    temperature=0.9,
    max_new_tokens=500,
    repetition_penalty=1.5,
    pad_token_id=tokenizer.eos_token_id
)

In [None]:
dataset_names = ["FreeSolv"]

In [None]:
loader = DataLoader()

In [None]:
base_output_dir = f"./result/QandA/{model_name}"
os.makedirs(base_output_dir, exist_ok=True)

In [1]:
# Inference per dataset
for dataset_name in tqdm(dataset_names, desc="Dataset"):
    print(f"\n[+] Processing: {dataset_name}")
    x, y = loader.load_dataset(dataset_name)
    x = np.array(x)
    y = np.array(y)

    kf = KFold(n_splits=5, shuffle=True, random_state=42)
    fold_results = []
    metrics_per_fold = defaultdict(list)

    # Output setup
    output_dir = os.path.join(base_output_dir, dataset_name)
    os.makedirs(output_dir, exist_ok=True)
    answer_txt_path = os.path.join(output_dir, f"{dataset_name}_answers.txt")
    with open(answer_txt_path, "w") as f:
        f.write("")

    for fold, (_, test_idx) in enumerate(tqdm(kf.split(x, y), total=5, desc=f"Fold {dataset_name}")):
        x_test = x[test_idx]
        y_test = y[test_idx]
        predictions = []

        for smiles in tqdm(x_test, desc=f"Predicting fold {fold+1}", leave=False):
            prompt = make_prompt(smiles, dataset_name)
            inputs = tokenizer(prompt, return_tensors="pt").to(device)
            outputs = model.generate(**inputs, generation_config=gen_config)
            result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            reply = result.split("<|assistant|>")[-1].strip()

            with open(answer_txt_path, "a") as f:
                f.write(f"{reply}\n")

            pred_val = extract_first_number(reply)
            predictions.append(pred_val)

        df_fold = pd.DataFrame({
            "smiles": list(x_test),
            "true_value": list(y_test),
            "predicted_value": list(predictions),
            "fold": [fold + 1] * len(x_test)
        })
        fold_results.append(df_fold)

        # Metrics for this fold
        df_valid = df_fold.dropna(subset=["predicted_value"])
        y_true = df_valid["true_value"].astype(float).values
        y_pred = df_valid["predicted_value"].astype(float).values

        mse = mean_squared_error(y_true, y_pred)
        rmse = np.sqrt(mse)
        r2 = r2_score(y_true, y_pred)
        pcc, _ = pearsonr(y_true, y_pred)
        spearman, _ = spearmanr(y_true, y_pred)

        metrics_per_fold["Fold"].append(fold + 1)
        metrics_per_fold["RMSE"].append(rmse)
        metrics_per_fold["R2"].append(r2)
        metrics_per_fold["PCC"].append(pcc)
        metrics_per_fold["SPEARMAN"].append(spearman)

    # Save predictions
    df_all = pd.concat(fold_results, ignore_index=True)
    df_all.to_csv(os.path.join(output_dir, f"{dataset_name}_predictions.csv"), index=False)

    # Save metrics
    df_metrics = pd.DataFrame(metrics_per_fold)
    mean_row = df_metrics.drop(columns=["Fold"]).mean()
    std_row = df_metrics.drop(columns=["Fold"]).std()

    df_metrics = pd.concat([
        df_metrics,
        pd.DataFrame([["mean"] + list(mean_row.values)], columns=df_metrics.columns),
        pd.DataFrame([["std"] + list(std_row.values)], columns=df_metrics.columns)
    ], ignore_index=True)

    df_metrics.to_csv(os.path.join(output_dir, "combined_metrics.csv"), index=False)

    # Save short summary
    with open(os.path.join(output_dir, "metrics.txt"), "w") as f:
        f.write(f"RMSE:     {mean_row['RMSE']:.3f}\n")
        f.write(f"R2:       {mean_row['R2']:.3f}\n")
        f.write(f"PCC:      {mean_row['PCC']:.3f}\n")
        f.write(f"SPEARMAN: {mean_row['SPEARMAN']:.3f}\n")

    print(f"    [✓] Saved to {output_dir}")
print("finall")

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/DATA1/bbq9088/anaconda3/envs/ChEmPromptv2/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


The repository for AI4Chem/ChemLLM-7B-Chat contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/AI4Chem/ChemLLM-7B-Chat.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  t
The repository for AI4Chem/ChemLLM-7B-Chat contains custom code which must be executed to correctly load the model. You can inspect the repository content at https://hf.co/AI4Chem/ChemLLM-7B-Chat.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


A new version of the following files was downloaded from https://huggingface.co/AI4Chem/ChemLLM-7B-Chat:
- tokenization_internlm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/AI4Chem/ChemLLM-7B-Chat:
- configuration_internlm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/AI4Chem/ChemLLM-7B-Chat:
- modeling_internlm2.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Dataset:   0%|          | 0/2 [00:00<?, ?it/s]Found local copy...
Loading...
Done!



[+] Processing: Caco2_Wang
('Caco2_Wang', (910, 2))



Fold Caco2_Wang:   0%|          | 0/5 [00:00<?, ?it/s][A

Predicting fold 1:   0%|          | 0/182 [00:00<?, ?it/s][A[A

Predicting fold 1:   1%|          | 1/182 [00:00<01:01,  2.94it/s][A[A

Predicting fold 1:   1%|          | 2/182 [00:00<00:53,  3.39it/s][A[A

Predicting fold 1:   2%|▏         | 3/182 [00:00<00:50,  3.58it/s][A[A

Predicting fold 1:   2%|▏         | 4/182 [00:01<00:44,  4.01it/s][A[A

Predicting fold 1:   3%|▎         | 6/182 [00:01<00:33,  5.32it/s][A[A

Predicting fold 1:   4%|▍         | 8/182 [00:01<00:25,  6.82it/s][A[A

Predicting fold 1:   5%|▌         | 10/182 [00:01<00:20,  8.45it/s][A[A

Predicting fold 1:   7%|▋         | 12/182 [00:01<00:18,  9.16it/s][A[A

Predicting fold 1:   8%|▊         | 14/182 [00:02<00:17,  9.65it/s][A[A

Predicting fold 1:   9%|▉         | 16/182 [00:02<00:16, 10.00it/s][A[A

Predicting fold 1:  10%|▉         | 18/182 [00:02<00:15, 10.60it/s][A[A

Predicting fold 1:  11%|█         | 20/182 [00:02<00:20

    [✓] Saved to ./result/QandA/ChemLLM-7B-Chat/Caco2_Wang

[+] Processing: logBB
('bbb_martins', (2030, 2))



Fold logBB:   0%|          | 0/5 [00:00<?, ?it/s][A

Predicting fold 1:   0%|          | 0/406 [00:00<?, ?it/s][A[A

Predicting fold 1:   0%|          | 1/406 [00:00<01:05,  6.15it/s][A[A

Predicting fold 1:   0%|          | 2/406 [00:00<01:06,  6.11it/s][A[A

Predicting fold 1:   1%|          | 3/406 [00:00<01:05,  6.13it/s][A[A

Predicting fold 1:   1%|          | 4/406 [00:00<01:05,  6.14it/s][A[A

Predicting fold 1:   1%|          | 5/406 [00:00<01:05,  6.15it/s][A[A

Predicting fold 1:   2%|▏         | 7/406 [00:01<00:54,  7.27it/s][A[A

Predicting fold 1:   2%|▏         | 8/406 [00:01<00:57,  6.94it/s][A[A

Predicting fold 1:   2%|▏         | 9/406 [00:01<00:59,  6.70it/s][A[A

Predicting fold 1:   2%|▏         | 10/406 [00:01<01:00,  6.53it/s][A[A

Predicting fold 1:   3%|▎         | 11/406 [00:01<01:01,  6.41it/s][A[A

Predicting fold 1:   3%|▎         | 12/406 [00:01<01:02,  6.33it/s][A[A

Predicting fold 1:   3%|▎         | 13/406 [00:02<01:32,  4.26

    [✓] Saved to ./result/QandA/ChemLLM-7B-Chat/logBB
finall



