In [2]:
import sys, re, os, json
sys.path.append("..")
from rxnutils import read_json, is_valid_smiles
MODEL_NAME = "ether0"

# Utils

In [3]:
def parse_response(response: str):
    # Attempt to extract a JSON code block and parse it
    # str format: "<|think_start|>...<|think_end|><|answer_start|>...<|answer_end|>"
    # we extract the answer part
    answer_match = re.search(r'<\|answer_start\|>(.*?)<\|answer_end\|>', response, re.DOTALL)
    if answer_match:
        answer = answer_match.group(1).strip()
        return answer
    return ""

# RCR

In [4]:
def evaluate_RCR(model_name:str, log_dir:str="logs/RCR"):
    """
    Evaluate the reaction condition recommendation task
    Metric: SMILES similarity

    Args:
        model_name (str): the name of the model
        log_dir (str): the directory of the logs

    Returns:
        None
    """
    # make sure the log_dir is correct
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/{model_name}.json")
    preds = []
    gts = []
    for sample in samples:
        gts.append(sample['gt'])

        pred_smiles = parse_response(sample['json_response'])
        # further parsing to remove the answer tag
        if pred_smiles.endswith("</answer>"):
            pred_smiles = pred_smiles[:-len("</answer>")]
        
        # further parsing to extract condition smiles
        if ">>" in pred_smiles:
            pred_smiles = pred_smiles.split(">>")[0]
        elif '>' in pred_smiles:
            pred_smiles = pred_smiles.split('>')[1]
        preds.append(pred_smiles)

    from evaluator import MoleculeSMILESEvaluator
    evaluator = MoleculeSMILESEvaluator()
    res = evaluator.evaluate(preds, gts)
    # pretty print the res
    print("exact_match: ", round(res['exact_match'], 2))
    print("bleu: ", round(res['bleu'], 2))
    print("levenshtein: ", round(res['levenshtein'], 2))
    print("rdk_sims: ", round(res['rdk_sims'], 2))
    print("maccs_sims: ", round(res['maccs_sims'], 2))
    print("morgan_sims: ", round(res['morgan_sims'], 2))
    print("validity: ", round(res['validity'], 2))
    fts = (res['rdk_sims'] + res['maccs_sims'] + res['morgan_sims']) / 3
    print("fts: ", round(fts * res['validity'], 2))

In [None]:
evaluate_RCR("ether0")

# Mechanism: NEPP

