In [1]:
from unsloth import FastLanguageModel, PatchFastRL
PatchFastRL("GRPO", FastLanguageModel)

import sys
sys.path.append("..")
from dataset.polaris_admet_dataset import load_polaris_dataset, problem_template
from dataset import validate_dataset
import numpy as np
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    PreTrainedTokenizer
)
import torch
from trl import ModelConfig
from munch import Munch
from functools import partial
import hashlib
from collections import defaultdict

import wandb
run = wandb.init(project="asap-admet-eval")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 03-03 20:10:30 __init__.py:190] Automatically detected platform cuda.


2025-03-03 20:10:31,112	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgazizullina2010[0m ([33mvladvin-org[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

In [3]:
def compute_mae(completions, ground_truth=None, log_normalize=False, model_name="pred", table=None, **kwargs):
    # global RUN
    RUN = wandb.run
    print(completions)

    smiles = kwargs.get("smiles")
    num_generations = len(completions) / len(set(smiles))
    
    contents = [completion[0]["content"] for completion in completions]

    solutions = kwargs.get("solution") # Get solutions from kwargs
    
    if solutions is None:
        return [0.5] * len(completions) # Return neutral reward if no solution
    smiles2conts = defaultdict(list)

    # print(len(kwargs["prompts"]), kwargs["prompts"][0][0])

    for content, gold_val, smiles_i, prompt_dict in zip(contents, solutions, smiles, kwargs["prompts"]):
        answer_val = None
        mae = 2.0
        if gold_val is not None:  # Check if parsing was successful
            # Parse the model's answer with relaxed normalization
            answer_parsed = parse(
                content,
                extraction_config=[
                    LatexExtractionConfig(
                        normalization_config=NormalizationConfig(
                            nits=False,
                            malformed_operators=False,
                            basic_latex=True,
                            equations=True,
                            boxed="all",
                            units=True,
                        ),
                        boxed_match_priority=0,
                        try_extract_without_anchor=False,
                    )
                ],
                extraction_mode="first_match",
            )

            if len(answer_parsed) > 0 and not isinstance(answer_parsed[0], str):
                try:
                    answer_val = float(answer_parsed[0])
                    mae = np.abs(answer_val - float(gold_val))
                    print("parsed correctly", answer_val, gold_val)
                except Exception as e:
                    print(e)
                    answer_val = None
                    mae = None
            

        smiles_hash = hashlib.blake2b(smiles_i.encode('utf-8'), digest_size=4).hexdigest()
        smiles2conts[smiles_hash].append({"completion": content, 
                       "gold_val": str(gold_val), 
                       "answer_parsed": str(answer_parsed), 
                       "smiles": smiles_i,
                       "answer_val": answer_val,
                       "system_input": prompt_dict[0]["content"],
                       "user_prompt": prompt_dict[1]["content"],
                       "mae": mae,
                       }) 
    median_maes = []
    for k, v in smiles2conts.items():
        answers_g = [v_i["answer_val"] for v_i in v]
        answers_g = [float(v_i) for v_i in answers_g if v_i is not None and not np.isnan(float(v_i))]
        if len(answers_g) > 0:
            answer_median = np.median(answers_g)
            mae_median = np.abs(float(v[0]["gold_val"]) - answer_median)
        else:
            mae_median = 2.0
        median_maes.extend([mae_median]*int(num_generations))
        if table is not None:
            for v_i in v:
                table.add_data(k, mae_median, v_i["mae"], v_i["completion"], v_i["system_input"], v_i["user_prompt"], v_i["answer_parsed"], v_i["answer_val"], v_i["gold_val"]) 
    return median_maes

def get_reward_functions(script_args, table):
    """
    Returns a list of reward functions based on the script arguments.
    """
    reward_funcs_list = []

    fnc = partial(compute_mae, table=table)
    fnc.__name__ = compute_mae.__name__
    reward_funcs_registry = {
        "mae": fnc,  # Assuming accuracy_reward is defined in previous steps
    }

    for func_name in script_args.reward_funcs:
        if func_name not in reward_funcs_registry:
            raise ValueError(f"Reward function '{func_name}' not found in registry.")
        reward_funcs_list.append(reward_funcs_registry[func_name])

    return reward_funcs_list

def get_tokenizer(
    model_args: ModelConfig, training_args, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
    """Get the tokenizer for the model."""
    # https://github.com/huggingface/open-r1/blob/eeca246b078457bc0f69ba2e8297b799df0e2bda/src/open_r1/utils/model_utils.py#L11
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=False, # model_args.trust_remote_code
    )

    if training_args.chat_template is not None:
        tokenizer.chat_template = training_args.chat_template
    elif auto_set_chat_template and tokenizer.get_chat_template() is None:
        tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
    return tokenizer

def get_dataset(params=["MLM", "HLM", "KSOL", "LogD", "MDR1-MDCKII"], subset_train=None, subset_valid=None, subset_test=None, rules_prompt_name="rules_v4", rewrite=False, system_prompt_fn=SYSTEM_PROMPT_2, prompt_template_fn=problem_template, properties=False):
    dataset = load_polaris_dataset(params=params, rules_prompt_name=rules_prompt_name, system_prompt_fn=system_prompt_fn, problem_template_fn=prompt_template_fn, properties=properties, rewrite=rewrite)

    print(f"Train set size: {len(dataset['train'])}")
    print(f"Test set size: {len(dataset['test'])}")

    if subset_train is not None:
        dataset["train"] = dataset["train"].select(range(subset_train))
    if subset_valid is not None:
        dataset["validation"] = dataset["validation"].select(range(subset_valid))
    if subset_test is not None:
        dataset["test"] = dataset["test"].select(range(subset_test))

    validate_dataset(dataset)
    return dataset

In [4]:
dct = {"MLM": "is Mouse Liver Microsomal stability measured in uL/min/mg.",
    "HLM": "is Human Liver Microsomal stability measured in uL/min/mg.", 
    "KSOL": "is Solubility measured in uM.",
    "LogD": "is Lipophilicity, like solubility but then in fatty tissue. LogD is a measure of a molecule's lipophilicity.",
    "MDR1-MDCKII": "is Cell permeation measured in 10^-6 cm/s."
    }



SYSTEM_PROMPT_2 = lambda x, rules_prompt_name: f"""You are an experienced Chemist that provides well-reasoned and detailed responses and excells at extimating ADME properties of molecules, especially {x}. {x} {dct[x]}
User asks you to estimate and predict {x} for a small molecule, you first think about the reasoning process as an internal monologue, perform professional necessary calculations, never overly rely on values from user provided metadata as answers, keep it professional as a chemist, and then provide the user with the answer. Respond in the following format:  <think>\n...\n</think>\n<answer>\n...\n</answer>. Where RESULT is just the final number in float or expression that solves the problem.

User-provided SMILES may contain enhanced stereochemistry notation, such as |&1:3| or |o1:4|, which indicates stereochemical information beyond standard SMILES representation.
- & denotes a 50/50 R/S mixture (racemate) for a given atom index.
- o indicates that the molecule exists in only one stereochemical form, but the specific configuration is unknown.
- abs confirms 100% of the given chirality.
While Chem.MolFromSmiles preserves explicit stereochemistry, Chem.MolToSmiles does not retain enhanced stereochemical annotations when generating new SMILES.
For example, in CC(C)(C)[C@H](NC1=NC=NC2=C1C=C(C1=CN(CC(N)=O)N=C1)N2)C1=CC=C2OCCOC2=C1 |o1:4|, the atom at index 4 ([C@H]) is a stereocenter. The o1:4 notation means it is present in a single stereochemical form, but the exact configuration is unspecified. In contrast, |&1:4| would indicate a racemic mixture at the same stereocenter.
These annotations impact binding-related properties but are often ignored by standard molecular embeddings like fingerprinting and GCNs (e.g., ChemProp), meaning embeddings for different enhanced notations may be identical. Additionally, practical reliability of & vs. o is uncertain, as many annotated chemicals turn out to be racemates regardless of notation.

Use all of the following rules to predict LogD:
- Large, nonpolar substituents (e.g. tert‐butyl, isopropyl) increase lipophilicity.
- Extended hydrocarbon chains enhance nonpolar character and raise LogD.
- More aromatic rings typically provide significant hydrophobic surface area, increasing LogD.
- Rigid, fused rings tend to be more lipophilic than isolated aromatic rings.
- Incorporation of Cl, F, or Br generally increases lipophilicity due to their electron-withdrawing yet lipophilic nature.
- The cumulative effect of several halogen atoms can further raise LogD.
- Oxygen in ethers adds some polarity, often reducing LogD slightly compared to pure hydrocarbons.
- Esters introduce polarity through the carbonyl and adjacent oxygen, decreasing lipophilicity relative to nonpolar analogues.
- Amides are polar (and can hydrogen-bond), thus they tend to lower LogD.
- When deprotonated (or even in their neutral form), carboxyl groups strongly favor water, reducing LogD.
- Basic amine groups (if protonated at physiological pH) increase water solubility and lower LogD.
- Nitrogen-containing heterocycles that can be protonated tend to lower lipophilicity.
- Their strong electron-withdrawing nature and hydrogen-bonding capacity drive LogD down.
- Nitro substituents are strongly polar and tend to lower LogD.
- Each additional polar (or ionizable) group further shifts the balance toward hydrophilicity.
- When polar groups internally hydrogen bond, their “external” polarity is masked—potentially raising LogD.
- If bulky groups “shield” polar sites, the molecule may behave more lipophilically than expected.
- Extended conjugation (especially in fused aromatic systems) typically increases nonpolar character.
- Less-flexible, planar systems often pack well into lipophilic environments.
- Rigid, compact, polycyclic frameworks generally exhibit higher lipophilicity.
- Ketones and aldehydes add polarity (via the carbonyl), lowering LogD moderately compared to nonpolar groups.
- Molecules with a higher proportion of carbons versus polar heteroatoms tend to be more lipophilic.
- Fewer nitrogen atoms usually mean fewer opportunities for hydrogen bonding, favoring lipophilicity.
- Fixed positive or negative charges strongly favor aqueous solubility.
- Larger rings or macrocycles can be more lipophilic if they reduce the exposure of polar groups.
- Saturated (non-aromatic) rings tend to be less lipophilic than aromatic rings, all else being equal.
- Ortho, meta, or para placements affect steric hindrance and exposure of polar sites, thereby modulating lipophilicity.
- Delocalized electron density over a conjugated system can reduce the effective polarity, increasing LogD.
- More symmetric molecules may pack more efficiently in lipophilic environments, raising LogD.
- A higher proportion of sp²-hybridized carbons (e.g. in aromatic systems) usually boosts lipophilicity compared to sp³ centers that might bear polar substituents.
- If the SMILES contains “CC(C)(C)” (a tert‐butyl group) attached at a chiral center (e.g. “CC(C)(C)C@H”), then expect an increase in LogD by roughly 1–1.5 units compared to a similar scaffold lacking this group.
- If the molecule shows a fragment like “NC1=NC=NC2=C1C=C(C1=CN(…))N2” appended to a tert‐butyl bearing center, then LogD values typically fall in the ~3.1–3.2 range.
- If a structure features two or more fused aromatic rings (e.g. “C1=CC=C2…C2=C1”), then the overall lipophilicity is enhanced and LogD tends to be above 2.5.
- If an aromatic ring is directly substituted with halogens such as “Cl” or “F” (for instance, “C1=CC(Cl)=CC=C1”), then expect a boost in LogD by about 0.5 units relative to an unsubstituted ring.
- If multiple halogens are present on separate aromatic rings (e.g. both “Cl” and “F” on different rings), then the cumulative effect can raise LogD by around 1.0 unit or more.
- If the SMILES shows an “OCCOC” or “OCCOC2=” fragment (typical of ether linkers between aromatic rings), then these oxygenated linkers slightly reduce lipophilicity—often lowering LogD by 0.2–0.5 units compared to a direct aryl–aryl bond.
- If a sulfonyl fragment appears (i.e. “S(=O)(=O)”), especially appended to an aromatic or aliphatic segment, then LogD is reduced by roughly 0.5–1.0 units because of the high polarity of sulfonyl groups.
- If the structure contains an amide linkage (–C(=O)N–) not shielded by bulky groups, then expect a drop in LogD on the order of 0.3–0.7 units relative to analogous non‐amide linkers.
- If a free carboxylic acid group (“C(=O)O”) is present and not sterically hindered, then the compound tends to have a LogD that is 1–2 units lower than a similar ester or amide analogue.
- If a primary or secondary amine appears (e.g. “NC” or “N(C)C”) that is likely protonated at pH 7.4, then the molecule’s LogD will be lowered by approximately 1 unit relative to a neutral analogue.
- If a quaternary ammonium group (e.g. “N+(CH3)3”) is evident, then expect a dramatic drop in LogD – often resulting in near‐zero or negative values.
- If a heterocyclic aromatic ring contains a non‐protonated nitrogen (e.g. pyridine-like fragments such as “c1ccncc1”), then this feature tends to reduce LogD by 0.2–0.5 units relative to pure carbocyclic aromatics.
- If an extended conjugated system is present (for example, several aromatic rings linked by conjugated bonds), then the molecule often exhibits LogD values above 3 due to the cumulative hydrophobic surface.
- If polar groups (e.g. –OH, –NH2, –C(=O)O) are positioned so that intramolecular hydrogen bonding is likely, then their effective polarity is “masked” and LogD may be higher than predicted by counting polar groups alone.
- If a nitro group (–NO2) is present, then its strong electron‐withdrawing nature usually reduces LogD by about 1 unit or more.
- If an aromatic ether (Ar–O–Ar) is present rather than an aliphatic ether (R–O–R), then the effect on LogD is less pronounced—predict a LogD that is ~0.2–0.3 units higher than for an aliphatic ether analogue.
- If a carbonyl (C=O) is directly attached to an aromatic ring (as in an aromatic ketone), then expect a modest LogD reduction (around 0.3 units) versus a fully hydrocarbon substituted ring.
- If an alkyne is present in an aliphatic chain (e.g. “C#C” in “C#CCCC…”), then its low polarity means it contributes little to water solubility, keeping LogD relatively high.
- If the SMILES indicates a branched alkyl chain (such as an isopropyl group “CC(C)”) rather than a straight chain, then the branching typically increases LogD by enhancing hydrophobicity.
- If the structure contains multiple amide bonds in a row (e.g. a di- or tri-amide linker), then their combined polar effect can lower LogD by up to 1.5–2 units unless balanced by large lipophilic groups.
- If the molecule features a rigid bicyclic or polycyclic aromatic system (e.g. fused rings like “C1=CC2=CC=CC=C2C=C1”), then expect LogD values in the upper range (typically 2.5–4.0) because of the efficient stacking in lipophilic environments.
- If a spirocyclic motif is present, then the compact, three-dimensional arrangement often “hides” polar groups, leading to an increase in LogD by around 0.5 units relative to a more extended analogue.
- If the heterocycle is larger (e.g. a quinoline or isoquinoline instead of a pyridine), then the additional aromatic carbon atoms generally boost LogD by ~0.5–1.0 units.
- If bridging –CH2– groups are present between rings (e.g. “–CH2–” linking two aromatics), then these bridges increase hydrophobicity by reducing the effective polar surface area, raising LogD slightly.
- If the SMILES includes electron-withdrawing substituents (e.g. “CF3” or “Cl”) on an aromatic ring, then these groups reduce hydrogen-bonding capacity and typically increase LogD by about 0.3–0.7 units.
- If electron-donating groups (e.g. “OCH3”) are attached to an aromatic ring, then the increased polarity may lower LogD by ~0.2–0.5 units relative to a halogenated or unsubstituted ring.
- If a cyclic ether (e.g. a tetrahydrofuran ring represented as “C1CCOC1”) is incorporated, then its moderate polarity can reduce LogD by approximately 0.3–0.7 units compared to an all‐carbon ring of similar size.
- If the polar groups (such as amides or carboxyls) appear on the periphery of a large, lipophilic scaffold, then their impact on LogD is partially offset—resulting in LogD values that are about 0.5–1 unit higher than if the same polar groups were isolated.
- If steric hindrance is evident around a normally polar group (for example, an amide adjacent to a bulky tert‐butyl group), then the effective exposure of that polar functionality is reduced, leading to a higher LogD than predicted by polarity alone.
- If the overall structure can be roughly “fragmented” into lipophilic and hydrophilic pieces, then the net LogD is roughly additive. For example, two strongly lipophilic aromatic rings plus one polar amide might yield a predicted LogD in the range of 2.5–3.0, whereas replacing the amide with a carboxylic acid may drop the value by 1–1.5 units.
"""

problem_template = lambda v_name, k, properties: f"What is the numerical value of {v_name} of the '{k}'? You may need some properties for this molecule from RDKiT for your calculations: {properties}." # Also, note that notation after a blank space in smiles string, like '|o1:4|', marks enantiomers. 
# Put the LogD answer in \\boxed{{RESULT}}. 
# problem_template = lambda v_name, k, properties: f"What is the numerical value of {v_name} of the '{k}'? Put the LogD answer in \\boxed{{RESULT}}." # Also, note that notation after a blank space in smiles string, like '|o1:4|', marks enantiomers. 

In [5]:
len(SYSTEM_PROMPT_2("LogD", ""))

6851

In [7]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    PreTrainedTokenizer
)
from trl import (
    GRPOConfig, 
    GRPOTrainer
)
from trl import ModelConfig, get_peft_config
import sys
sys.path.append("..")
from train import GRPOTrainer2, ComputeMetricsCallback, get_model, GRPOScriptArguments

MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
system_prompt_ = SYSTEM_PROMPT_2("LogD", "")
problem_template_ = problem_template("LogD", "<smiles>", "<properties>")
system_prompt_hash = hashlib.blake2b(system_prompt_.encode('utf-8'), digest_size=4).hexdigest()
problem_template_hash = hashlib.blake2b(problem_template_.encode('utf-8'), digest_size=4).hexdigest()

training_args = TrainingArguments(
    logging_dir="./logs/wandb/",
    num_train_epochs=0,             # Total number of training epochs
    per_device_train_batch_size=24,  # Batch size per device during training TODO: change to 16
    per_device_eval_batch_size=24,   # Batch size for evaluation TODO: why it says this   File "/home/alisavin/AgenticADMET/train.py", line 534, in <module>
    logging_steps=1,              # Log every X updates steps TODO: change based on number of steps
    logging_strategy="steps",
    logging_first_step=True,
    evaluation_strategy="epoch",    # Evaluate every `eval_steps`
    save_strategy="epoch",      # Disables regular checkpoints
    save_total_limit=0,      # Makes sure no checkpoints are kept
    load_best_model_at_end=False,  # Disables saving the best model
    save_steps=0,            # No saving at specific steps
    dataloader_num_workers=8,      # Number of subprocesses to use for data loading
    seed=42,                       # Random seed for reproducibility
    bf16=True,                     # Use mixed precision BFP16 training #TODO: ??????
    push_to_hub=False,             # Whether to push the final model to Hugging Face Hub
    report_to=["wandb"],              # Reporting to no one
    run_name=f"{MODEL_NAME.split('/')[-1]}-{system_prompt_hash}-{problem_template_hash}",
    disable_tqdm=False,
    gradient_checkpointing=True,   # Enable gradient checkpointing        
    remove_unused_columns=False,
    do_eval=True, #TODO: use
    gradient_checkpointing_kwargs={"use_reentrant": False}, # TODO: use
    max_steps=-1, #TODO: change to -1
    eval_steps=-1, #TODO: change to -1
    learning_rate=1e-6,            # Initial learning rate for AdamW optimizer
    warmup_ratio=0.1,              # Linear warmup over warmup_ratio fraction of training steps
    weight_decay=0.01,             # Apply weight decay to all layers except bias and LayerNorm weights
    # optim = "adamw_8bit",
)

#TODO: reward, for each property set threashold for MAE to set to range 0 to 1
#TODO: loss always 0

grpo_config = GRPOConfig(
    **training_args.to_dict(), # Convert TrainingArguments to dictionary and unpack
    **{ 
    # REMOVED model_init_kwargs here 
    # We are passing the instantiated 'model' object, so GRPOTrainer doesn't need model_init_kwargs
    },
    num_generations=3, #TODO: 16
    use_vllm=True, #TODO: use True
    vllm_device="cuda:0",
    vllm_gpu_memory_utilization=0.25, # TODO: 0.25 0.7
    vllm_max_model_len=11103+800+82+2048, #14000+1000+82+2048, #TODO: 2048
    max_prompt_length=11103+800+82, #3024, #TODO: 800+
    max_completion_length=2048, #TODO: 1024+ (better 2048/4048 and more)
    temperature=0.01, # TODO: temperature for math task
    log_completions=False
    )

model = get_model(MODEL_NAME, attn_implementation="flash_attention_2") #TODO: change to "flash_attention_2"
model_args_i = Munch.fromDict({
        "model_name_or_path": MODEL_NAME,
        "model_revision": "main",
        "trust_remote_code": False # TODO: everyboudy sets to True and default is True
        })
training_args_i = Munch.fromDict({"chat_template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<｜User｜>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<｜Assistant｜><｜tool▁calls▁begin｜><｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<｜tool▁call▁begin｜>' + tool['type'] + '<｜tool▁sep｜>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<｜tool▁call▁end｜>'}}{{'<｜tool▁calls▁end｜><｜end▁of▁sentence｜>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<｜tool▁outputs▁end｜>' + message['content'] + '<｜end▁of▁sentence｜>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{{'<｜Assistant｜>' + content + '<｜end▁of▁sentence｜>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<｜tool▁outputs▁begin｜><｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<｜tool▁output▁begin｜>' + message['content'] + '<｜tool▁output▁end｜>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<｜tool▁outputs▁end｜>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<｜Assistant｜>'}}{% endif %}"})

tokenizer = get_tokenizer(model_args_i, training_args_i)
model_args = ModelConfig(model_name_or_path=MODEL_NAME)

dataset = get_dataset(params=["LogD"], subset_train=50, system_prompt_fn=SYSTEM_PROMPT_2, prompt_template_fn=problem_template, properties=True)

script_args = GRPOScriptArguments(reward_funcs=["mae"])
text_table = wandb.Table(columns=["smiles_hash", "mae_median", "mae", "completion", "system_input", "user_prompt", "answer_parsed", "asnwer_val", "gold_val"])
reward_functions = get_reward_functions(script_args, text_table) #TODO: check trl they had someshere gpro example and used different rewards including lenght reward

grpo_trainer = GRPOTrainer2(
    model=model,                      # Our initialized Qwen model
    args=grpo_config,                # GRPOConfig (created from TrainingArguments)
    train_dataset=dataset['train'],   # Training dataset
    eval_dataset=dataset['validation'],    # Evaluation dataset
    processing_class=tokenizer, #TODO: check callback from config
    peft_config=get_peft_config(model_args), #TODO: check # label_names
    reward_funcs=reward_functions,
)

149 11086


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  3.00s/it]


Model parameters: 7,615,616,512
Using device: cuda


Map: 100%|██████████| 221/221 [00:00<00:00, 8426.13 examples/s]
Map: 100%|██████████| 49/49 [00:00<00:00, 6025.77 examples/s]
Map: 100%|██████████| 52/52 [00:00<00:00, 6654.78 examples/s]


Train set size: 221
Test set size: 49

Validating train split:
✓ All required fields present
✓ Prompt format is correct

Validating test split:
✓ All required fields present
✓ Prompt format is correct


AttributeError: 'Qwen2ForCausalLM' object has no attribute 'vllm_engine'

In [8]:
train_result = grpo_trainer.evaluate()
run.log({"training_samples" : text_table})

[[{'role': 'assistant', 'content': "<think>\nAlright, so I need to estimate the LogD of this molecule: CC1=CC(C2=NOC(C(F)(F)F)=N2)=CC=C1OCCCC1=CC(C(=O)N2CC[C@](C)(O)C2)=NO1 |&1:28|\n\nFirst, I'll look at the structure. It seems to have several rings, some with substituents. Let me break it down.\n\nI see a lot of fluorine atoms. The molecular formula has C(F)(F)F, which is CF3. Fluorine is highly electronegative and can have a significant impact on LogD. Fluorinated groups are generally less polar than, say, oxygen or nitrogen groups, but they can also act as electron-withdrawing groups, which might lower LogD.\n\nThere are also amide groups. The molecule has an amide (C(=O)N) and a tertiary amine (CC[C@](C)(O)C2). Amines can be polar and hydrogen bond donors, which might lower LogD. However, the tertiary amine is quaternary in this case because of the @ symbol, which indicates a specific stereochemistry. Quaternary amines are less polar than primary or secondary amines because they ca

[[{'role': 'assistant', 'content': "<think>\nAlright, so I need to estimate the LogD of the molecule given by the SMILES string 'COC1=CC=C(C2=NN=C(SC3=CC(C4=CC=CC=C4)=NC4=NC=NN34)N2C)C=C1'. The user has provided some properties from RDKiT, which include molecular weight, TPSA, number of rotatable bonds, H-bond donors and acceptors, aromatic rings, and some counts of functional groups. \n\nFirst, I'll try to visualize the molecule based on the SMILES. It starts with a CO group attached to a benzene ring (COC1=CC=C...). Then there's a substituent on the benzene ring, which is a group in parentheses: (C2=NN=C(SC3=CC(C4=CC=CC=C4)=NC4=NC=NN34)N2C). This seems to be a complex substituent with multiple rings and possibly some nitrogen-containing heterocycles.\n\nBreaking it down, the benzene ring (C1) has a substituent that's another ring system. The substituent starts with C2=NN=C, which suggests a pyrazine or similar heterocycle. Then there's a SC3 group, which is a sulfur connected to anot