In [1]:
import sys
sys.path.append("..")
from dataset import load_polaris_dataset, validate_dataset
from train import get_dataset
import numpy as np
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    PreTrainedTokenizer
)
import torch
from peft import PeftModel
from peft import prepare_model_for_kbit_training
from trl import ModelConfig
from munch import Munch
import json
from pathlib import Path
from functools import partial
import hashlib
from collections import defaultdict
from tqdm import tqdm
from train import GRPOTrainer2
import os
from trl import (
    GRPOConfig, 
    GRPOTrainer,
    get_peft_config
)
from dataclasses import field, dataclass

  from .autonotebook import tqdm as notebook_tqdm


INFO 03-02 00:05:16 __init__.py:190] Automatically detected platform cuda.


2025-03-02 00:05:17,175	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

In [42]:
def compute_mae(contents, smiles, solutions):
    if solutions is None:
        return [0.5] * len(contents) # Return neutral reward if no solution
    smiles2conts = defaultdict(list)
    for content, gold_val, smiles_i in zip(contents, solutions, smiles):
        answer_val = None
        if gold_val is not None:
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            if len(answer_parsed) > 0 and not isinstance(answer_parsed[0], str):
                answer_val = float(answer_parsed[0])
        
        smiles_hash = hashlib.blake2b(smiles_i.encode('utf-8'), digest_size=4).hexdigest()
        smiles2conts[smiles_hash].append({
                       "answer_val": answer_val
                       }) 
    median_maes = []
    for k, v in smiles2conts.items():
        answers_g = [v_i["answer_val"] for v_i in v]
        answers_g = [float(v_i) for v_i in answers_g if v_i is not None]
        answer_median = np.median(answers_g)
        mae_median = np.median(np.abs(float(v[0]["gold_val"]) - answer_median))
        median_maes.append(mae_median)
    return median_maes

def get_tokenizer(
    model_args: ModelConfig, training_args, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
    """Get the tokenizer for the model."""
    # https://github.com/huggingface/open-r1/blob/eeca246b078457bc0f69ba2e8297b799df0e2bda/src/open_r1/utils/model_utils.py#L11
    print("loading tokenizer")
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=False, # model_args.trust_remote_code
    )
    print("tokenizer loaded")

    if training_args.chat_template is not None:
        tokenizer.chat_template = training_args.chat_template
    elif auto_set_chat_template and tokenizer.get_chat_template() is None:
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    print("chat template")
    # if processing_class is None:
    #     processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
    return tokenizer

def get_model(model_name, attn_implementation="flash_attention_2"):
    # Initialize base model
    if attn_implementation is not None:
        kwargs_dict = {"attn_implementation": attn_implementation}
    else:
        kwargs_dict = {}
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="cuda:0", #TODO: how it affects the ddp https://huggingface.co/openai/whisper-large-v3/discussions/63
        low_cpu_mem_usage=True, #TODO: ??
        # use_safetensors=True, #TODO: ??
        **kwargs_dict
    )

    print(f"Model parameters: {model.num_parameters():,}")

    # Check CUDA availability
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = "cpu"
    print(f"Using device: {device}")

    # Move model to the appropriate device
    model.to(device)

    return model

def get_reward_functions(script_args, model_name):
    """
    Returns a list of reward functions based on the script arguments.
    """
    reward_funcs_list = []

    fnc = partial(compute_mae, model_name=model_name)
    fnc.__name__ = compute_mae.__name__
    reward_funcs_registry = {
        "mae": fnc,  # Assuming accuracy_reward is defined in previous steps
    }

    for func_name in script_args.reward_funcs:
        if func_name not in reward_funcs_registry:
            raise ValueError(f"Reward function '{func_name}' not found in registry.")
        reward_funcs_list.append(reward_funcs_registry[func_name])

    return reward_funcs_list

@dataclass
class GRPOScriptArguments:
    """
    Script arguments for GRPO training, specifically related to reward functions.
    """

    reward_funcs: list[str] = field(
        default_factory=lambda: ["mae"], 
        metadata={
            "help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'repetition_penalty'"        },
    )

    repetition_n_grams: int = field(
        default=3,
        metadata={"help": "Number of n-grams for repetition penalty reward"},
    )
    repetition_max_penalty: float = field(
        default=-0.1,
        metadata={"help": "Maximum (negative) penalty for for repetition penalty reward"},
    )