In [6]:
def evaluate_NEPP(model_name: str, log_dir: str):
    """
    Evaluate the next element-step product prediction task
    Metric: SMILES similarity

    Args:
        model_name (str): the name of the model
        log_dir (str): the directory of the logs

    Returns:
        None
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/{model_name}.json")
    preds = []
    gts = []
    for sample in samples:
        gts.append(sample['gt'])
        json_response = sample['json_response']
        pred_smiles = parse_response(json_response)
        preds.append(pred_smiles)

    from evaluator import MoleculeSMILESEvaluator
    evaluator = MoleculeSMILESEvaluator()
    res = evaluator.evaluate(preds, gts)

    # pretty print the res
    print("exact_match: ", round(res['exact_match'], 2))
    print("bleu: ", round(res['bleu'], 2))
    print("levenshtein: ", round(res['levenshtein'], 2))
    print("rdk_sims: ", round(res['rdk_sims'], 2))
    print("maccs_sims: ", round(res['maccs_sims'], 2))
    print("morgan_sims: ", round(res['morgan_sims'], 2))
    print("validity: ", round(res['validity'], 2))

In [8]:
evaluate_NEPP("ether0", log_dir="logs/mech_task1")

exact_match:  0.02
bleu:  0.52
levenshtein:  24.98
rdk_sims:  0.62
maccs_sims:  0.61
morgan_sims:  0.53
validity:  0.91


# Mechanism: MechSel

In [9]:
def evaluate_mechsel(model_name: str, logs_dir: str):
    """
    Evaluate the reaction mechanism selection prediction.

    Args:
        logs_dir (str): The directory where the logs are stored.
        model_name (str): The name of the model.

    Returns:
        None
    """
    if not os.path.exists(logs_dir):
        raise ValueError(f"logs_dir {logs_dir} is not correct")
    samples = read_json(f"{logs_dir}/{model_name}.json")

    preds = []
    gts = []
    for sample in samples:
        pred_choice = parse_response(sample['json_response'])
        if len(pred_choice) > 1:
            # if multiple chars, we take the first one
            pred_choice = pred_choice[0]
        # if pred_choice is not a valid choice, we treat it as empty
        if pred_choice.lower() not in ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']:
            pred_choice = ""

        pred_choice = pred_choice.lower()
        gt = sample['gt'].lower()
        preds.append(pred_choice)
        gts.append(gt)

    accuracy = sum(1 for pred, gt in zip(preds, gts) if pred == gt) / len(gts)
    print(f"MCQ Accuracy (mean): {accuracy:.2f}")

# Example usage
evaluate_mechsel(MODEL_NAME, "logs/mech_task2")

MCQ Accuracy (mean): 0.27


# Reaction

## major product

In [None]:
def evaluate_fs(model_name: str, log_dir: str):
    """
    Evaluate SMILES predictions against ground truth.

    Args:
        model_name (str): The name of the model.
        log_dir (str): The directory where the logs are stored.

    Returns:
        None
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/{model_name}.json")
    preds = []
    gts = []
    for sample in samples:
        ## parse the gt
        try:
            gt = sample['products']
        except:
            gt = json.loads(sample['gt']).get("Major Product")
        if isinstance(gt, list):
            gt = '.'.join(gt)
        gts.append(gt)

        ## parse the pred
        pred_smiles = parse_response(sample['json_response'])
        preds.append(pred_smiles)

    from evaluator import MoleculeSMILESEvaluator
    evaluator = MoleculeSMILESEvaluator()
    res = evaluator.evaluate(preds, gts)

    print("exact_match: ", round(res['exact_match'], 2))
    print("bleu: ", round(res['bleu'], 2))
    print("levenshtein: ", round(res['levenshtein'], 2))
    print("rdk_sims: ", round(res['rdk_sims'], 2))
    print("maccs_sims: ", round(res['maccs_sims'], 2))
    print("morgan_sims: ", round(res['morgan_sims'], 2))
    print("validity: ", round(res['validity'], 2))
    fts = (res['rdk_sims'] + res['maccs_sims'] + res['morgan_sims']) / 3
    print("fts: ", round(fts * res['validity'], 2))

# Example usage
evaluate_fs(MODEL_NAME, "logs/fs")

exact_match:  0.74
bleu:  0.77
levenshtein:  12.56
rdk_sims:  0.83
maccs_sims:  0.86
morgan_sims:  0.82
validity:  0.92


## by product

In [None]:
def parse_raw_response(raw_response: str, field: str) -> str:
    """从原始响应字符串中提取指定字段的值
    
    Args:
        raw_response: 可能包含 JSON 块或键值对的原始字符串
        field: 需要提取的字段名（如 "Byproduct(s)"）
    
    Returns:
        提取到的字段值字符串，未找到则返回空字符串
    """
    
    # 改进点 1: 更健壮的 JSON 块检测（支持不同格式的代码块）
    json_block_match = re.search(
        r'```(?:json)?\s*({.*?})\s*```',  # 匹配可选的 json 标记和任意空格
        raw_response, 
        flags=re.DOTALL | re.IGNORECASE
    )
    
    if json_block_match:
        try:
            # 改进点 2: 更严格的 JSON 解析（处理嵌套结构）
            json_str = json_block_match.group(1).strip()
            data = json.loads(json_str)
            
            # 改进点 3: 处理数组型结果（如 ["HCl", "H2O"]）
            if value := data.get(field):
                if isinstance(value, list):
                    return ", ".join(map(str, value))
                return str(value)
        except (json.JSONDecodeError, AttributeError):
            pass  # 解析失败则继续尝试其他方法

    # 改进点 4: 更灵活的正则匹配（处理多种引号和转义）
    escaped_field = re.escape(field)  # 转义特殊字符如括号
    
    # 模式 1: 匹配双引号字符串（允许转义双引号）
    pattern = fr'"{escaped_field}":\s*"((?:\\"|[^"])*)"'
    if match := re.search(pattern, raw_response):
        try:
            # 使用 JSON 解析处理转义字符
            return json.loads(f'"{match.group(1)}"')
        except json.JSONDecodeError:
            return match.group(1).replace(r'\"', '"')
    
    # 模式 2: 匹配单引号字符串（允许转义单引号）
    pattern = fr'"{escaped_field}":\s*\'((?:\\\'|[^\'])*)\''
    if match := re.search(pattern, raw_response):
        return match.group(1).replace(r"\'", "'")
    
    # 改进点 5: 匹配无引号的值（如数值或布尔值）
    pattern = fr'"{escaped_field}":\s*([^,\s}}]+)'
    if match := re.search(pattern, raw_response):
        return match.group(1).strip()

    return ""  # 所有模式均未匹配

