In [1]:
import os
import json
import random
from sklearn.model_selection import train_test_split

In [2]:
def load_data_from_directory(directory):
    """
    从指定目录加载每本书的数据，并返回书名与对应内容的映射。
    """
    book_data = {}
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            filepath = os.path.join(directory, filename)
            with open(filepath, 'r', encoding='utf-8') as f:
                data = json.load(f)
                book_data[filename] = data
    return book_data

In [3]:
def split_books_by_ratio(book_data, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """
    按照给定比例划分书籍。
    """
    # 打乱书名顺序以保证随机性
    book_names = list(book_data.keys())
    random.shuffle(book_names)

    # 累计每个集合的条目数量
    total_entries = sum(len(entries) for entries in book_data.values())
    target_train_size = int(total_entries * train_ratio)
    target_val_size = int(total_entries * val_ratio)
    target_test_size = total_entries - target_train_size - target_val_size

    train_books, val_books, test_books = [], [], []
    train_count, val_count, test_count = 0, 0, 0

    # 分配书籍到集合
    for book in book_names:
        book_size = len(book_data[book])

        if train_count + book_size <= target_train_size:
            train_books.append(book)
            train_count += book_size
        elif val_count + book_size <= target_val_size:
            val_books.append(book)
            val_count += book_size
        else:
            test_books.append(book)
            test_count += book_size

    print(f"Train set: {train_count} entries")
    print(f"Validation set: {val_count} entries")
    print(f"Test set: {test_count} entries")

    return train_books, val_books, test_books

In [4]:
def merge_data(book_data, book_list):
    """
    根据书名列表合并书籍内容。
    """
    merged_data = []
    for book in book_list:
        merged_data.extend(book_data[book])
    return merged_data

def save_data(data, output_path):
    """
    将数据保存为 JSON 文件。
    """
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


In [5]:
def main(input_dir, output_dir):
    """
    主程序：加载数据、划分数据集并保存结果。
    """
    os.makedirs(output_dir, exist_ok=True)

    # 加载书籍数据
    print("Loading data...")
    book_data = load_data_from_directory(input_dir)

    # 按比例划分书籍
    print("Splitting books...")
    train_books, val_books, test_books = split_books_by_ratio(book_data)

    # 合并每个集合的数据
    train_data = merge_data(book_data, train_books)
    val_data = merge_data(book_data, val_books)
    test_data = merge_data(book_data, test_books)

    # 保存数据集
    save_data(train_data, os.path.join(output_dir, 'train_raw.json'))
    save_data(val_data, os.path.join(output_dir, 'val_raw.json'))
    save_data(test_data, os.path.join(output_dir, 'test_raw.json'))

    print("Data processing complete!")

# 示例路径配置
input_dir = "C:/Users/Lenovo/OneDrive/NUS/CS-24fall/project/AudiobookGeneration_cs5647/LiteraryTextsDataset/target_extracted_data"
output_dir = "C:/Users/Lenovo/OneDrive/NUS/CS-24fall/project/AudiobookGeneration_cs5647/LiteraryTextsDataset/dataset"

# 运行主程序
main(input_dir, output_dir)

Loading data...
Splitting books...
Train set: 21453 entries
Validation set: 1826 entries
Test set: 3734 entries
Data processing complete!
