In [1]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import random
from collections import Counter

# --- 开始: 您提供的 BalancedIdSampler 代码 ---
# (为了可运行性，我们假设 IdIncludedDataset 在此定义，或者先定义一个mock)

# 首先，我们需要一个 IdIncludedDataset 的 mock/定义
# 根据 BalancedIdSampler，它需要有 flat_sample_map 属性
# flat_sample_map 是一个列表，每个元素是一个字典，包含 'id' 键

class IdIncludedDataset(Dataset):
    def __init__(self, flat_sample_map_with_data):
        """
        Args:
            flat_sample_map_with_data (list): 
                列表中的每个元素是一个字典，例如 {'id': 'dataset_A', 'data_index': original_local_index, 'value': ...}
                这里 'data_index' 是原始数据集内的局部索引，'value' 是模拟的数据。
                BalancedIdSampler 主要关心 'id' 和全局索引。
        """
        self.flat_sample_map = [{'id': item['id'], 'original_id': item['id']} for item in flat_sample_map_with_data] # Sampler 使用 'id'
        self.data_values = [item['value'] for item in flat_sample_map_with_data]


    def __getitem__(self, global_idx):
        # DataLoader 使用 Sampler 返回的全局索引来这里取数据
        # 返回数据和它的原始ID，方便验证
        return {
            "data": self.data_values[global_idx], 
            "id": self.flat_sample_map[global_idx]['id'],
            "global_idx": global_idx
        }

    def __len__(self):
        return len(self.flat_sample_map)

class BalancedIdSampler(Sampler):
    def __init__(self, data_source: IdIncludedDataset, common_samples_per_id=None, shuffle_within_id=True, shuffle_all=True):
        """
        Sampler 实现对不同原始数据集(ID)的平衡加载。

        Args:
            data_source (IdIncludedDataset): 被包装的数据集，必须是 IdIncludedDataset 类型。
            common_samples_per_id (int, optional): 
                每个ID在一个epoch中被采样的目标次数。
                如果为 None, 则默认为样本量最大的ID的样本数 (即对小ID进行过采样)。
            shuffle_within_id (bool): 是否在为每个ID选择样本时进行随机选择。
            shuffle_all (bool): 是否在最后将所有选出的索引进行全局随机打乱。
        """
        super().__init__(data_source)
        self.data_source = data_source
        self.shuffle_within_id = shuffle_within_id
        self.shuffle_all = shuffle_all

        # 1. 按ID对全局索引进行分组
        self.indices_per_id = {}
        if not hasattr(self.data_source, 'flat_sample_map'):
            raise ValueError("data_source 必须是 IdIncludedDataset 的实例，或具有 'flat_sample_map' 属性。")

        for global_idx, sample_info in enumerate(self.data_source.flat_sample_map):
            original_id = sample_info['id'] # 或者 sample_info['original_id'] 取决于 flat_sample_map 的结构
            if original_id not in self.indices_per_id:
                self.indices_per_id[original_id] = []
            self.indices_per_id[original_id].append(global_idx)

        self.id_list = list(self.indices_per_id.keys())
        if not self.id_list: # 如果没有任何有效的ID
            print("警告: 没有有效的ID，Sampler将不会产生任何索引。")
            self._num_samples_epoch = 0
            self.target_samples_per_id = 0
            return

        # 2. 确定每个ID的目标采样数
        num_actual_samples_per_id = {id_str: len(indices) for id_str, indices in self.indices_per_id.items()}
        
        if common_samples_per_id is None:
            if not num_actual_samples_per_id:
                 self.target_samples_per_id = 0 # 防止空字典
            else:
                 self.target_samples_per_id = max(num_actual_samples_per_id.values()) if num_actual_samples_per_id else 0
        else:
            self.target_samples_per_id = common_samples_per_id
        
        # 3. 计算一个epoch的总样本数
        self._num_samples_epoch = self.target_samples_per_id * len(self.id_list)
        if self._num_samples_epoch == 0 and self.id_list:
            print(f"警告: target_samples_per_id ({self.target_samples_per_id}) 或 id_list ({self.id_list}) 为空或0, epoch 将没有样本。")


    def __iter__(self):
        if not self.id_list or self.target_samples_per_id == 0:
            return iter([])

        all_epoch_indices = []
        for id_str in self.id_list:
            id_specific_global_indices = self.indices_per_id[id_str]
            num_actual_id_samples = len(id_specific_global_indices)

            if num_actual_id_samples == 0:
                continue 

            if num_actual_id_samples < self.target_samples_per_id:
                chosen_for_id = random.choices(id_specific_global_indices, k=self.target_samples_per_id)
            elif num_actual_id_samples > self.target_samples_per_id:
                if self.shuffle_within_id:
                    chosen_for_id = random.sample(id_specific_global_indices, k=self.target_samples_per_id)
                else: 
                    chosen_for_id = id_specific_global_indices[:self.target_samples_per_id]
            else: # num_actual_id_samples == self.target_samples_per_id
                if self.shuffle_within_id:
                    chosen_for_id = random.sample(id_specific_global_indices, k=num_actual_id_samples)
                else:
                    chosen_for_id = list(id_specific_global_indices) 
            
            all_epoch_indices.extend(chosen_for_id)

        if self.shuffle_all:
            random.shuffle(all_epoch_indices) 
        
        return iter(all_epoch_indices)

    def __len__(self):
        return self._num_samples_epoch

