In [7]:
import sys
sys.path.append("..")
from train import get_dataset
from dataset.polaris_admet_dataset import load_polaris_dataset, SYSTEM_PROMPT, problem_template
from dataset import validate_dataset
import numpy as np
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizer
)
import torch
from trl import ModelConfig
from munch import Munch
from functools import partial
import hashlib
from collections import defaultdict
from tqdm import tqdm
from torch.utils.data import DataLoader
from transformers import default_data_collator
from pathlib import Path
import json

In [8]:
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

In [23]:
def compute_mae(contents, smiles, solutions):
    if solutions is None:
        return [0.5] * len(contents) # Return neutral reward if no solution
    smiles2conts = defaultdict(list)
    for content, gold_val, smiles_i in zip(contents, solutions, smiles):
        answer_val = None
        if gold_val is not None:
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            if len(answer_parsed) > 0 and not isinstance(answer_parsed[0], str):
                answer_val = float(answer_parsed[0])
        
        smiles_hash = hashlib.blake2b(smiles_i.encode('utf-8'), digest_size=4).hexdigest()
        smiles2conts[smiles_hash].append({
                       "answer_val": answer_val
                       }) 
    median_maes = []
    for k, v in smiles2conts.items():
        answers_g = [v_i["answer_val"] for v_i in v]
        answers_g = [float(v_i) for v_i in answers_g if v_i is not None]
        answer_median = np.median(answers_g)
        mae_median = np.median(np.abs(float(v[0]["gold_val"]) - answer_median))
        median_maes.append(mae_median)
    return median_maes

def compute_mae_v2(content, gold_val):
    answer_val = None
    if gold_val is not None:
        answer_parsed = parse(
            content,
            extraction_config=[
                LatexExtractionConfig(
                    normalization_config=NormalizationConfig(
                        nits=False,
                        malformed_operators=False,
                        basic_latex=True,
                        equations=True,
                        boxed="all",
                        units=True,
                    ),
                    boxed_match_priority=0,
                    try_extract_without_anchor=False,
                )
            ],
            extraction_mode="first_match",
        )

        if len(answer_parsed) > 0 and not isinstance(answer_parsed[0], str):
            answer_val = float(answer_parsed[0])
    if answer_val is not None:
        mae = np.median(np.abs(float(gold_val) - answer_val))
        return mae
    else:
        return None

def get_tokenizer(
    model_args: ModelConfig, training_args, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
    """Get the tokenizer for the model."""
    # https://github.com/huggingface/open-r1/blob/eeca246b078457bc0f69ba2e8297b799df0e2bda/src/open_r1/utils/model_utils.py#L11
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=False, # model_args.trust_remote_code
    )

    if training_args.chat_template is not None:
        tokenizer.chat_template = training_args.chat_template
    elif auto_set_chat_template and tokenizer.get_chat_template() is None:
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    return tokenizer

def get_model(model_name, attn_implementation="flash_attention_2"):
    # Initialize base model
    if attn_implementation is not None:
        kwargs_dict = {"attn_implementation": attn_implementation}
    else:
        kwargs_dict = {}
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0", #TODO: how it affects the ddp https://huggingface.co/openai/whisper-large-v3/discussions/63
        low_cpu_mem_usage=True, #TODO: ??
        # use_safetensors=True, #TODO: ??
        **kwargs_dict
    )

    print(f"Model parameters: {model.num_parameters():,}")

    # Check CUDA availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = "cpu"
    print(f"Using device: {device}")

    # Move model to the appropriate device
    model.to(device)

    return model

def test_trained_model_inference(sample: str, model, tokenizer):
    """Test inference with the loaded trained model and tokenizer."""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        # Apply chat template using our tokenizer
        text = tokenizer.apply_chat_template(
            sample,
            tokenize=False,
            add_generation_prompt=True
        )

        # Tokenize the input text
        inputs = tokenizer(text, return_tensors="pt").to(device)

        # Generate output using our *trained_model*
        outputs = model.generate(
            **inputs,
            max_new_tokens=1024, # Maybe generate a bit longer now
            do_sample=False,
            temperature=0.0 #0.7
        )

        # Decode the generated tokens back to text
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def evaluate(SYSTEM_PROMPT, problem_template, model_name, rules=""):
    system_prompt_ = SYSTEM_PROMPT("LogD", rules)
    problem_template_ = problem_template("LogD", "<smiles>")
    system_prompt_hash = hashlib.blake2b(system_prompt_.encode('utf-8'), digest_size=4).hexdigest()
    problem_template_hash = hashlib.blake2b(problem_template_.encode('utf-8'), digest_size=4).hexdigest()
    
    dir_name = f"./benchmark/{model_name}/{system_prompt_hash}_{problem_template_hash}"
    Path(dir_name).mkdir(exist_ok=True, parents=True)

    prompt_pth = f"{dir_name}/prompt.json"

    if not Path(prompt_pth).exists():
        with open(prompt_pth, "w") as f:
            json.dump({
                "system_prompt": system_prompt_,
                "rules": rules,
                "problem_template": problem_template_
            }, f, indent=4)

    dataset = get_dataset(params=["LogD"], subset_train=50, system_prompt_fn=SYSTEM_PROMPT, prompt_template_fn=problem_template)["validation"]
    model = get_model(model_name, attn_implementation="flash_attention_2")
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    model_args_i = Munch.fromDict({
            "model_name_or_path": model_name,
            "model_revision": "main",
            "trust_remote_code": False # TODO: everyboudy sets to True and default is True
            })
    training_args_i = Munch.fromDict({"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}"})

    tokenizer = get_tokenizer(model_args_i, training_args_i)

    maes = []
    for batch in tqdm(dataset, total=len(dataset)):
        response = test_trained_model_inference(batch, model, tokenizer)
        print(batch, response)
        mae = compute_mae_v2(response, batch["solution"])
        maes.append(mae)
    
    return np.mean([l for l in maes if l is not None])


