In [2]:
from datasets import load_dataset, DatasetDict
import json
import os


# Alpaca 格式转换函数
def convert_to_alpaca_format(dataset):
    alpaca_dataset = []
    for data in dataset:
        # 假设数据集中有'instruction', 'input', 'output'字段，根据实际数据集调整字段名称
        alpaca_data = {
            "instruction": data.get('instruction', ""),
            "input": data.get('query', ""),
            "output": data.get('response', "")
        }
        alpaca_dataset.append(alpaca_data)
    return alpaca_dataset

# 保存数据集到 JSON 文件
def save_dataset(dataset, save_path):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(dataset, f, ensure_ascii=False, indent=4)

# 主程序：遍历数据集并保存
save_path = "./data/llama-factory"

# 遍历医疗数据集
try:
    # 加载完整数据集
    dataset = load_dataset(
        "michaelwzhu/ChatMed_Consult_Dataset",
        # data_files={"train": f"{benchmark}/train-*.parquet"},
        split="train"
    )

    # 切分数据集为 90% 训练集和 10% 测试+验证集
    train_testvalid = dataset.train_test_split(test_size=0.1)

    # 将测试+验证集切分为 50% 测试集和 50% 验证集
    test_valid = train_testvalid['test'].train_test_split(test_size=0.5)

    # 合并为一个 DatasetDict
    train_test_valid_dataset = DatasetDict({
        'train': train_testvalid['train'],
        'test': test_valid['test'],
        'valid': test_valid['train']
    })

    # 转换为 Alpaca 格式并保存
    train_alpaca_dataset = convert_to_alpaca_format(train_test_valid_dataset['train'])
    test_alpaca_dataset = convert_to_alpaca_format(train_test_valid_dataset['test'])
    valid_alpaca_dataset = convert_to_alpaca_format(train_test_valid_dataset['valid'])

    # save_dataset(train_alpaca_dataset, f"{save_path}/Medical/{benchmark}_train.json")
    save_dataset(test_alpaca_dataset, f"{save_path}/ChatMed_Consult_Dataset_test.json")
    # save_dataset(valid_alpaca_dataset, f"{save_path}/Medical/{benchmark}_valid.json")

    print(f"Saved Medical benchmark ChatMed_Consult_Dataset train, test, and valid sets to Alpaca format.")
except Exception as e:
    print(f"Failed to process Medical benchmark ChatMed_Consult_Dataset: {e}")


Saved Medical benchmark ChatMed_Consult_Dataset train, test, and valid sets to Alpaca format.


In [1]:
from datasets import load_dataset, DatasetDict
import json
import os


# Alpaca 格式转换函数
def convert_to_alpaca_format(dataset):
    alpaca_dataset = []
    for data in dataset:
        # 假设数据集中有'instruction', 'input', 'output'字段，根据实际数据集调整字段名称
        alpaca_data = {
            "instruction": data.get('a', ""),
            "input": data.get('instruction', ""),
            "output": data.get('output', "")
        }
        alpaca_dataset.append(alpaca_data)
    return alpaca_dataset

# 保存数据集到 JSON 文件
def save_dataset(dataset, save_path):
    if not os.path.exists(os.path.dirname(save_path)):
        os.makedirs(os.path.dirname(save_path))
    with open(save_path, 'w', encoding='utf-8') as f:
        json.dump(dataset, f, ensure_ascii=False, indent=4)

# 主程序：遍历数据集并保存
save_path = "./data/llama-factory"

# 遍历医疗数据集
try:
    # 加载完整数据集
    dataset = load_dataset(
        "tyang816/MedChatZH",
        # data_files={"train": f"{benchmark}/train-*.parquet"},
        # split="validation"
    )

    # 转换为 Alpaca 格式并保存
    train_alpaca_dataset = convert_to_alpaca_format(dataset['train'])
    # test_alpaca_dataset = convert_to_alpaca_format(train_test_valid_dataset['test'])
    valid_alpaca_dataset = convert_to_alpaca_format(dataset['validation'])

    # save_dataset(train_alpaca_dataset, f"{save_path}/Medical/{benchmark}_train.json")
    save_dataset(valid_alpaca_dataset, f"{save_path}/MedChatZH_test.json")
    # save_dataset(valid_alpaca_dataset, f"{save_path}/Medical/{benchmark}_valid.json")

    print(f"Saved Medical benchmark ChatMed_Consult_Dataset train, test, and valid sets to Alpaca format.")
except Exception as e:
    print(f"Failed to process Medical benchmark ChatMed_Consult_Dataset: {e}")


Saved Medical benchmark ChatMed_Consult_Dataset train, test, and valid sets to Alpaca format.
