In [1]:
import numpy as np
import random

In [2]:
class GoalSampler:
    def __init__(self, candidate_goals, candidate_goal_scores):
        """
        初始化目标采样器
        
        Args:
            candidate_goals: 期望目标列表，每个目标是由2或3个浮点数组成的tuple
            candidate_goal_scores: 每个目标对应的权重列表
        """
        self.candidate_goals = np.array(candidate_goals)
        self.candidate_goal_scores = np.array(candidate_goal_scores)
        
        # 验证输入
        assert len(self.candidate_goals) == len(self.candidate_goal_scores), \
            "目标数量和权重数量不匹配"
    
    def _filter_valid_goals(self):
        """
        过滤掉权重为inf的目标，返回有效的目标和权重
        
        Returns:
            tuple: (valid_goals, valid_scores, valid_indices)
        """
        # 找到权重不是inf的索引
        valid_mask = ~np.isinf(self.candidate_goal_scores)
        valid_indices = np.where(valid_mask)[0]
        
        # 获取有效的目标和权重
        valid_goals = self.candidate_goals[valid_mask]
        valid_scores = self.candidate_goal_scores[valid_mask]
        
        return valid_goals, valid_scores, valid_indices
    
    def uniform_sampling(self):
        """
        方法1：对于权重不是inf的所有期望目标，按均匀分布采样一个出来
        
        Returns:
            tuple: 采样的目标
        """
        valid_goals, valid_scores, valid_indices = self._filter_valid_goals()
        
        if len(valid_goals) == 0:
            raise ValueError("没有有效的目标可供采样（所有权重都是inf）")
        
        # 均匀分布采样
        sampled_idx = np.random.choice(len(valid_goals))
        sampled_goal = tuple(valid_goals[sampled_idx])
        original_idx = valid_indices[sampled_idx]
        
        print(f"均匀采样结果:")
        print(f"  采样目标: {sampled_goal}")
        print(f"  目标权重: {valid_scores[sampled_idx]}")
        print(f"  原始索引: {original_idx}")
        
        return sampled_goal
    
    def weighted_sampling(self):
        """
        方法2：对于权重不是inf的所有期望目标，按权重采样一个出来
        
        Returns:
            tuple: 采样的目标
        """
        valid_goals, valid_scores, valid_indices = self._filter_valid_goals()
        
        if len(valid_goals) == 0:
            raise ValueError("没有有效的目标可供采样（所有权重都是inf）")
        
        # 检查权重是否都为0
        if np.all(valid_scores == 0):
            raise ValueError("所有有效目标的权重都为0，无法进行加权采样")
        
        # 确保权重为正数（如果有权重为负，取绝对值）
        positive_scores = np.abs(valid_scores)
        
        # 归一化权重
        probabilities = positive_scores / np.sum(positive_scores)
        
        # 按权重采样
        sampled_idx = np.random.choice(len(valid_goals), p=probabilities)
        sampled_goal = tuple(valid_goals[sampled_idx])
        original_idx = valid_indices[sampled_idx]
        
        print(f"加权采样结果:")
        print(f"  采样目标: {sampled_goal}")
        print(f"  目标权重: {valid_scores[sampled_idx]}")
        print(f"  采样概率: {probabilities[sampled_idx]:.4f}")
        print(f"  原始索引: {original_idx}")
        
        return sampled_goal
    
    def get_sampling_info(self):
        """
        获取采样信息统计
        
        Returns:
            dict: 包含采样统计信息的字典
        """
        valid_goals, valid_scores, valid_indices = self._filter_valid_goals()
        
        info = {
            'total_goals': len(self.candidate_goals),
            'valid_goals': len(valid_goals),
            'inf_weight_goals': len(self.candidate_goals) - len(valid_goals),
            'valid_goal_indices': valid_indices.tolist(),
            'valid_weights': valid_scores.tolist()
        }
        
        if len(valid_scores) > 0:
            info.update({
                'min_weight': float(np.min(valid_scores)),
                'max_weight': float(np.max(valid_scores)),
                'mean_weight': float(np.mean(valid_scores)),
                'std_weight': float(np.std(valid_scores))
            })
        
        return info


