从数据集中提取 Scrambled expression 并转为 前缀表达式

In [1]:
# prefix_converter.py
import re
import json
from sympy import sympify
from sympy.parsing.sympy_parser import parse_expr

def extract_scrambled_expressions(file_path):
    expressions = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.startswith("Scrambled expression :"):
                expr = line.strip().split(":", 1)[1].strip()
                expressions.append(expr)
    return expressions

def extract_polylog_terms(expr_str):
    """
    提取 polylog(2, ...) 项，并带上前面的系数。
    返回结构为：[(系数, 参数表达式), ...]
    """
    pattern = r'([+\-]?\s*\d*\.?\d*)\*?polylog\(2,\s*(.*?)\)'
    matches = re.findall(pattern, expr_str)
    terms = []
    for coeff, arg in matches:
        coeff = coeff.replace(" ", "")
        coeff = coeff if coeff not in ['', '+', '-'] else coeff + '1'
        terms.append((coeff, arg))
    return terms

def build_prefix(tokens):
    return ['add'] + tokens

def generate_prefix(expr_str):
    terms = extract_polylog_terms(expr_str)
    prefix_expr = []
    for coeff, arg in terms:
        # 前缀形式为：mul, 系数, polylog, 2, 参数
        prefix_expr.extend(['mul', coeff, 'polylog', '2', arg])
    return build_prefix(prefix_expr)

if __name__ == '__main__':
    expressions = extract_scrambled_expressions("rl_data_dilogs/test_data.txt")
    all_prefixes = [generate_prefix(expr) for expr in expressions]

    # 输出前 3 个示例
    for i, prefix in enumerate(all_prefixes[:3]):
        print(f"Example {i+1}:")
        print(prefix)
        print()

    # 保存为 JSON 文件
    with open("prefix_expressions.json", "w", encoding="utf-8") as f:
        json.dump(all_prefixes, f, indent=2, ensure_ascii=False)


Example 1:
['add', 'mul', '-6', 'polylog', '2', 'x', 'mul', '-2', 'polylog', '2', '(-2*x**2 - 2*x + 1', 'mul', '-2', 'polylog', '2', '(2*x**2 + 2*x - 1', 'mul', '+6', 'polylog', '2', 'x - 1', 'mul', '-3', 'polylog', '2', 'x**2 - 2*x + 1']

Example 2:
['add', 'mul', '-3', 'polylog', '2', '2*x', 'mul', '-3', 'polylog', '2', '1/(2*x - 1', 'mul', '-7', 'polylog', '2', '-2/(x**2 - 2', 'mul', '-7', 'polylog', '2', '2/(x**2 - 2', 'mul', '+3', 'polylog', '2', '1/(4*x**2 - 4*x + 1', 'mul', '-7', 'polylog', '2', 'x**4/4 - x**2 + 1']

Example 3:
['add', 'mul', '-1', 'polylog', '2', '(x**4 + 4*x**2 + 4', 'mul', '-1', 'polylog', '2', '-x/(x**2 + 2', 'mul', '-1', 'polylog', '2', 'x/(x**2 + 2', 'mul', '-2', 'polylog', '2', '-2*x**2 - 2*x + 1', 'mul', '+2', 'polylog', '2', '1/(2*x**2 + 2*x - 1', 'mul', '+1', 'polylog', '2', '4*x**4 + 8*x**3 - 4*x + 1']



将前缀表达式转为 One-Hot 编码

In [2]:
import numpy as np
import json
from sympy import sympify
from sympy.core import Symbol, Function, Add, Mul, Pow

# 词表：包括负常数 -10 到 9
# 定义词表
vocab = ['add', 'mul', 'polylog', '+', '*', '**', 'x'] + [str(i) for i in range(-10, 10)]

# token → id 映射
token_to_id = {token: idx for idx, token in enumerate(vocab)}

# id → token 映射（反向映射）
id_to_token = {idx: token for token, idx in token_to_id.items()}

vocab_size = len(vocab)
L_max = 64  # 最大 token 长度

def flatten_expr(expr):
    """递归地将 sympy 表达式转换为 prefix token 列表"""
    if expr.is_Symbol:
        return [str(expr)]
    elif expr.is_Number:
        return [str(expr)]
    elif isinstance(expr, Pow):
        base, exp = expr.args
        return ['**'] + flatten_expr(base) + flatten_expr(exp)
    elif isinstance(expr, Mul):
        def binary_mul(args):
            if len(args) == 1:
                return flatten_expr(args[0])
            else:
                return ['*'] + flatten_expr(args[0]) + binary_mul(args[1:])
        return binary_mul(list(expr.args))
    elif isinstance(expr, Add):
        tokens = ['+']
        for arg in expr.args:
            tokens += flatten_expr(arg)
        return tokens
    elif expr.func.__name__ == 'polylog':
        n, arg = expr.args
        return ['polylog', str(n)] + flatten_expr(arg)
    else:
        return [str(expr)]

