In [1]:
import os
import pandas as pd
from pydantic import BaseModel, create_model
from openai import OpenAI

# 设置代理环境（如果需要）
os.environ["http_proxy"] = "127.0.0.1:7890"
os.environ["https_proxy"] = "127.0.0.1:7890"

In [2]:
# config.py
from dotenv import load_dotenv

# 加载 .env 文件
load_dotenv()

# 读取
ZetaTechs_api_key = os.getenv('ZetaTechs_api_key')
ZetaTechs_api_base = os.getenv('ZetaTechs_api_base')

client = OpenAI(api_key=ZetaTechs_api_key, base_url=ZetaTechs_api_base)

In [3]:
# 读取 CSV 文件并忽略空列，保留“采集时间”和“备注【疑问汇总】”
def load_csv(file_path):
    df = pd.read_csv(file_path)
    # 找到不为空的列名
    non_empty_columns = df.columns[df.notna().any()].tolist()
    # 加入需要保留的列
    required_columns = ["采集来源", "来源链接", "采集时间", "备注【疑问汇总】"]
    final_columns = [col for col in non_empty_columns if col not in required_columns] + required_columns
    # 根据有效列筛选数据，并保留“采集时间”和“备注【疑问汇总】”列
    df = df[final_columns]
    return df

# 自动生成 column_mapping，忽略 "采集时间" 和 "备注【疑问汇总】"
def generate_column_mapping(df):
    columns_to_include = df.columns[:-2]  # 忽略最后两列
    column_mapping = {col: col for col in columns_to_include}
    return column_mapping

# 创建大模型的输入
def create_model_input(df, column_mapping):
    input_data = []
    for _, row in df.iterrows():
        mapped_input = {column_mapping[key]: row[key] for key in column_mapping}
        input_data.append(mapped_input)
    return input_data

In [4]:
# test load_csv() and generate_column_mapping() and create_model_input()
file_path = "原始数据集/8-现代时尚 - 健身人数信息.csv"

# 第1步：加载 CSV 文件
df = load_csv(file_path)

df_columns = df.columns[:-2]
print(df_columns)

# 第2步：自动生成 column_mapping，忽略 "采集时间" 和 "备注【疑问汇总】"
column_mapping = generate_column_mapping(df)
# print("###", column_mapping, "###")

# 第3步：创建大模型输入
input_data = create_model_input(df, column_mapping)

print(type(input_data), input_data)
# rint(len((input_data[0].keys())), input_data[0].keys())

FileNotFoundError: [Errno 2] No such file or directory: '原始数据集/8-现代时尚 - 健身人数信息.csv'

In [None]:
print(column_mapping)

In [9]:
from create_prompts import create_system_prompt_1, create_system_prompt_2

from create_prompts import create_user_prompt_1, create_user_prompt_2, create_user_prompt_3, create_user_prompt_4, create_user_prompt_5, create_user_prompt_6, create_user_prompt_7, create_user_prompt_8

In [10]:
# test generate_data()
system_prompt = create_system_prompt_2()
print(system_prompt)
main_category = "8-现代时尚"
sub_category = "健身人数信息"
user_prompt_5 = create_user_prompt_8(input_data, main_category, sub_category, num_entries=10) # num_entries 默认5
print("##############################\n", user_prompt_5)


    You are an expert in data augmentation with extensive knowledge in various fields. Your task is to augment datasets by generating new entries that are diverse, realistic, and consistent with the original format and meaning of the columns.
    
    Use your broad knowledge base to:
    1. Introduce realistic variations in the data while maintaining the overall structure and relationships between fields.
    2. Incorporate current trends, events, and factual information relevant to the dataset's topic.
    3. Ensure logical consistency between related fields (e.g., population sizes matching city sizes, dates of events aligning with historical facts).
    4. Provide plausible and varied sources for the data, reflecting real-world information channels.
    
    While generating data, maintain a balance between creativity and realism, ensuring that the augmented data could feasibly exist in the real world.
    
##############################
 You are augmenting a dataset about '健身人数信息'