In [4]:
# 示例数据
candidate_goals = [
    (1.0, 2.0),           # 2D目标
    (3.5, 4.2),      # 3D目标
    (0.5, 1.5),           # 2D目标
    (2.1, 3.3),      # 3D目标
    (4.0, 5.0),           # 2D目标
    (1.2, 2.4),      # 3D目标
]

candidate_goal_scores = [
    0.5,    # 正常权重
    np.inf, # 无限权重（将被过滤）
    1.5,    # 正常权重
    0.8,    # 正常权重
    np.inf, # 无限权重（将被过滤）
    2.0,    # 正常权重
]

# 创建采样器
sampler = GoalSampler(candidate_goals, candidate_goal_scores)

# 显示采样信息
print("=" * 50)
print("采样信息统计:")
info = sampler.get_sampling_info()
for key, value in info.items():
    print(f"  {key}: {value}")
print("=" * 50)

# 测试均匀采样
print("\n1. 均匀分布采样测试:")
for i in range(3):
    print(f"\n第{i+1}次采样:")
    goal = sampler.uniform_sampling()

print("\n" + "=" * 50)

# 测试加权采样
print("\n2. 按权重采样测试:")
for i in range(3):
    print(f"\n第{i+1}次采样:")
    goal = sampler.weighted_sampling()

print("\n" + "=" * 50)

# 测试多次采样的分布
print("\n3. 采样分布测试（各采样100次）:")

# 均匀采样分布
uniform_counts = {}
for _ in range(100):
    goal = sampler.uniform_sampling()
    uniform_counts[goal] = uniform_counts.get(goal, 0) + 1

print("\n均匀采样分布（100次）:")
for goal, count in sorted(uniform_counts.items()):
    print(f"  {goal}: {count}次 ({count}%)")

# 加权采样分布
weighted_counts = {}
for _ in range(100):
    goal = sampler.weighted_sampling()
    weighted_counts[goal] = weighted_counts.get(goal, 0) + 1

print("\n加权采样分布（100次）:")
for goal, count in sorted(weighted_counts.items()):
    print(f"  {goal}: {count}次 ({count}%)")

采样信息统计:
  total_goals: 6
  valid_goals: 4
  inf_weight_goals: 2
  valid_goal_indices: [0, 2, 3, 5]
  valid_weights: [0.5, 1.5, 0.8, 2.0]
  min_weight: 0.5
  max_weight: 2.0
  mean_weight: 1.2
  std_weight: 0.5873670062235365

1. 均匀分布采样测试:

第1次采样:
均匀采样结果:
  采样目标: (np.float64(1.0), np.float64(2.0))
  目标权重: 0.5
  原始索引: 0

第2次采样:
均匀采样结果:
  采样目标: (np.float64(1.0), np.float64(2.0))
  目标权重: 0.5
  原始索引: 0

第3次采样:
均匀采样结果:
  采样目标: (np.float64(1.2), np.float64(2.4))
  目标权重: 2.0
  原始索引: 5


2. 按权重采样测试:

第1次采样:
加权采样结果:
  采样目标: (np.float64(0.5), np.float64(1.5))
  目标权重: 1.5
  采样概率: 0.3125
  原始索引: 2

第2次采样:
加权采样结果:
  采样目标: (np.float64(0.5), np.float64(1.5))
  目标权重: 1.5
  采样概率: 0.3125
  原始索引: 2

第3次采样:
加权采样结果:
  采样目标: (np.float64(0.5), np.float64(1.5))
  目标权重: 1.5
  采样概率: 0.3125
  原始索引: 2


3. 采样分布测试（各采样100次）:
均匀采样结果:
  采样目标: (np.float64(0.5), np.float64(1.5))
  目标权重: 1.5
  原始索引: 2
均匀采样结果:
  采样目标: (np.float64(1.0), np.float64(2.0))
  目标权重: 0.5
  原始索引: 0
均匀采样结果:
  采样目标: (np.float64(1.0), np.float64(2.0