In [1]:
import os
import json
import numpy as np

In [2]:
class DomainDatasetMixer:
    def __init__(
        self,
        stage1_train_path: str,
        stage1_valid_path: str,
        stage2_train_path: str,
        stage2_valid_path: str,
        stage3_train_path: str,
        stage3_valid_path: str,
        mix_ratio: dict = {"stage1": 0.6, "stage2": 0.2, "stage3": 0.2},
        target_train_size: int = 100000,  # 期望的训练集总数
        target_valid_size: int = 1000,    # 期望的验证集总数
        output_dir: str = "stage3/data_mixed"
    ):
        """初始化专业领域数据集混合器
        
        Args:
            stage1_train_path: 基础阶段训练数据路径
            stage1_valid_path: 基础阶段验证数据路径
            stage2_train_path: 特定任务阶段训练数据路径
            stage2_valid_path: 特定任务阶段验证数据路径
            stage3_train_path: 专业领域阶段训练数据路径
            stage3_valid_path: 专业领域阶段验证数据路径
            mix_ratio: 混合比例，字典格式
            target_train_size: 期望的训练集总数
            target_valid_size: 期望的验证集总数
            output_dir: 输出目录
        """
        self.stage1_train_path = stage1_train_path
        self.stage1_valid_path = stage1_valid_path
        self.stage2_train_path = stage2_train_path
        self.stage2_valid_path = stage2_valid_path
        self.stage3_train_path = stage3_train_path
        self.stage3_valid_path = stage3_valid_path
        self.mix_ratio = mix_ratio
        self.target_train_size = target_train_size
        self.target_valid_size = target_valid_size
        self.output_dir = output_dir
        
        # 验证比例和是否为1
        if abs(sum(mix_ratio.values()) - 1) > 1e-6:
            raise ValueError("混合比例之和必须为1")
            
        # 确保输出目录存在
        os.makedirs(output_dir, exist_ok=True)

    def load_data(self):
        """加载所有数据集"""
        print("正在加载数据集...")
        
        # 加载三个阶段的数据
        with open(self.stage1_train_path, 'r', encoding='utf-8') as f:
            self.stage1_train = json.load(f)
        with open(self.stage1_valid_path, 'r', encoding='utf-8') as f:
            self.stage1_valid = json.load(f)
            
        with open(self.stage2_train_path, 'r', encoding='utf-8') as f:
            self.stage2_train = json.load(f)
        with open(self.stage2_valid_path, 'r', encoding='utf-8') as f:
            self.stage2_valid = json.load(f)
            
        with open(self.stage3_train_path, 'r', encoding='utf-8') as f:
            self.stage3_train = json.load(f)
        with open(self.stage3_valid_path, 'r', encoding='utf-8') as f:
            self.stage3_valid = json.load(f)
            
        print(f"✓ Stage1 训练集: {len(self.stage1_train)} 条数据")
        print(f"✓ Stage1 验证集: {len(self.stage1_valid)} 条数据")
        print(f"✓ Stage2 训练集: {len(self.stage2_train)} 条数据")
        print(f"✓ Stage2 验证集: {len(self.stage2_valid)} 条数据")
        print(f"✓ Stage3 训练集: {len(self.stage3_train)} 条数据")
        print(f"✓ Stage3 验证集: {len(self.stage3_valid)} 条数据")

    def analyze_types(self, data):
        """分析数据集中各个类型的分布"""
        type_counts = {}
        domain_counts = {}
        
        for item in data:
            # 统计数据类型
            type_name = item.get('type', 'unknown')
            type_counts[type_name] = type_counts.get(type_name, 0) + 1
            
            # 统计专业领域
            domain = item.get('domain', 'general')
            domain_counts[domain] = domain_counts.get(domain, 0) + 1
            
        print("\n数据类型分布:")
        for type_name, count in type_counts.items():
            percentage = (count / len(data)) * 100
            print(f"- {type_name}: {count} 条 ({percentage:.2f}%)")
            
        print("\n领域分布:")
        for domain, count in domain_counts.items():
            percentage = (count / len(data)) * 100
            print(f"- {domain}: {count} 条 ({percentage:.2f}%)")
            
        return type_counts, domain_counts
    
    def mix_data(self, data_list, ratios, target_size):
        """混合多个数据集
        
        Args:
            data_list: 数据集列表
            ratios: 对应的比例列表
            target_size: 期望的混合后数据集大小
        """
        mixed_data = []
        
        # 计算每个数据集需要的样本数
        target_sizes = []
        for i, ratio in enumerate(ratios):
            if i == len(ratios) - 1:
                # 最后一个数据集取剩余数量
                size = target_size - sum(target_sizes)
            else:
                size = int(target_size * ratio)
            target_sizes.append(size)
        
        # 检查并调整采样数量
        for i, (data, target_size) in enumerate(zip(data_list, target_sizes)):
            if target_size > len(data):
                print(f"警告：数据集{i+1}目标数量({target_size})超过可用数据量({len(data)})")
                target_sizes[i] = len(data)
        
        # 采样并合并数据
        for data, size in zip(data_list, target_sizes):
            print(f"数据集大小: {len(data)}, 采样数量: {size}")
            selected = np.random.choice(len(data), size, replace=False)
            mixed_data.extend([data[idx] for idx in selected])
        
        # 打乱数据顺序
        np.random.shuffle(mixed_data)
        
        print(f"混合后数据集大小: {len(mixed_data)}")
        return mixed_data

    def process(self):
        """处理并保存混合数据集"""
        # 1. 加载数据
        self.load_data()
        
        # 2. 混合训练集
        print("\n混合训练集...")
        mixed_train = self.mix_data(
            [self.stage1_train, self.stage2_train, self.stage3_train],
            [self.mix_ratio["stage1"], self.mix_ratio["stage2"], self.mix_ratio["stage3"]],
            self.target_train_size
        )
        
        # 3. 混合验证集
        print("\n混合验证集...")
        mixed_valid = self.mix_data(
            [self.stage1_valid, self.stage2_valid, self.stage3_valid],
            [self.mix_ratio["stage1"], self.mix_ratio["stage2"], self.mix_ratio["stage3"]],
            self.target_valid_size
        )
        
        # 4. 分析数据分布
        print("\n训练集分布:")
        self.analyze_types(mixed_train)
        
        print("\n验证集分布:")
        self.analyze_types(mixed_valid)
        
        # 5. 保存混合后的数据
        train_path = os.path.join(self.output_dir, 'train.json')
        valid_path = os.path.join(self.output_dir, 'valid.json')
        
        with open(train_path, 'w', encoding='utf-8') as f:
            json.dump(mixed_train, f, ensure_ascii=False, indent=4)
            
        with open(valid_path, 'w', encoding='utf-8') as f:
            json.dump(mixed_valid, f, ensure_ascii=False, indent=4)
            
        print("\n数据保存完成：")
        print(f"- 混合训练集: {train_path} ({len(mixed_train)} 条数据)")
        print(f"- 混合验证集: {valid_path} ({len(mixed_valid)} 条数据)")