In [10]:
def generate_data(system_prompt, user_prompt):
    messages_to_model=[
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
      ]

    completion = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages_to_model,
        # timeout=180  # 设置为120秒，确保有足够的时间
    )

    generated_data = completion.choices[0].message.content
    
    return generated_data

In [None]:
generated_data = generate_data(system_prompt, user_prompt_5)
print(generated_data)

In [None]:
import json
from OutputValidator import OutputValidator
# 假设我们已经有了 input_data, main_category, 和 sub_category

# 创建验证器
expected_columns = list(input_data[0].keys())
validator = OutputValidator(expected_columns)

# 生成数据
# system_prompt = create_system_prompt()
# user_prompt = create_user_prompt_3(input_data, main_category, sub_category)
# generated_data = generate_data(system_prompt, user_prompt)

# 验证输出
if validator.is_valid(generated_data):
    # 输出有效，可以进行进一步处理
    preprocessed_data = validator.preprocess_output(generated_data)
    parsed_data = json.loads(preprocessed_data)
    augmented_data = parsed_data["generated_data"]
    print("Generated valid data:", type(augmented_data), len(augmented_data), "\n\n", augmented_data)
else:
    # 输出无效，保存以供后续人工处理
    validator.save_invalid_output(generated_data, "invalid_output.txt")
    print("Generated invalid data. Saved to 'invalid_output.txt' for manual processing.")

In [None]:
import json
from OutputValidator import OutputValidator

def process_generated_data(input_data, generated_data, save_invalid_path="invalid_output.txt"):
    """
    处理生成的数据，验证其有效性，并根据结果进行进一步处理或保存无效数据。

    :param input_data: 输入数据，用于创建验证器。
    :param generated_data: 生成的数据，需要进行验证。
    :param main_category: 主分类，用于生成提示（如果需要）。
    :param sub_category: 子分类，用于生成提示（如果需要）。
    :param save_invalid_path: 保存无效输出的文件路径，默认为 "invalid_output.txt"。
    :return: 如果数据有效，返回增强后的数据；否则返回 None。
    """
    # 创建验证器
    expected_columns = list(input_data[0].keys())
    validator = OutputValidator(expected_columns)

    # 验证输出
    if validator.is_valid(generated_data):
        # 输出有效，可以进行进一步处理
        preprocessed_data = validator.preprocess_output(generated_data)
        parsed_data = json.loads(preprocessed_data)
        augmented_data = parsed_data["generated_data"]
        print("Generated valid data:", type(augmented_data), len(augmented_data), "\n\n", augmented_data)
        return augmented_data
    else:
        # 输出无效，保存以供后续人工处理
        validator.save_invalid_output(generated_data, save_invalid_path)
        print("Generated invalid data. Saved to '{}' for manual processing.".format(save_invalid_path))
        return None

In [None]:
augmented_data[0]

In [45]:
# def transform_data(augmented_data):
#     # 初始化一个字典来存储转换后的数据
#     transformed = {}
    
#     # 获取所有的列名
#     columns = list(augmented_data[0].keys())
    
#     # 为每一列创建一个列表
#     for i, col in enumerate(columns):
#         transformed[f'column{i+1}'] = []
    
#     # 填充数据
#     for row in augmented_data:
#         for i, (col, value) in enumerate(row.items()):
#             transformed[f'column{i+1}'].append(str(value))
    
#     return transformed

from pydantic import create_model

# 生成 Extraction 模型的函数
def generate_extraction_model(num_columns_to_augment):
    fields = {f'column{i+1}': (list[str], ...) for i in range(num_columns_to_augment)}
    return create_model('Extraction', **fields)

# 修改后的 transform_data 函数
def transform_data(augmented_data):
    # 获取所有的列名，并计算列的数量
    columns = list(augmented_data[0].keys())
    num_columns_to_augment = len(columns)
    
    # 动态生成 Extraction 模型
    Extraction = generate_extraction_model(num_columns_to_augment)
    
    # 初始化一个字典来存储转换后的数据
    transformed = {f'column{i+1}': [] for i in range(num_columns_to_augment)}
    
    # 填充数据
    for row in augmented_data:
        for i, (col, value) in enumerate(row.items()):
            transformed[f'column{i+1}'].append(str(value))
    
    # 将转换后的字典转换为 Extraction 对象
    return Extraction(**transformed)