# --- 结束: 您提供的 BalancedIdSampler 代码 ---

# 1. 创建模拟数据和 IdIncludedDataset
mock_flat_data = []
# Dataset A: 3个样本 (全局索引 0, 1, 2)
for i in range(3):
    mock_flat_data.append({'id': 'A', 'value': f'A{i}'})
# Dataset B: 5个样本 (全局索引 3, 4, 5, 6, 7)
for i in range(5):
    mock_flat_data.append({'id': 'B', 'value': f'B{i}'})
# Dataset C: 2个样本 (全局索引 8, 9)
for i in range(2):
    mock_flat_data.append({'id': 'C', 'value': f'C{i}'})

mock_dataset = IdIncludedDataset(mock_flat_data)

print(f"Mock Dataset size: {len(mock_dataset)}")
print("Original samples per ID:")
for id_key, indices in BalancedIdSampler(mock_dataset, shuffle_all=False).indices_per_id.items(): # 用一个临时sampler获取indices_per_id
    print(f"  ID {id_key}: {len(indices)} samples, global indices: {indices}")
print("-" * 30)

# 2. 测试场景
test_scenarios = [
    {"name": "Oversample to max (B=5), no common_samples_per_id", "common_samples": None, "shuffle_all": False},
    {"name": "Oversample to max (B=5), no common_samples_per_id, shuffle_all", "common_samples": None, "shuffle_all": True},
    {"name": "Set common_samples_per_id = 4 (A,B over/undersample, C oversample)", "common_samples": 4, "shuffle_all": False},
    {"name": "Set common_samples_per_id = 2 (A,B undersample, C exact)", "common_samples": 2, "shuffle_all": True},
    {"name": "Set common_samples_per_id = 0 (should yield nothing)", "common_samples": 0, "shuffle_all": False},
]

BATCH_SIZE = 3 # DataLoader 的批大小

