In [1]:
# AutoDL官方学术资源加速
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
import sys
import os

# 添加项目根目录到Python路径
project_root = "/home/cuipeng/Gemma"
sys.path.append(project_root)

# 现在可以正常导入src下的模块
from src.core.model.model_initializer import initialize_model_and_tokenizer
from src.core.utils.model_utils import generate_response, apply_chat_template

In [3]:
from datasets import load_dataset # type: ignore
import pandas as pd # type: ignore
import numpy as np # type: ignore
from typing import List, Dict # type: ignore
import json # type: ignore
import os # type: ignore
from tqdm import tqdm # type: ignore
import random
from openai import OpenAI # type: ignore

# config.py
from dotenv import load_dotenv # type: ignore

# 加载 .env 文件
load_dotenv()

# 读取
ZetaTechs_api_key = os.getenv('ZETATECHS_API_KEY')
ZetaTechs_api_base = os.getenv('ZETATECHS_API_BASE')

# 初始化OpenAI客户端
client = OpenAI(api_key=ZetaTechs_api_key, base_url=ZetaTechs_api_base)

In [4]:
class TaskSpecificBaseDatasetBuilder:
    def __init__(self):
        """初始化任务特定数据集构建器"""
        self.datasets = {}
        self.processed_data = []

    def load_task_datasets(self):
        """加载特定任务所需的数据集"""
        print("正在加载任务特定数据集...")
        
        try:
            # 翻译任务数据集
            self.datasets['translation'] = load_dataset('Helsinki-NLP/news_commentary', "en-zh", split='train')
            print("✓ 翻译数据集加载完成") # https://huggingface.co/datasets/Helsinki-NLP/news_commentary
            
            # 对话任务数据集
            self.datasets['dialogue'] = load_dataset('Hello-SimpleAI/HC3-Chinese', "all", split='train', trust_remote_code=True)
            print("✓ 对话数据集加载完成") # https://huggingface.co/datasets/Hello-SimpleAI/HC3-Chinese
            
            # 故事创作数据集
            self.datasets['story'] = load_dataset('YeungNLP/firefly-train-1.1M', split='train')
            print("✓ 故事创作数据集加载完成") # https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M
            
        except Exception as e:
            print(f"加载数据集时出错: {e}")
            raise
    
    def preprocess_data(self, max_samples_per_task=20000):
        """预处理各个任务数据集"""
        for dataset_name, dataset in self.datasets.items():
            print(f"处理 {dataset_name} 数据集...")
            
            if dataset_name == 'translation':
                processed = self._process_translation(dataset, max_samples_per_task)
            elif dataset_name == 'dialogue':
                processed = self._process_dialogue(dataset, max_samples_per_task)
            elif dataset_name == 'story':
                processed = self._process_story(dataset, max_samples_per_task)
                
            self.processed_data.extend(processed)
            print(f"✓ {dataset_name} 处理完成，添加 {len(processed)} 条数据")
    
    def _process_translation(self, dataset, max_samples):
        """处理翻译数据集"""
        dataset = dataset[:5*max_samples] # 中译英，英译中，对应的，所以适当增加处理的数据，避免出现中译英，英译中在一块的情况
        processed = []
        
        # 处理所有数据
        for i in range(len(dataset['translation'])):
            # 中译英
            processed.append({
                'messages': [
                    {"role": "user", "content": f"请将以下中文翻译成英文：\n\n{dataset['translation'][i]['zh']}"},
                    {"role": "model", "content": dataset['translation'][i]['en']}
                ],
                'type': 'translation'
            })
            
            # 英译中
            processed.append({
                'messages': [
                    {"role": "user", "content": f"请将以下英文翻译成中文：\n\n{dataset['translation'][i]['en']}"},
                    {"role": "model", "content": dataset['translation'][i]['zh']}
                ],
                'type': 'translation'
            })
        
        print(f"翻译数据总量: {len(processed)} 条")
        
        # 如果数据量超过max_samples，随机采样
        if len(processed) > max_samples:
            processed = np.random.choice(processed, max_samples, replace=False).tolist()
            print(f"采样后数据量: {len(processed)} 条")
        
        return processed

    def _process_dialogue(self, dataset, max_samples):
        """处理对话数据集 - 这个数据集总共有12.9k条数据"""
        dataset = dataset[:5*max_samples]
        processed = []
        
        # 处理所有数据
        for i in range(len(dataset['question'])):
            if dataset['chatgpt_answers'][i]:  # 确保有chatgpt的回答
                processed.append({
                    'messages': [
                        {"role": "user", "content": dataset['question'][i]},
                        {"role": "model", "content": dataset['chatgpt_answers'][i][0]}
                    ],
                    'type': 'dialogue'
                })
        
        print(f"对话数据总量: {len(processed)} 条")
        
        # 如果数据量超过max_samples，随机采样
        if len(processed) > max_samples:
            processed = np.random.choice(processed, max_samples, replace=False).tolist()
            print(f"采样后数据量: {len(processed)} 条")
        
        return processed

    def _process_story(self, dataset, max_samples):
        """处理故事创作数据集"""
        all_data = dataset.to_pandas()
        story_data = all_data[all_data['kind'] == 'StoryGeneration']
        # 这个数据集总共有19048条故事生成的数据
        story_data = story_data[:5*max_samples]
        
        processed = []
        # 处理所有数据
        for _, row in story_data.iterrows():
            processed.append({
                'messages': [
                    {"role": "user", "content": f"请根据以下提示生成一个故事：\n\n{row['input']}"},
                    {"role": "model", "content": row['target']}
                ],
                'type': 'story_generation'
            })
        
        print(f"故事生成数据总量: {len(processed)} 条")
        
        # 如果数据量超过max_samples，随机采样
        if len(processed) > max_samples:
            processed = np.random.choice(processed, max_samples, replace=False).tolist()
            print(f"采样后数据量: {len(processed)} 条")
        
        return processed
    
    def format_for_training(self):
        """将数据格式化为训练所需的格式"""
        formatted_data = []
        
        for item in self.processed_data:
            dialogue_str = apply_chat_template(item['messages'])
            formatted_data.append({
                'text': dialogue_str,
                'type': item['type']  # 保持原始任务类型
            })
            
        return formatted_data
    
    def save_dataset(self, formatted_data, output_path='./stage2/data_raw', train_ratio=0.8, valid_size=1000):
        """保存处理后的数据集"""
        os.makedirs(output_path, exist_ok=True)
        
        # 随机打乱数据
        np.random.shuffle(formatted_data)
        
        # 划分训练集和验证集
        split_idx = int(len(formatted_data) * train_ratio)
        train_data = formatted_data[:split_idx]
        valid_data = formatted_data[split_idx:]
        
        # 从验证集中随机选择指定数量的数据
        if len(valid_data) > valid_size:
            valid_indices = np.random.choice(len(valid_data), valid_size, replace=False)
            valid_data = [valid_data[i] for i in valid_indices]
        
        # 保存训练集和验证集
        train_path = os.path.join(output_path, 'train.json')
        valid_path = os.path.join(output_path, 'valid.json')
        
        with open(train_path, 'w', encoding='utf-8') as f:
            json.dump(train_data, f, ensure_ascii=False, indent=2)
        
        with open(valid_path, 'w', encoding='utf-8') as f:
            json.dump(valid_data, f, ensure_ascii=False, indent=2)
        
        print(f"\n数据集已保存：")
        print(f"- 训练集: {train_path} ({len(train_data)} 条数据)")
        print(f"- 验证集: {valid_path} ({len(valid_data)} 条数据)")
        
        # 输出数据集统计信息
        self._print_dataset_stats(formatted_data)
    
    def _print_dataset_stats(self, data):
        """打印数据集统计信息"""
        print("\n数据集统计信息:")
        
        # 统计不同类型的数据数量
        type_counts = {}
        for item in data:
            type_counts[item['type']] = type_counts.get(item['type'], 0) + 1
            
        print("\n数据类型分布:")
        for data_type, count in type_counts.items():
            percentage = (count / len(data)) * 100
            print(f"- {data_type}: {count} ({percentage:.2f}%)")