In [None]:
def evaluate_fs_byproduct(model_name, log_dir:str="logs/fs"):
    samples = read_json(f"{log_dir}/{model_name}.json")
    preds = []
    gts = []
    
    for sample in samples:
        ## parse the gt
        try:
            gt = sample['byproducts']
        except:
            gt = json.loads(sample['gt'])["Byproduct(s)"]
        if len(gt) == 0:  # if no byproducts, skip the sample
            continue
        if isinstance(gt, list):
            gt = '.'.join(gt)

        ## parse the pred
        if isinstance(sample['json_response'], dict):
            if "Byproduct(s)" in sample['json_response']:
                pred_smiles = sample['json_response']["Byproduct(s)"]
            else:
                pred_smiles = parse_raw_response(sample['raw_response'], "Byproduct(s)")
        else:
            try:
                pred_smiles = parse_raw_response(sample['raw_response'], "Byproduct(s)")
            except:
                pred_smiles = ""

        gts.append(gt)
        preds.append(pred_smiles)

    from evaluator import MoleculeSMILESEvaluator
    evaluator = MoleculeSMILESEvaluator()
    res = evaluator.evaluate(preds, gts)
    print("exact_match: ", round(res['exact_match'], 2))
    print("bleu: ", round(res['bleu'], 2))
    print("levenshtein: ", round(res['levenshtein'], 2))
    print("rdk_sims: ", round(res['rdk_sims'], 2))
    print("maccs_sims: ", round(res['maccs_sims'], 2))
    print("morgan_sims: ", round(res['morgan_sims'], 2))
    print("validity: ", round(res['validity'], 2))
    fts = (res['rdk_sims'] + res['maccs_sims'] + res['morgan_sims']) / 3
    print("fts: ", round(fts * res['validity'], 2))

# Example usage
evaluate_fs_byproduct(MODEL_NAME, "logs/fs")

exact_match:  0.0
bleu:  0.0
levenshtein:  0.0
rdk_sims:  0.0
maccs_sims:  0.0
morgan_sims:  0.0
validity:  0.0


## Retro

In [13]:
def evaluate_retro(model_name: str, log_dir: str):
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/{model_name}.json")
    preds = []
    gts = []

    for sample in samples:
        ## parse the gt
        gt = sample['reactants']
        if len(gt) == 0:
            continue
        if isinstance(gt, list):
            gt = '.'.join(gt)
        gts.append(gt)

        ## parse the pred
        pred_smiles = parse_response(sample['json_response'])
        preds.append(pred_smiles)

    from evaluator import MoleculeSMILESEvaluator
    evaluator = MoleculeSMILESEvaluator()
    res = evaluator.evaluate(preds, gts)

    print("exact_match: ", round(res['exact_match'], 2))
    print("bleu: ", round(res['bleu'], 2))
    print("levenshtein: ", round(res['levenshtein'], 2))
    print("rdk_sims: ", round(res['rdk_sims'], 2))
    print("maccs_sims: ", round(res['maccs_sims'], 2))
    print("morgan_sims: ", round(res['morgan_sims'], 2))
    print("validity: ", round(res['validity'], 2))
    fts = (res['rdk_sims'] + res['maccs_sims'] + res['morgan_sims']) / 3
    print("fts: ", round(fts * res['validity'], 2))