def test_trained_model_inference(sample: str):
    """Test inference with the loaded trained model and tokenizer."""

    # Apply chat template using our tokenizer
    text = tokenizer.apply_chat_template(
        sample,
        tokenize=False,
        add_generation_prompt=True
    )

    # Tokenize the input text
    inputs = tokenizer(text, return_tensors="pt").to(device)

    # Generate output using our *trained_model*
    outputs = model.generate(
        **inputs,
        max_new_tokens=1024, # Maybe generate a bit longer now
        do_sample=False,
        temperature=0.0 #0.7
    )

    # Decode the generated tokens back to text
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

def evaluate(dataset, model_name):
    responses = []
    solution = []
    smiles = []
    for batch in tqdm(dataset, total=len(dataset)):
        response = test_trained_model_inference(batch)
        responses.append(response)
        solution.append(batch["solution"])
        smiles.append(batch["smiles"])

    print("compute mae")
    mae = compute_mae_v2(responses, smiles, solution, model_name=model_name)
    
    return mae

In [43]:
dataset = get_dataset(params=["LogD"], subset_train=50)

Map: 100%|██████████| 221/221 [00:00<00:00, 12781.69 examples/s]
Map: 100%|██████████| 49/49 [00:00<00:00, 7507.89 examples/s]
Map: 100%|██████████| 52/52 [00:00<00:00, 7723.50 examples/s]

Train set size: 221
Test set size: 49

Validating train split:
✓ All required fields present
✓ Prompt format is correct

Validating test split:
✓ All required fields present
✓ Prompt format is correct





In [44]:
MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"

model = get_model(MODEL_NAME, attn_implementation="flash_attention_2")

model = PeftModel.from_pretrained(model, "/home/alisavin/AgenticADMET/outputs/2025-02-26/22-18-57/checkpoint-60/")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

model_args_i = Munch.fromDict({
        "model_name_or_path": MODEL_NAME,
        "model_revision": "main",
        "trust_remote_code": False # TODO: everyboudy sets to True and default is True
        })
training_args_i = Munch.fromDict({"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}"})

tokenizer = get_tokenizer(model_args_i, training_args_i)

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.51s/it]


Model parameters: 7,615,616,512
Using device: cuda
loading tokenizer
tokenizer loaded
chat template