def get_dataset(params=["MLM", "HLM", "KSOL", "LogD", "MDR1-MDCKII"], subset_train=None, subset_valid=None, subset_test=None, rules_prompt_name="rules_v4", rewrite=False, system_prompt_fn=SYSTEM_PROMPT, prompt_template_fn=problem_template):
    dataset = load_polaris_dataset(params=params, rules_prompt_name=rules_prompt_name, rewrite=rewrite)

    print(f"Train set size: {len(dataset['train'])}")
    print(f"Test set size: {len(dataset['test'])}")

    if subset_train is not None:
        dataset["train"] = dataset["train"].select(range(subset_train))
    if subset_valid is not None:
        dataset["validation"] = dataset["validation"].select(range(subset_valid))
    if subset_test is not None:
        dataset["test"] = dataset["test"].select(range(subset_test))

    validate_dataset(dataset)
    return dataset

In [25]:
dct = {"MLM": "is Mouse Liver Microsomal stability measured in uL/min/mg.",
    "HLM": "is Human Liver Microsomal stability measured in uL/min/mg.", 
    "KSOL": "is Solubility measured in uM.",
    "LogD": "is Lipophilicity, like solubility but then in fatty tissue. LogD is a measure of a molecule's lipophilicity.",
    "MDR1-MDCKII": "is Cell permeation measured in 10^-6 cm/s."
    }

SYSTEM_PROMPT_1 = lambda x, rules_prompt_name: f"""You are an experienced Chemist that provides well-reasoned and detailed responses and excells at extimating ADME properties of molecules, especially {x}. {x} {dct[x]}
User asks you to estimate and predict {x} for a small molecule, you first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>. Inside <answer>\n...\n</answer> put the final {x} prediction in the following format: \\boxed{{RESULT}}, where RESULT is just the final number in float or expression that solves the problem.
"""

problem_template = lambda v_name, k: f"What is the numerical value of {v_name} of the '{k}'?"

In [26]:
m = evaluate(SYSTEM_PROMPT_1, problem_template, "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", rules="")
print(m)

Map: 100%|██████████| 221/221 [00:00<00:00, 5797.95 examples/s]
Map: 100%|██████████| 49/49 [00:00<00:00, 4930.45 examples/s]
Map: 100%|██████████| 52/52 [00:00<00:00, 4942.86 examples/s]




Train set size: 221
Test set size: 49

Validating train split:
✓ All required fields present
✓ Prompt format is correct

Validating test split:
✓ All required fields present
✓ Prompt format is correct


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.51s/it]


Model parameters: 7,615,616,512
Using device: cuda


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
  2%|▏         | 1/52 [00:39<33:21, 39.24s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


{'solution': 2.9, 'problem': "What is the numerical value of LogD of the 'CC(C)(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN(CC(N)=O)N=C1)N2)C1=CC=C2OCCOC2=N1 |o1:4|'?", 'property': 'LogD', 'smiles': 'CC(C)(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN(CC(N)=O)N=C1)N2)C1=CC=C2OCCOC2=N1 |o1:4|', 'ground_truth': 2.9, 'prompt': [{'content': "You are an experienced Chemist that provides well-reasoned and detailed responses and excells at extimating ADME properties of molecules, especially LogD. LogD is Lipophilicity, like solubility but then in fatty tissue. LogD is a measure of a molecule's lipophilicity.\nUser asks you to estimate and predict LogD for a small molecule, you first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>. Inside <answer>\n...\n</answer> put the final LogD prediction in the following format: \\boxed{RESULT}, where RESULT is just the final number in float or expr

  4%|▍         | 2/52 [01:18<32:43, 39.26s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


{'solution': 4.0, 'problem': "What is the numerical value of LogD of the 'CC(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN3C=NC=C3C=C1)N2)C1=CC=C2CCCS(=O)(=O)C2=C1 |&1:3|'?", 'property': 'LogD', 'smiles': 'CC(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN3C=NC=C3C=C1)N2)C1=CC=C2CCCS(=O)(=O)C2=C1 |&1:3|', 'ground_truth': 4.0, 'prompt': [{'content': "You are an experienced Chemist that provides well-reasoned and detailed responses and excells at extimating ADME properties of molecules, especially LogD. LogD is Lipophilicity, like solubility but then in fatty tissue. LogD is a measure of a molecule's lipophilicity.\nUser asks you to estimate and predict LogD for a small molecule, you first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>. Inside <answer>\n...\n</answer> put the final LogD prediction in the following format: \\boxed{RESULT}, where RESULT is just the final number in float

  4%|▍         | 2/52 [01:18<32:49, 39.39s/it]


KeyboardInterrupt: 

In [17]:
dataset = get_dataset(params=["LogD"], subset_train=50, system_prompt_fn=SYSTEM_PROMPT, prompt_template_fn=problem_template)

Map: 100%|██████████| 221/221 [00:00<00:00, 5918.56 examples/s]
Map: 100%|██████████| 49/49 [00:00<00:00, 4842.96 examples/s]
Map: 100%|██████████| 52/52 [00:00<00:00, 4929.57 examples/s]

Train set size: 221
Test set size: 49

Validating train split:
✓ All required fields present
✓ Prompt format is correct

Validating test split:
✓ All required fields present
✓ Prompt format is correct