evaluate_retro(MODEL_NAME, "logs/rs")

exact_match:  0.0
bleu:  0.5
levenshtein:  25.0
rdk_sims:  0.51
maccs_sims:  0.57
morgan_sims:  0.43
validity:  0.87
fts:  0.44


# Mol-und

In [None]:
task_dict = dict(
    fg_samples="fg_samples", 
    murcko='Murcko_scaffold', 
    ring_count='ring_count',
    ring_system='ring_system_scaffold', 
    mutated='mutated_list', 
    permutated='permutated_list'
)
pred_key_dict = dict(
    fg_samples="count", murcko='Output Scaffold', ring_count='count',
    ring_system='output', mutated='output', permutated='output'
)
gt_key_dict = dict(
    fg_samples="fg_num", murcko='largest_scaffold', ring_count='count',
    ring_system='', mutated='', permutated=''
)

from rdkit import DataStructs, Chem
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol, MurckoScaffoldSmiles # type: ignore
from rdkit.Chem import rdFMCS, AllChem

def scaffold_consistency(src_mol_list, tgt_mol_list)->tuple[float, float]:
    """
    Evaluate the scaffold consistency before&after mol-opt, consistency includes: same or contain
    Metric: Tanimoto molecule similarity

    Args:
        src_mol_list (list): list of source molecules
        tgt_mol_list (list): list of target molecules

    Returns:
        tuple[float, float]: tuple of count of same scaffold and mean of scaffold similarity
    """
    assert len(src_mol_list) == len(tgt_mol_list)

    count_same = 0
    scaffold_score = list()
    for i in range(len(tgt_mol_list)):
        src_smiles, tgt_smiles = src_mol_list[i], tgt_mol_list[i]
        src_mol, tgt_mol = Chem.MolFromSmiles(src_smiles), Chem.MolFromSmiles(tgt_smiles)
        if (src_mol == None or tgt_mol == None) or (src_smiles == "" or tgt_smiles == ""):
            scaffold_score.append(0.0)
            continue

        murcko_scaffold_list = [
            MurckoScaffoldSmiles(smiles) for smiles in [src_smiles, tgt_smiles]
        ]

        if len(set(murcko_scaffold_list)) == 1:
            # if identical, score 1.0
            scaffold_score.append(1.0)
            count_same += 1
        else:
            ## Morgan Fingerprint for scaffold similarity
            murcko_scaffold_mol_list = [
                Chem.MolFromSmiles(murcko_scaffold_list[0]),
                Chem.MolFromSmiles(murcko_scaffold_list[1]),
            ]
            mcs = rdFMCS.FindMCS(murcko_scaffold_mol_list)
            mcs_mol = (
                Chem.MolFromSmarts(mcs.smartsString) if mcs.numAtoms > 0 else None
            )

            if mcs_mol:
                # 计算基于指纹的Tanimoto相似度
                fp1 = AllChem.GetMorganFingerprintAsBitVect(
                    murcko_scaffold_mol_list[0], 2, nBits=1024
                )
                fp2 = AllChem.GetMorganFingerprintAsBitVect(
                    murcko_scaffold_mol_list[1], 2, nBits=1024
                )
                similarity = DataStructs.TanimotoSimilarity(fp1, fp2)
            else:
                similarity = 0.0

            scaffold_score.append(similarity)

    if len(tgt_mol_list) == 0:
        return 0.0, 0.0
    
    mean_scaffold_score = sum(scaffold_score) / len(scaffold_score)
    return count_same, mean_scaffold_score