def tokenize_prefix(prefix_expr):
    """使用 sympy 将 prefix 中的参数进一步拆分成基本操作 token"""
    tokens = []
    for token in prefix_expr:
        try:
            parsed = sympify(token)
            tokens += flatten_expr(parsed)
        except Exception:
            tokens.append(token)
    return tokens

def prefix_to_onehot(prefix_expr):
    tokens = tokenize_prefix(prefix_expr)
    onehots = np.zeros((L_max, vocab_size), dtype=np.float32)
    for i, token in enumerate(tokens[:L_max]):
        if token in token_to_id:
            onehots[i, token_to_id[token]] = 1.0
    return onehots

def encode_all_from_json(json_path, output_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        prefix_data = json.load(f)
    all_onehots = np.stack([prefix_to_onehot(expr) for expr in prefix_data])
    np.save(output_path, all_onehots)
    print(f"Saved {len(all_onehots)} one-hot encoded expressions to {output_path}")

if __name__ == '__main__':
    # 示例用法：从 JSON 加载并编码所有表达式
    encode_all_from_json("prefix_expressions.json", "onehot_expressions.npy")


Saved 1481 one-hot encoded expressions to onehot_expressions.npy


环境PolylogSimplifyEnv

In [3]:
import gym
from gym import spaces
import numpy as np
from sympy import sympify, simplify, Function, Add
#from onehot_encoder import prefix_to_onehot, tokenize_prefix
#from prefix_converter import generate_prefix
import random

class PolylogSimplifyEnv(gym.Env):
    def __init__(self, onehot_data, token_to_id, id_to_token):
        super().__init__()
        self.data = onehot_data  # shape: [N, L_max, vocab_size]
        self.token_to_id = token_to_id
        self.id_to_token = id_to_token

        self.num_actions = 4
        self.max_steps = 50

        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=self.data.shape[1:], dtype=np.float32)
        self.action_space = spaces.Discrete(self.num_actions)

    def reset(self):
        self.current_step = 0
        self.idx = np.random.randint(0, len(self.data))
        self.state = self.data[self.idx].copy()
        self.expr = self.token_to_expr(self.state)
        self.best_dilog_count = self.count_dilogs(self.expr)
        return self.state

    def step(self, action):
        self.current_step += 1

        new_expr = self.apply_action(self.expr, action)
        new_dilog_count = self.count_dilogs(new_expr)

        #reward = 0
        #if new_dilog_count < self.best_dilog_count:
        #    reward = 1
        #    self.best_dilog_count = new_dilog_count

        reward = 0
        if new_dilog_count < self.best_dilog_count:
            reward = 1
            self.best_dilog_count = new_dilog_count
        if action == self.prev_action and new_dilog_count >= self.prev_dilog_count:
            reward -= 0.1  # Cyclic penalty


        self.expr = new_expr
        self.state = self.expr_to_onehot(new_expr)

        done = self.current_step >= self.max_steps or self.best_dilog_count == 0
        return self.state, reward, done, {}

    def apply_action(self, expr, action):
        # 作用于第一项
        # 记得使用函数之后进行简化操作
        try:
            if action == 0:  # reflection
                return simplify(expr.replace(
                    lambda f: f.func.__name__ == 'polylog',
                    lambda f: -Function('polylog')(2, 1 - f.args[1]) 
                              + (np.pi**2)/6 
                              - sympify("log(%s)*log(1-%s)" % (f.args[1], f.args[1])),
                    1))
            elif action == 1:  # inversion
                return simplify(expr.replace(lambda f: f.func.__name__ == 'polylog',
                                    lambda f: -Function('polylog')(2, 1/f.args[1]) - (np.pi**2)/6 - 0.5 * (sympify("log(-%s)" % f.args[1]))**2, 1))
            elif action == 2:  # duplication
                return simplify(expr.replace(lambda f: f.func.__name__ == 'polylog' and f.args[1].is_Pow and f.args[1].exp == 2,
                                    lambda f: 2 * Function('polylog')(2, f.args[1].base) + 2 * Function('polylog')(2, -f.args[1].base), 1))
            elif action == 3:  # cyclic (polylog 项随机重排)
                return simplify(self.shuffle_polylog_terms(expr))
        except Exception:
            return simplify(expr)

    def shuffle_polylog_terms(self, expr):
        try:
            if not isinstance(expr, Add):
                return expr
            terms = list(expr.args)
            polylog_terms = [t for t in terms if t.has(Function('polylog'))]
            other_terms = [t for t in terms if not t.has(Function('polylog'))]
            random.shuffle(polylog_terms)
            new_expr = Add(*polylog_terms, *other_terms)
            return new_expr
        except Exception:
            return expr

    def token_to_expr(self, onehot):
        tokens = []
        for i in range(onehot.shape[0]):
            idx = np.argmax(onehot[i])
            token = self.id_to_token.get(idx)
            if token: tokens.append(token)
        try:
            expr = sympify(" ".join(tokens))
        except Exception:
            expr = sympify("0")
        return expr

    def count_dilogs(self, expr):
        try:
            return len(expr.atoms(Function).intersection({Function('polylog')}))
        except Exception:
            return 0

    def expr_to_onehot(self, expr):
        try:
            prefix = generate_prefix(str(expr))
            return prefix_to_onehot(prefix)
        except Exception:
            return self.state