for scenario in test_scenarios:
    print(f"\n--- Testing Scenario: {scenario['name']} ---")
    
    balanced_sampler = BalancedIdSampler(
        data_source=mock_dataset,
        common_samples_per_id=scenario['common_samples'],
        shuffle_within_id=True, # 通常设为True
        shuffle_all=scenario['shuffle_all']
    )

    print(f"Sampler target_samples_per_id: {balanced_sampler.target_samples_per_id}")
    print(f"Sampler total samples for epoch (__len__): {len(balanced_sampler)}")

    if len(balanced_sampler) == 0:
        print("Sampler yields 0 samples, skipping DataLoader iteration.")
        # 验证 DataLoader 在 sampler 为空时的行为
        try:
            empty_loader = DataLoader(
                dataset=mock_dataset,
                batch_sampler=None, # 确保我们用的是 sampler
                sampler=balanced_sampler,
                batch_size=BATCH_SIZE,
                num_workers=0 # 重要：用于调试和确保主进程运行
            )
            count = 0
            for _ in empty_loader:
                count +=1
            assert count == 0, "DataLoader should be empty if sampler is empty"
            print("DataLoader correctly yields 0 batches for an empty sampler.")
        except Exception as e:
            print(f"Error with empty DataLoader: {e}")
        continue


    # 将 sampler 用于 DataLoader
    # 重要: balanced_sampler 应该作为 sampler 参数，而不是 batch_sampler
    try:
        data_loader = DataLoader(
            dataset=mock_dataset,
            sampler=balanced_sampler,
            batch_size=BATCH_SIZE,
            num_workers=0 # 为了简单和可复现的调试，设为0
        )
    except Exception as e:
        print(f"Error creating DataLoader: {e}")
        if "positive integer" in str(e) and balanced_sampler.target_samples_per_id == 0:
             print("This error might be expected if total samples is 0 leading to batch_size issues with sampler.")
        continue


    print(f"DataLoader with batch_size={BATCH_SIZE}:")
    epoch_sampled_ids = []
    epoch_sampled_global_indices = []
    
    batch_num = 0
    for batch in data_loader:
        # batch 应该是一个字典，包含 'data', 'id', 'global_idx' 的批次
        # 我们这里主要关心 'id' 和 'global_idx'
        print(f"  Batch {batch_num}:")
        print(f"    Global Indices: {batch['global_idx'].tolist()}")
        print(f"    Original IDs: {batch['id']}") # batch['id'] 应该是一个包含字符串的列表
        epoch_sampled_ids.extend(batch['id'])
        epoch_sampled_global_indices.extend(batch['global_idx'].tolist())
        batch_num += 1

    if not epoch_sampled_global_indices:
        print("No samples were drawn by the DataLoader.")
    else:
        print("\n  Epoch Sampling Summary:")
        print(f"  Total samples drawn in epoch: {len(epoch_sampled_global_indices)}")
        id_counts = Counter(epoch_sampled_ids)
        print(f"  Samples per ID drawn: {dict(id_counts)}")

        # 验证每个ID的采样数量是否符合预期
        expected_samples_this_epoch = balanced_sampler.target_samples_per_id
        if expected_samples_this_epoch > 0 : # 只有在期望采样数大于0时才进行此验证
            for id_key in balanced_sampler.id_list:
                assert id_counts[id_key] == expected_samples_this_epoch, \
                    f"ID {id_key} expected {expected_samples_this_epoch} samples, got {id_counts[id_key]}"
            print("  ID balancing count per epoch: Correct.")
        elif len(epoch_sampled_global_indices) == 0 and expected_samples_this_epoch == 0 :
            print("  ID balancing count per epoch: Correct (0 samples expected and drawn).")


        # 验证总样本数
        assert len(epoch_sampled_global_indices) == len(balanced_sampler), \
            f"Total samples drawn ({len(epoch_sampled_global_indices)}) != sampler length ({len(balanced_sampler)})"
        print("  Total samples drawn vs sampler __len__(): Correct.")
    print("-" * 30)

Mock Dataset size: 10
Original samples per ID:
  ID A: 3 samples, global indices: [0, 1, 2]
  ID B: 5 samples, global indices: [3, 4, 5, 6, 7]
  ID C: 2 samples, global indices: [8, 9]
------------------------------

--- Testing Scenario: Oversample to max (B=5), no common_samples_per_id ---
Sampler target_samples_per_id: 5
Sampler total samples for epoch (__len__): 15
DataLoader with batch_size=3:
  Batch 0:
    Global Indices: [0, 1, 1]
    Original IDs: ['A', 'A', 'A']
  Batch 1:
    Global Indices: [2, 2, 7]
    Original IDs: ['A', 'A', 'B']
  Batch 2:
    Global Indices: [5, 3, 6]
    Original IDs: ['B', 'B', 'B']
  Batch 3:
    Global Indices: [4, 9, 8]
    Original IDs: ['B', 'C', 'C']
  Batch 4:
    Global Indices: [8, 9, 9]
    Original IDs: ['C', 'C', 'C']

  Epoch Sampling Summary:
  Total samples drawn in epoch: 15
  Samples per ID drawn: {'A': 5, 'B': 5, 'C': 5}
  ID balancing count per epoch: Correct.
  Total samples drawn vs sampler __len__(): Correct.
------------------