In [5]:
def main():
    # 1. 初始化数据集构建器
    builder = TaskSpecificBaseDatasetBuilder()
    
    try:
        # 2. 加载数据集
        builder.load_task_datasets()
        
        # 3. 处理数据集
        builder.preprocess_data(max_samples_per_task=20000)
        
        # 4. 格式化数据
        formatted_data = builder.format_for_training()
        
        # 5. 保存数据集
        builder.save_dataset(formatted_data)
        
        print("\n特定任务阶段基础数据集构建完成！")
        
    except Exception as e:
        print(f"构建数据集时出错: {e}")
        raise

In [6]:
if __name__ == "__main__":
    main()

正在加载任务特定数据集...


✓ 翻译数据集加载完成
✓ 对话数据集加载完成


Repo card metadata block was not found. Setting CardData to empty.


✓ 故事创作数据集加载完成
处理 translation 数据集...
翻译数据总量: 138412 条
采样后数据量: 20000 条
✓ translation 处理完成，添加 20000 条数据
处理 dialogue 数据集...
对话数据总量: 12853 条
✓ dialogue 处理完成，添加 12853 条数据
处理 story 数据集...
故事生成数据总量: 19048 条
✓ story 处理完成，添加 19048 条数据

数据集已保存：
- 训练集: ./stage2/data_raw/train.json (41520 条数据)
- 验证集: ./stage2/data_raw/valid.json (1000 条数据)

数据集统计信息:

数据类型分布:
- story_generation: 19048 (36.70%)
- translation: 20000 (38.53%)
- dialogue: 12853 (24.76%)

特定任务阶段基础数据集构建完成！