def evaluate_fg_sample(model_name, log_dir="logs/fg_samples"):
    """
    Evaluate the function-group (fg) count task
    Metric: MAE
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        gt = pred[gt_key_dict['fg_samples']]
        pred = parse_response(pred['json_results'])
        ## Parsing ##
        # if pred like "{count:xx}", we need to extract the count
        if pred.startswith("{count:"):
            pred = pred.split(":")[1].split("}")[0]
        
        # pred should be a number. if failed, we treat the pred as 0
        if not pred.isdigit():
            pred = 0
        pred_list.append(pred)
        gt_list.append(gt)
    
    assert len(gt_list) == len(pred_list), f"len(gt_list): {len(gt_list)}, len(pred_list): {len(pred_list)}"
    score = sum([abs(int(pred_list[i])-int(gt_list[i])) for i in range(len(pred_list))]) / len(gt_list)
    return score


def evaluate_ring_count(model_name, log_dir="logs/ring_count"):
    """
    Evaluate the ring count task
    Metric: MAE
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        id_ = pred['id']
        gt = pred[gt_key_dict['ring_count']]
        pred = parse_response(pred['json_results'])
        # if pred like "{count:xx}", we need to extract the count
        if pred.startswith("{count:"):
            pred = pred.split(":")[1].split("}")[0]
        # pred should be a number. if failed, we treat the pred as 0
        if not pred.isdigit():
            pred = 0
        pred_list.append(pred)
        gt_list.append(gt)
    
    assert len(gt_list) == len(pred_list)
    score = sum([abs(int(pred_list[i])-int(gt_list[i])) for i in range(len(pred_list))]) / len(pred_results)
    return score


def evaluate_mutate(model_name, log_dir="logs/mutated_list"):
    """
    Evaluate the mutated task
    Metric: Accuracy
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        gt = False
        pred = parse_response(pred['json_results'])
        pred = pred.lower()

        # if failed to parse. we treat the pred as False and gt as True
        # this is a special case for the model that cannot parse the response
        if pred == "":
            pred_list.append(False)
            gt_list.append(True)
            continue
        
        pred = False if "unsame" in pred else True
        pred_list.append(pred)
        gt_list.append(gt)
    
    assert len(gt_list) == len(pred_list)
    score = sum(1 for x,y in zip(pred_list, gt_list) if x == y) / len(pred_list)
    return score


def evaluate_permutate(model_name, log_dir="logs/permutated_list"):
    """
    Evaluate the permutated task
    Metric: Accuracy
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        gt = True
        pred = parse_response(pred['json_results'])
        pred = pred.lower()

        # if failed to parse. we treat the pred as False and gt as True
        # this is a special case for the model that cannot parse the response
        if pred == "":
            pred_list.append(False)
            gt_list.append(True)
            continue

        pred = False if "unsame" in pred else True
        pred_list.append(pred)
        gt_list.append(gt)
    
    assert len(gt_list) == len(pred_list)
    score = sum(1 for x,y in zip(pred_list, gt_list) if x == y) / len(pred_list)
    return score


