# 符号定义
Fragments = {x1, x2, ..., xn}  # 残片集合
α, β, γ = 0.5, 0.3, 0.2       # 多模态权重系数
K = 10                         # Top-K候选数

class MRFBuilder:
    def build_graph(fragments):
        # 构建图结构
        G = Graph()
        
        # 节点初始化
        for xi in fragments:
            G.add_node(xi, 
                      geo_feat=compute_geo_feature(xi),
                      text_feat=BERT.encode(text_extract(xi)),
                      material_feat=compute_texture(xi))
        
        # 边计算
        for xi, xj in combinations(fragments, 2):
            # 计算多模态势函数
            ϕ_geo = exp(-ICP_registration(xi, xj))
            ϕ_text = cosine_sim(xi.text_feat, xj.text_feat)
            ϕ_material = chi2_distance(xi.material_feat, xj.material_feat)
            
            weight = α*ϕ_geo + β*ϕ_text + γ*ϕ_material
            G.add_edge(xi, xj, weight=weight)
        
        return G

    def candidate_selection(G):
        # 置信传播计算边缘概率
        bp = BeliefPropagation(G)
        marginals = bp.run()
        
        # 筛选Top-K候选边
        candidates = sorted(G.edges, 
                           key=lambda e: marginals[e], 
                           reverse=True)[:K]
        return candidates

class RLAgent:
    def __init__(self):
        self.gnn = GNNEncoder()  # 共享编码器
        self.actor = ActorNetwork()
        self.critic = CriticNetwork()
        self.renderer = ImageRenderer()
    
    def state_representation(self, G, t):
        # GNN编码
        node_emb = self.gnn(G.nodes, G.edges)
        global_state = self.gnn.readout(node_emb)
        
        # 动态渲染特征（可选）
        if self.renderer:
            canvas = self.renderer.render(G.assembled)
            img_feat = CNN(canvas)
            global_state = concat(global_state, img_feat)
            
        return global_state
    
    def get_action(self, state):
        # 离散动作选择
        match_probs = self.actor.discrete_head(state)
        selected_edge = Categorical(match_probs).sample()
        
        # 连续旋转调整
        rotate_delta = self.actor.continuous_head(state)
        
        return (selected_edge, rotate_delta)
    
    def update(self, trajectory):
        # 使用PPO算法更新策略
        states, actions, rewards = process_trajectory(trajectory)
        
        values = self.critic(states)
        advantages = compute_gae(rewards, values)
        
        # 策略梯度更新
        for _ in range(ppo_epochs):
            new_probs = evaluate_policy(states, actions)
            actor_loss = ppo_loss(old_probs, new_probs, advantages)
            critic_loss = mse_loss(values, discounted_rewards)
            
            self.optimize(actor_loss + critic_loss)

class TrainingFramework:
    def __init__(self):
        self.mrf_builder = MRFBuilder()
        self.agent = RLAgent()
    
    def train(self, epochs):
        for epoch in range(epochs):
            G = self.mrf_builder.build_graph(Fragments)
            candidates = self.mrf_builder.candidate_selection(G)
            
            state = self.agent.state_representation(G, 0)
            trajectory = []
            
            while not G.is_complete():
                # 决策步骤
                action = self.agent.get_action(state)
                
                # 执行拼接动作
                execute_action(G, action)
                
                # 获取奖励
                reward = self.compute_reward(G, action)
                
                # 存储转移
                trajectory.append( (state, action, reward) )
                
                # 更新状态
                state = self.agent.state_representation(G, t)
                
            # 课程学习奖励
            if epoch < curriculum_stage:
                reward += local_assembly_bonus(G)
                
            # 策略更新
            self.agent.update(trajectory)

    def compute_reward(self, G, action):
        immediate_reward = G.edges[action].weight 
        semantic_score = compute_semantic_coherence(G)
        
        # 稀疏奖励
        if G.is_complete():
            return 100 + immediate_reward + λ*semantic_score
        else:
            return immediate_reward + λ*semantic_score

# 辅助函数
def execute_action(G, action):
    edge, rotation = action
    xi, xj = edge.nodes
    assemble(xi, xj, rotation)
    update_graph_connectivity(G, xi, xj)