In [48]:
m = evaluate(dataset["validation"], "test/completions/init_v3_correct_format_16_v3")

  0%|          | 0/52 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
  2%|▏         | 1/52 [00:07<06:05,  7.17s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
  4%|▍         | 2/52 [00:14<05:57,  7.15s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
  6%|▌         | 3/52 [00:21<05:49,  7.14s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
  8%|▊         | 4/52 [00:28<05:42,  7.14s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
 10%|▉         | 5/52 [00:35<05:35,  7.15s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
 12%|█▏        | 6/52 [00:42<05:30,  7.19s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
 13%|█▎        | 7/52 [00:50<05:23,  7.20s/it]Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
 15%|█▌        | 8/52 [00:57<05:17,  7.21s/it]Setting `p

In [19]:
dataset = get_dataset(params=["LogD"], rewrite=True)

Map: 100%|██████████| 221/221 [00:00<00:00, 13278.44 examples/s]
Map: 100%|██████████| 49/49 [00:00<00:00, 7124.27 examples/s]
Map: 100%|██████████| 52/52 [00:00<00:00, 8248.39 examples/s]

Train set size: 221
Test set size: 49

Validating train split:
✓ All required fields present
✓ Prompt format is correct

Validating test split:
✓ All required fields present
✓ Prompt format is correct





In [20]:
dict_i = {}
for b in dataset["train"]:
    dict_i[b["smiles"]] = b["solution"]

In [22]:
dict_i

{'CC(C)(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN(CC(N)=O)N=C1)N2)C1=CC=C2OCCOC2=C1 |o1:4|': 3.1,
 'C[C@H]1CN(C2=CN=CC3=CC=CC=C23)C(=O)[C@@]12CN(CC1=NNC=N1)C(=O)C1=CC=C(F)C=C12': 1.3,
 'CC(C)(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN(CC(N)=O)N=C1)N2)C1=CC=C2OCCOC2=N1 |&1:4|': 2.7,
 'CC(C)[C@H](CO)NC1=NC=NC2=C1C=CN2': 1.69,
 'NCC1=CC=CC(NC(=O)[C@@H](NC(=O)OCC2=CC=CC=C2)C2=CC=C(OCC3=CC=CC=C3)C=C2)=C1 |a:10|': 3.2,
 'CC(C)[C@H](NC1=NC=NC2=C1C=CN2)C1=CC=C2CCCS(=O)(=O)C2=C1 |&1:3|': 2.5,
 'C[C@H]1CN(C2=CN=CC3=CC=CC=C23)C(=O)[C@@]12CN(CC1=NC=CN1C)C(=O)C1=CC=C(F)C=C12 |a:1,16|': 2.0,
 'C#CCCC1=CC=C(OCCCC2=CC(C(=O)N(C)C)=NO2)C=C1': 3.5,
 'CNC(=O)C1=CC2=C(N[C@H](C3=CC=C4CCCS(=O)(=O)C4=C3)C(C)C)N=CN=C2N1 |&1:9|': 2.4,
 'CC1=NC2=NC=NN2C(SC2=NN=C(C)O2)=C1': 0.1,
 'O=C(NCC(F)F)[C@H](NC1=CC=C2CNCC2=C1)C1=CC(Br)=CC2=C1NC=N2 |&1:7|': 0.4,
 'CNC(=O)CN1C[C@@]2(C(=O)N(C3=CN=CC4=CC=CC=C34)C[C@@H]2CNC2=CC=C(Cl)N=N2)C2=CC(Cl)=CC=C2C1=O |a:7,22|': 2.1,
 'CNC(=O)C1=CC(Cl)=CC=C1NS(=O)(=O)C1=CC=C(OC2=CC=CC=C2Cl)C=C1': 2.4,
 'CN(C(=O

In [21]:
response_i = test_trained_model_inference(dataset["validation"][0]['prompt'])

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


In [22]:
response_i

"You are an experienced Chemist that provides well-reasoned and detailed responses and excells at predicting ADME properties of molecules. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>. Inside <answer>\n...\n</answer>, when you finished thinking and certian that it is the most accurate answer you can give, put the final answer in the following format: \\boxed{RESULT}, where RESULT is just the final number in float or expression that solves the problem.<｜User｜>The numerical value of LogD (Lipophilicity, like solubility - but then in fatty tissue - LogD is a measure of a molecule's lipophilicity) of the small molecule given it's SMILES 'C[C@H]1CN(C2=CN=CC3=CC=CC=C23)C(=O)[C@@]12CN(CCN1CCOCC1)C(=O)C1=CC=C(Cl)C=C12' is<｜Assistant｜><think>\nAlright, so I need to determine the LogD value of the given molecule based on its SMILES notation. LogD is

In [23]:
dataset["validation"][0]["solution"]

1.8

In [10]:
train_result = grpo_trainer.train(resume_from_checkpoint="/home/alisavin/AgenticADMET/outputs/2025-02-26/22-18-57/checkpoint-60/")

  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
	logging_steps: 1 (from args) != 10 (from trainer_state.json)


  checkpoint_rng_state = torch.load(rng_file)


parsed correctly 3.8 2.1
parsed correctly 1.85 2.1
parsed correctly 2.5 2.1
parsed correctly 2.8 2.1
parsed correctly 3.5 2.1
parsed correctly 3.2 2.1
parsed correctly 3.5 2.1
parsed correctly 2.8 2.1
parsed correctly 2.5 1.9
parsed correctly 1.25 1.9
parsed correctly 3.5 1.9
parsed correctly 1.3 1.9
parsed correctly 3.8 1.9
parsed correctly 4.7 1.9
parsed correctly 2.5 1.9


RuntimeError: The expanded size of the tensor (16) must match the existing size (2) at non-singleton dimension 0.  Target sizes: [16].  Tensor sizes: [2]

In [11]:
# train_result = grpo_trainer.evaluate()

## Compare Tuned and Frozen Model Results

In [12]:
import glob
from pathlib import Path
import json
import hashlib
import numpy as np

pred_path = "./test/completions/tuned_v3_correct_format_16.0/*.json"
num_generations = Path(pred_path).parts[-2].split("_")[-1]
pths_1 = glob.glob(pred_path)
# pth_2 = glob.glob("test/not_tuned_pred/*json")

dict_all = {}
 
mean_mae_1 = []
mean_mae_2 = []

mean_mae_1_median = []
mean_mae_2_median = []


for pth_i in pths_1:
    with open(pth_i, "r") as f:
        dict_i = json.load(f)
    smiles = dict_i["smiles"]
    smiles_hash = hashlib.blake2b(smiles.encode('utf-8'), digest_size=4).hexdigest()

    pth = f"./test/completions/init_v3_correct_format_{num_generations}/parsed_{smiles_hash}.json" if Path(f"test/completions/init_v3_correct_format_{num_generations}/parsed_{smiles_hash}.json").exists() else f"test/completions/init_v3_correct_format_{num_generations}/{smiles_hash}.json"
    with open(pth, "r") as f:
        dict_i_2 = json.load(f)
    
    print(pth)
    print(pth_i)

    dict_all[smiles] = {
        "7BQwen": {
            "completion": dict_i_2["completion"], 
            "answer_val": dict_i_2["answer_val"],
            "mae": dict_i_2["mae"],
            "mae_median": dict_i_2["mae_median"]
        },
        "7BQwenTuned": {
            "completion": dict_i["completion"], 
            "answer_val": dict_i["answer_val"],
            "mae": dict_i["mae"],
            "mae_median": dict_i["mae_median"]
        },
        "gold_val": dict_i_2["gold_val"],
    }
    if dict_i_2["answer_val"] is None:
        dict_all[smiles]["7BQwen"]["answer_parsed"] = dict_i_2["answer_parsed"]
    if dict_i["answer_val"] is None:
        dict_all[smiles]["7BQwenTuned"]["answer_parsed"] = dict_i_2["answer_parsed"]

    if dict_i["mae"] is not None:
        mean_mae_1.extend([float(v_i) if v_i is not None else 10 for v_i in dict_i["mae"]])
    if dict_i_2["mae"] is not None:
        mean_mae_2.append([float(v_i) if v_i is not None else 10 for v_i in dict_i_2["mae"]])
    print([float(v_i) if v_i is not None else 10 for v_i in dict_i["mae"]][:10])
    print([float(v_i) if v_i is not None else 10 for v_i in dict_i_2["mae"]][:10])
    if dict_i["mae_median"] is not None:
        mean_mae_1_median.append(float(dict_i["mae_median"]))
    if dict_i_2["mae_median"] is not None:
        mean_mae_2_median.append(float(dict_i_2["mae_median"]))
print(f"mean mae tuned - {np.mean(mean_mae_1)}, mean mae - {np.mean(mean_mae_2)}")
print(f"median: mean mae tuned - {np.mean([v_i for v_i in mean_mae_1_median if not np.isnan(v_i)])}, mean mae - {np.mean([v_i for v_i in mean_mae_2_median if not np.isnan(v_i)])}")

with open(f"./test/completions/all_results_v3_{num_generations}.json", "w") as f:
    json.dump(dict_all, f, indent=2)
    


./test/completions/init_v3_correct_format_16.0/parsed_8c64959f.json
./test/completions/tuned_v3_correct_format_16.0/parsed_8c64959f.json
[10, 10, 1.7, 0.9000000000000001, 0.9999999999999998, 2.7, 10, 1.4000000000000001, 0.6000000000000001, 2.9000000000000004]
[10, 10, 1.7, 0.9000000000000001, 0.9999999999999998, 2.7, 10, 1.4000000000000001, 0.6000000000000001, 2.9000000000000004]
./test/completions/init_v3_correct_format_16.0/parsed_beff7dfa.json
./test/completions/tuned_v3_correct_format_16.0/parsed_beff7dfa.json
[2.0999999999999996, 0.5, 0.9000000000000001, 1.7000000000000002, 0.09999999999999964, 10, 10, 0.9000000000000001, 2.5, 3.8]
[2.0999999999999996, 0.5, 0.9000000000000001, 1.7000000000000002, 0.09999999999999964, 10, 10, 0.9000000000000001, 2.5, 3.8]
test/completions/init_v3_correct_format_16.0/829e5d68.json
./test/completions/tuned_v3_correct_format_16.0/829e5d68.json
[10, 0.8, 1.7999999999999998, 0.19999999999999996, 2.0, 1.7999999999999998, 3.8, 0.30000000000000004, 3.2, 10