神经网络

In [4]:
# ===== Cell 4 —— networks.py =====
import torch
import torch.nn as nn
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.policies import ActorCriticPolicy


class IdentityExtractor(BaseFeaturesExtractor):
    """
    观测已是展平的一维向量；这里只做透传。
    """
    def __init__(self, observation_space):
        super().__init__(observation_space,
                         features_dim=observation_space.shape[0])
        self.forward = nn.Identity()


class CustomActorCriticPolicy(ActorCriticPolicy):
    """
    简洁版 Actor–Critic：
      • 共享前端 → π / V 分支  
      • 网络宽度放到构造参数 `net_arch` 控制
    """
    def __init__(
        self,
        observation_space,
        action_space,
        lr_schedule,
        net_arch=(256, 128, 64),
        activation_fn=nn.ReLU,
        **kwargs,
    ):
        self._net_arch = net_arch
        self._activation_fn = activation_fn
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            features_extractor_class=IdentityExtractor,
            **kwargs,
        )

def _build_mlp_extractor(self):
    """
    自定义共享网络 + π/V 分支，并向父类暴露
    `mlp_extractor`, `latent_dim_pi`, `latent_dim_vf` 三个关键属性。
    """
    feat_dim = self.features_dim
    layers = []
    last = feat_dim
    # 共享前端
    for h in self._net_arch[:-1]:
        layers += [nn.Linear(last, h), self._activation_fn()]
        last = h
    self.shared_net = nn.Sequential(*layers)

    # π 网络
    self.policy_net = nn.Sequential(
        nn.Linear(last, self._net_arch[-1]),
        self._activation_fn(),
        nn.Linear(self._net_arch[-1], self.action_space.n),
    )
    # V 网络
    self.value_net = nn.Sequential(
        nn.Linear(last, self._net_arch[-1]),
        self._activation_fn(),
        nn.Linear(self._net_arch[-1], 1),
    )

    # ---------- 关键补丁：向 SB3 暴露 extractor ----------
    # 1) 让 TRPO 找到 `self.mlp_extractor`
    self.mlp_extractor = self.shared_net
    # 2) 告诉父类 latent 向量长度（π / V 通常一样）
    self.latent_dim_pi = last      # = 最后一层输出维度
    self.latent_dim_vf = last
    # ------------------------------------------------------

    # 无需 return，父类会在 __init__ 里调用本函数


TRPO的RL模型

In [5]:
# ===== Cell 5 —— train_trpo.py =====
import os
import torch
import numpy as np
from sb3_contrib import TRPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.callbacks import (
    ProgressBarCallback,
    EvalCallback,
)
# ------- 数据与环境 -------
ONEHOT_PATH = "onehot_expressions.npy"
N_ENVS = 4
TOTAL_STEPS = 500_000

onehot_data = np.load(ONEHOT_PATH)


def make_env():
    return PolylogSimplifyEnv(onehot_data, token_to_id, id_to_token)


train_env = DummyVecEnv([make_env for _ in range(N_ENVS)])
eval_env = DummyVecEnv([make_env])

# ------- Policy & 模型 -------
policy_kwargs = dict(
    net_arch=(256, 128, 64),
    # features_extractor_class 已在自定义策略中固定
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = TRPO(
    policy=CustomActorCriticPolicy,
    env=train_env,
    verbose=1,
    tensorboard_log="./trpo_logs",
    policy_kwargs=policy_kwargs,
    device=device,
)

# ------- 回调（进度条 + 定期评估） -------
callbacks = [
    ProgressBarCallback(),
    EvalCallback(
        eval_env,
        eval_freq=10_000,
        best_model_save_path="./best_model",
        deterministic=True,
        verbose=0,
    ),
]

# ------- 训练 -------
model.learn(total_timesteps=TOTAL_STEPS, callback=callbacks)

# ------- 保存 -------
os.makedirs("models", exist_ok=True)
model.save("models/trpo_polylog_model")




Using cpu device
Logging to ./trpo_logs\TRPO_9


NotImplementedError: Module [IdentityExtractor] is missing the required "forward" function