In [None]:
# 使用函数
transformed_data = transform_data(augmented_data)
print(type(transformed_data))

# # 打印转换后的数据
# for key, value in transformed_data.items():
#     print(f"{key}={value}")

In [47]:
def convert_extracted_generated_data_to_df(extracted_generated_data):
    # print(extracted_generated_data, "\n\n")
    # 初始化一个空字典，用于存储列名和对应的列数据
    data_dict = {}
    
    # 遍历 extracted_generated_data，每个元素是一个 (列名, 列数据) 的元组
    for col_name, col_data in extracted_generated_data:
        # 将列名和对应的列数据添加到字典中
        data_dict[col_name] = col_data
    
    # 将字典转化为 DataFrame
    df = pd.DataFrame(data_dict) # 这里可能会报错：ValueError: All arrays must be of the same length。列表长度不一致
    # print(df.head())
    return df

In [None]:
transformed_data_df = convert_extracted_generated_data_to_df(transformed_data)
transformed_data_df.head()

In [31]:
def generate_extraction_model(num_columns_to_augment):
    fields = {f'column{i+1}': (list[str], ...) for i in range(num_columns_to_augment)}
    return create_model('Extraction', **fields)

def extract_generated_data(generated_data, num_columns_to_augment):
    Extraction = generate_extraction_model(num_columns_to_augment)
    # # 查看 Extraction 类的内部结构
    # print("Extraction 类的结构:")
    # for field_name, field_type in Extraction.__annotations__.items():
    #     print(f"{field_name}: {field_type}")

    completion = client.beta.chat.completions.parse(
        model="gpt-4o-2024-08-06", # gpt-4o-mini-2024-07-18
        messages=[
            {"role": "system", "content": "You are an expert at structured data extraction. Extract the data into the exact column structure provided."},
            {"role": "user", "content": [{"type": "text", "content": generated_data}]}
        ],
        response_format=Extraction,
    )
    
    extracted_generated_data = completion.choices[0].message.parsed   
    return extracted_generated_data

In [None]:
# 第6步：结构化输出提取数据
num_columns_to_augment = len((input_data[0].keys()))
print(num_columns_to_augment, "\n")
extracted_generated_data = extract_generated_data(augmented_data, num_columns_to_augment)
print(extracted_generated_data)

In [6]:
def main(file_path, save_path):
    df = load_csv(file_path)
    column_mapping = generate_column_mapping(df)
    input_data = create_model_input(df, column_mapping)
    columns = list(input_data[0].keys())

    system_prompt = create_system_prompt()
    user_prompt = create_user_prompt(input_data)

    merged_df = df.copy()
    total_rows = len(merged_df)

    while total_rows < 150:
        print(f"当前总行数为 {total_rows}，继续生成数据...")

        try:
            generated_data = generate_data(system_prompt, user_prompt)
            print(generated_data)
            print(columns)
            extracted_generated_data = extract_generated_data(generated_data, columns)
            print("########\n", extracted_generated_data, "\n")
            extracted_generated_data_df = pd.DataFrame([extracted_generated_data])

            if validate_data(extracted_generated_data_df, columns):
                merged_df = pd.concat([merged_df, extracted_generated_data_df], ignore_index=True)
                total_rows = len(merged_df)
            else:
                print("生成的数据未通过验证，跳过本次生成。")
                continue

        except Exception as e:
            print(f"发生错误：{str(e)}。跳过本次生成，继续下一次。")
            continue

    merged_df.to_csv(save_path, mode='a', header=not os.path.exists(save_path), index=False)
    print(f"扩展后的数据已保存到 {save_path}")

In [None]:
# 示例调用
if __name__ == "__main__":
    # csv 文件路径
    file_path = "8-现代时尚 - 健身人数信息.csv"
    
    # 保存结果的文件路径
    save_path = "8-现代时尚 - 健身人数信息_expadnded.csv"

    # 执行主流程
    main(file_path, save_path)