In [3]:
def main():
    mixer = DomainDatasetMixer(
        stage1_train_path='stage1/data_final/train.json',
        stage1_valid_path='stage1/data_final/valid.json',
        stage2_train_path='stage2/data_final/train.json', # 都是用final的数据，mix是混合了stage1的数据
        stage2_valid_path='stage2/data_final/valid.json',
        stage3_train_path='stage3/data_final/train.json',
        stage3_valid_path='stage3/data_final/valid.json',
        mix_ratio={"stage1": 0.5, "stage2": 0.3, "stage3": 0.2},
        target_train_size=80000,
        target_valid_size=1000,
        output_dir='stage3/data_mixed'
    )
    
    mixer.process()

In [4]:
if __name__ == "__main__":
    main()

正在加载数据集...
✓ Stage1 训练集: 80000 条数据
✓ Stage1 验证集: 1000 条数据
✓ Stage2 训练集: 41520 条数据
✓ Stage2 验证集: 1000 条数据
✓ Stage3 训练集: 30000 条数据
✓ Stage3 验证集: 1000 条数据

混合训练集...
数据集大小: 80000, 采样数量: 40000
数据集大小: 41520, 采样数量: 24000
数据集大小: 30000, 采样数量: 16000
混合后数据集大小: 80000

混合验证集...
数据集大小: 1000, 采样数量: 500
数据集大小: 1000, 采样数量: 300
数据集大小: 1000, 采样数量: 200
混合后数据集大小: 1000

训练集分布:

数据类型分布:
- instruction: 40000 条 (50.00%)
- ancient_poetry_creation: 5348 条 (6.69%)
- translation: 9219 条 (11.52%)
- classical_translation: 5342 条 (6.68%)
- story_generation: 8861 条 (11.08%)
- modern_poetry_creation: 5310 条 (6.64%)
- dialogue: 5920 条 (7.40%)

领域分布:
- general: 80000 条 (100.00%)

验证集分布:

数据类型分布:
- ancient_poetry_creation: 70 条 (7.00%)
- story_generation: 110 条 (11.00%)
- translation: 125 条 (12.50%)
- instruction: 500 条 (50.00%)
- classical_translation: 56 条 (5.60%)
- dialogue: 65 条 (6.50%)
- modern_poetry_creation: 74 条 (7.40%)

领域分布:
- general: 1000 条 (100.00%)

数据保存完成：
- 混合训练集: stage3/data_mixed/train.json (80000 条数据)