def evaluate_murcko(model_name, log_dir="logs/murcko_scaffold"):
    """
    Evaluate the murcko scaffold task
    Metric: Tanimoto molecule similarity
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        gt = pred[gt_key_dict['murcko']]
        pred = parse_response(pred['json_results'])
        pred_list.append(pred)
        gt_list.append(gt)
    
    same_count, mean_scaffold_score = scaffold_consistency(pred_list, gt_list)
    return mean_scaffold_score


def evaluate_ring_system(model_name, log_dir="logs/ring_system_scaffold"):
    """
    Evaluate the ring system task
    Metric: Accuracy
    """
    if not os.path.exists(log_dir):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    gt_list = []
    pred_list = []
    pred_results = read_json(f"{log_dir}/{model_name}.json")
    for pred in pred_results:
        gt = pred['gt']
        gt = gt.lower() # yes/no
        pred = parse_response(pred['json_results'])
        
        # If failed to parse, we treat the pred as False and gt as True
        # this is a special case for the model that cannot parse the response
        if pred == "":
            pred_list.append(False)
            gt_list.append(True)
            continue
        
        # detailed parsing to determine the pred as yes/no
        pred = pred.lower()
        if "yes" in pred:
            pred = True
        elif "no" in pred:
            pred = False
        else:
            pred = False
        pred_list.append(pred)
        gt_list.append(gt)
    
    assert len(gt_list) == len(pred_list)
    score = sum(1 for x,y in zip(pred_list, gt_list) if x == y) / len(pred_list)
    return score


print('fg_samples: ', evaluate_fg_sample('ether0', '../mol_und_edit/logs/fg_samples'))
print('ring_count: ', evaluate_ring_count('ether0', '../mol_und_edit/logs/ring_count'))

mutated_score = evaluate_mutate('ether0', '../mol_und_edit/logs/mutated_list')
permutated_score = evaluate_permutate('ether0', '../mol_und_edit/logs/permutated_list')
eq_score = (mutated_score + permutated_score) / 2
print('eq_score: ', eq_score)

print('murcko: ', evaluate_murcko('ether0', '../mol_und_edit/logs/murcko_scaffold'))
print('ring_system: ', evaluate_ring_system('ether0', '../mol_und_edit/logs/ring_system_scaffold'))

fg_samples:  0.6
ring_count:  1.2
eq_score:  0.45
murcko:  0.16892521367521368
ring_system:  0.0


# Mol-Edit

In [23]:
from eval.eval_moledit import check_edit_add_valid, check_edit_del_valid, check_edit_sub_valid
def evaluate_moledit_score(task, model_name, log_dir="logs/mol_edit"):
    """
    Evaluate the moledit score
    metric: accuracy
    Args:
        task: the task name
        model_name: the model name
        log_dir: the directory of the logs
    Returns:
        the score
    """
    assert task in ['add', 'delete', 'sub'], "task must be one of add, delete, sub"
    if not os.path.exists(os.path.join(log_dir, task)):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/{task}/{model_name}.json")

    invalid_number = 0
    pred_list, src_list = list(), list()
    group_a_list, group_b_list = list(), list()
    for sample in samples:
        if task == 'add':
            group_a = sample['added_group']
            group_b = None
        elif task == 'delete':
            group_a = sample['removed_group']
            group_b = None
        elif task == 'sub':
            group_a = sample['added_group']
            group_b = sample['removed_group']
        
        src = sample['molecule']
        # try to extract predicted-smiles
        json_results = sample['json_results']
        extracted = parse_response(json_results)
        if extracted == "":
            invalid_number += 1
            continue
        
        pred_list.append(extracted)
        src_list.append(src)
        group_a_list.append(group_a)
        if group_b is not None:
            group_b_list.append(group_b)
    
    assert len(src_list) == len(pred_list) == len(group_a_list)
    # calculate the score
    correct_num = 0
    for i in range(len(src_list)):
        if task in ['add']:
            if check_edit_add_valid(src=src_list[i], tgt=pred_list[i], group=group_a_list[i]):
                correct_num += 1
        if task in ['delete']:
            if check_edit_del_valid(src=src_list[i], tgt=pred_list[i], group=group_a_list[i]):
                correct_num += 1
        if task == 'sub':
            if check_edit_sub_valid(src=src_list[i], tgt=pred_list[i], remove_group=group_b_list[i], add_group=group_a_list[i]):
                correct_num += 1
    print(f"invalid number: {invalid_number}")
    return correct_num / len(src_list)

In [29]:
print("#########################")
print("evaluate add")
add_score = evaluate_moledit_score("add", MODEL_NAME, log_dir="../mol_und_edit/logs")
print("add_score: ", add_score)
print("#########################\n")

print("#########################")
print("evaluate delete")
delete_score = evaluate_moledit_score("delete", MODEL_NAME, log_dir="../mol_und_edit/logs")
print("delete_score: ", delete_score)
print("#########################\n")

print("#########################")
print("evaluate sub")
sub_score = evaluate_moledit_score("sub", MODEL_NAME, log_dir="../mol_und_edit/logs")
print("sub_score: ", sub_score)
print("#########################\n")

weighted_score = add_score * 0.2 + delete_score * 0.2 + sub_score * 0.6
print("weighted_score: ", weighted_score)

#########################
evaluate add
添加amide失败: 目标分子中amide数量为1, 源分子中amide数量为1
invalid number: 2
add_score:  0.9444444444444444
#########################

#########################
evaluate delete
invalid number: 3
delete_score:  0.7647058823529411
#########################

#########################
evaluate sub
invalid number: 2
sub_score:  0.7758620689655172
#########################

weighted_score:  0.8073473067387875


# Mol-opt

In [39]:
import re
from typing import Optional, Literal
import json

def tranform_str_to_json(str_input):
    ## 假如LLM输出的是类似json的字符串, 我需要设定一个逻辑, 把字符串重新转换成json
    ## o1-mini的感觉, 是要移除字符串里面的\n，并且把所有的\"都改成 "
    if "```json\n" in str_input:
        str_input = str_input.split("```json\n")[1]
        str_input = str_input.replace("\n```", '')
    
    unescaped_str = str_input.replace('\n    ', '').replace('\n', '').replace('\"', '"')
    try:
        json_obj = json.loads(unescaped_str)
        return json_obj
    except json.JSONDecodeError as e:
        return None
    

def _validate_format(value: str, format: str) -> Optional[str]:
    """
    验证值的类型是否符合指定的 format。

    Args:
        value (str): 提取的原始值（字符串形式）。
        format (str): 期望的类型（"str"、"int"、"float"、"bool"）。

    Returns:
        Optional[str]: 转换后的值（字符串形式），如果类型不匹配则返回 None。
    """
    try:
        if format == "int":
            int(value)
            return value
        elif format == "float":
            float(value)
            return value
        elif format == "bool":
            if value.lower() in ("true", "false"):
                return value.lower()
            return None
        elif format == "str":
            return value
        return None
    except (ValueError, TypeError):
        return None


def parse_raw_response(
    raw_response: str,
    field: str,
    format: Literal["str", "int", "float", "bool"] = "str"
) -> Optional[str]:
    """
    从 JSON 格式字符串中提取指定字段的值，忽略开头的 <think>...</think> 部分，
    并根据 format 参数验证值的类型。

    Args:
        raw_response (str): 包含 JSON 数据的字符串，可能以 <think>...</think> 开头。
        field (str): 要提取的字段名（如 "count"）。
        format (Literal["str", "int", "float", "bool"]): 期望的返回值类型，默认为 "str"。

    Returns:
        Optional[str]: 字段的值（字符串形式），如果未找到或类型不匹配则返回 None。
    """
    # 1. 移除 <think>...</think> 部分（如果有）
    cleaned_response = re.sub(r'<think>.*?</think>', '', raw_response, flags=re.DOTALL)

    # 2. 尝试匹配带引号的字符串值（如 "count": "2"）
    quoted_pattern = rf'"{field}":\s*"([^"]+)"'
    match = re.search(quoted_pattern, cleaned_response)
    if match:
        value = match.group(1)
        if format == "str":
            return value
        return _validate_format(value, format)

    # 3. 尝试匹配不带引号的值（如 "count": 2, "active": true）
    unquoted_pattern = rf'"{field}":\s*([^,}}\s]+)'
    match = re.search(unquoted_pattern, cleaned_response)
    if match:
        value = match.group(1).strip()
        return _validate_format(value, format)

    # 4. 未找到字段
    return None

In [40]:
import sys
sys.path.append("..")
from eval.eval_metric import mol_opt_evaluater
from eval.eval_metric import calculate_solubility, compute_statistics

prop_dict = dict(
    logp="logp", 
    solubility="solubility", 
    qed="qed", 
    drd="drd2", 
    gsk="gsk3b", 
    jnk="jnk3"
)


def evaluate_molopt_score(task:str, model_name:str, log_dir:str="logs/mol_opt"):
    if not os.path.exists(os.path.join(log_dir, task)):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    prop_evaluater = mol_opt_evaluater(prop=prop_dict[task])
    samples = read_json(f"{log_dir}/{task}/{model_name}.json")
    invalid_number = 0
    pred_list, src_list = list(), list()
    for sample in samples:
        src = sample['src_smiles']

        # try to extract predicted-smiles
        json_results = sample['json_results']
        extracted = parse_response(json_results)
        if extracted == "":
            # print(f"cannot parse from sample {sample['id']}")
            invalid_number += 1
            continue
        
        pred_list.append(extracted)
        src_list.append(src)
    
    assert len(src_list) == len(pred_list)
    # calculate the score
    improve_scores = prop_evaluater.property_improvement(
        src_mol_list=src_list, 
        tgt_mol_list=pred_list, 
        total_num=len(samples)
    )

    print(f"invalid number: {invalid_number}")
    print(f"improvement mean: {round(improve_scores.get('mean'), 2)}")
    print(f"improvement std: {improve_scores.get('variance')}")
    print(f"success_rate: {round(improve_scores.get('success_rate'), 2)}")

def evaluate_solubility_score(model_name, log_dir="logs/mol_opt/"):
    if not os.path.exists(os.path.join(log_dir, "solubility")):
        raise ValueError(f"logs_dir {log_dir} is not correct")
    samples = read_json(f"{log_dir}/solubility/{model_name}.json")
    src_list = []
    tgt_list = []
    key = 'Final Target Molecule'
    invalid_number = 0
    for x in samples:
        json_results = x['json_results']
        src_smiles = x['src_smiles']
        if type(json_results) is dict and key in json_results:
            extracted = json_results[key]
        elif type(json_results) is str:
            # 如果直接是个SMILES字符串, 则直接添加
            if is_valid_smiles(json_results, strict=False):
                extracted = json_results
            else:
                extracted = parse_raw_response(json_results, field=key)
                if extracted == "" or extracted is None:
                    invalid_number += 1
                    continue
        else:
            # print(f"cannot parse from sample {x['id']}")
            invalid_number += 1
            continue
        src_list.append(src_smiles)
        tgt_list.append(extracted)


    gains = []
    for s,t in zip(src_list, tgt_list):
        if not is_valid_smiles(s) or not is_valid_smiles(t):
            gains.append(0.)
        else:
            gains.append(calculate_solubility(t) - calculate_solubility(s))

    stats = compute_statistics(gains, 'solubility', skew=True)
    print(f"invalid number: {invalid_number}")
    print(f"improvement mean: {round(stats.get('mean'), 2)}")
    print(f"improvement std: {stats.get('variance')}")
    print(f"success_rate: {round(stats.get('success_rate'), 2)}")
    

In [44]:
print("\nEvaluate logp")
print(evaluate_molopt_score("logp", MODEL_NAME, log_dir="../mol_opt/logs"))

print("\nEvaluate solubility")
print(evaluate_solubility_score(MODEL_NAME, log_dir="../mol_opt/logs"))

print("\nEvaluate qed")
print(evaluate_molopt_score("qed", MODEL_NAME, log_dir="../mol_opt/logs"))

print("\nEvaluate drd")
print(evaluate_molopt_score("drd", MODEL_NAME, log_dir="../mol_opt/logs"))

print("\nEvaluate gsk")
print(evaluate_molopt_score("gsk", MODEL_NAME, log_dir="../mol_opt/logs"))


Evaluate logp


Found local copy...


invalid number: 1
improvement mean: 0.37
improvement std: 1.159692890137555
success_rate: 0.68
None

Evaluate solubility
invalid number: 93
improvement mean: 0.0
improvement std: 0.0
success_rate: 0.0
None

Evaluate qed
invalid number: 0
improvement mean: -0.37
improvement std: 0.027038276905880027
success_rate: 0.05
None

Evaluate drd


Found local copy...


invalid number: 0
improvement mean: -0.14
improvement std: 0.025330559277142978
success_rate: 0.04
None

Evaluate gsk
invalid number: 0
improvement mean: 0.0
improvement std: 0.0
success_rate: 0.0
None


In [None]:
# NOTE: for jnk evaluation, the oracle pkl is old. So please change to some old version.

print("\nEvaluate jnk")
print(evaluate_molopt_score("jnk", MODEL_NAME, log_dir="../mol_opt/logs"))