<a href="https://colab.research.google.com/github/DominicDin/teest/blob/main/%E2%80%9Cdzx_Dataset_for_LLM_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Take Home Test: Reformat a Public Dataset for LLM Training

### Objective

The goal of this task is to prepare public datasets for more effective use in training and fine-tuning Large Language Models (LLMs). You are required to reformat a specific subset of a public dataset into a structured, consistent format to facilitate its usability.

### Detailed Instructions

#### 1. Dataset Selection and Preparation

- **Dataset:** You are assigned the `Headline` subset of the [AdaptLLM/finance-tasks](https://huggingface.co/datasets/AdaptLLM/finance-tasks) dataset.

- **Task Description:** Each entry in the `input` column contains multiple "Yes" or "No" questions alongside their respective answers. Your task is to:

  - Develop a Python script to parse and separate each question and its answer from the entry.
  - Save each question-answer pair in a structured JSON format as follows:
    ```json
    {
      "id": "<unique_identifier>",
      "Question": "<question_text>",
      "Answer": "<answer_text>"
    }
    ```

  - You are encouraged to introduce additional attributes if needed to preserve the integrity and completeness of the information. Adding relevant tag information is strongly recommended.
- **Automation Requirement:** The task must be completed using Python. Manual editing or data manipulation is strictly prohibited. Your script should efficiently handle variations in data format within the column.

#### 2. Deliverables

- **Reformatted Dataset:** Provide the schema of the final format you adopted for saving the results.
- **Transformation Code:** Submit the complete code used for converting the dataset into the designated format.
- **Statistics:** Report the total number of question-answer pairs extracted from the dataset.
- **Performance Metrics:** Document the time taken to complete the dataset cleanup and transformation process.


In [6]:
# Cell 1: 安装依赖和导入库
# ----------------------------
# 运行这个单元格来安装必要的库
!pip install datasets transformers



In [1]:
# Cell 2: 导入所需的库
# -------------------
import json
import re
import time
from datasets import load_dataset
import pandas as pd
from typing import List, Dict, Any
from datetime import datetime

print("所有库已成功导入!")


所有库已成功导入!


In [2]:
# Cell 3: 加载数据集
# -----------------
def load_finance_dataset():
    """加载AdaptLLM/finance-tasks数据集"""
    print("正在加载AdaptLLM/finance-tasks数据集...")
    try:
        dataset = load_dataset("AdaptLLM/finance-tasks", name="Headline", split="test")
        df = pd.DataFrame(dataset)
        print(f"数据集加载成功! 总条目数: {len(df)}")
        return df
    except Exception as e:
        print(f"加载数据集时出错: {e}")
        raise

# 加载数据集
df = load_finance_dataset()

# Cell 4: 查看数据集结构
# ---------------------
print("数据集基本信息:")
print(f"形状: {df.shape}")
print(f"列名: {list(df.columns)}")
print("\n前3行数据:")
df.head(3)

正在加载AdaptLLM/finance-tasks数据集...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/8.23k [00:00<?, ?B/s]

test.json:   0%|          | 0.00/22.4M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/20547 [00:00<?, ? examples/s]

数据集加载成功! 总条目数: 20547
数据集基本信息:
形状: (20547, 5)
列名: ['id', 'input', 'options', 'gold_index', 'class_id']

前3行数据:


Unnamed: 0,id,input,options,gold_index,class_id
0,0,"Headline: ""Gold falls to Rs 30,800; silver dow...","[No, Yes]",1,0
1,1,Headline: february gold rallies to intraday hi...,"[No, Yes]",0,7
2,2,Please answer a question about the following h...,"[No, Yes]",0,5


In [3]:
# Cell 5: 查看具体数据示例
# -----------------------
print("查看第一条记录的详细内容:")
print(f"ID: {df.iloc[0]['id']}")
print(f"Class ID: {df.iloc[0]['class_id']}")
print(f"Gold Index: {df.iloc[0]['gold_index']}")
print(f"Options: {df.iloc[0]['options']}")
print(f"\nInput内容:")
print("-" * 50)
print(df.iloc[0]['input'])
print("-" * 50)

查看第一条记录的详细内容:
ID: 0
Class ID: 0
Gold Index: 1
Options: ['No', 'Yes']

Input内容:
--------------------------------------------------
Headline: "Gold falls to Rs 30,800; silver down at Rs 41,200 per kg" Now answer this question: Does the news headline talk about price in the past? Yes

Headline: "gold futures add to gains after adp data" Now answer this question: Does the news headline talk about price? Yes

Headline: "Gold holds on to modest loss after data" Now answer this question: Does the news headline talk about price in the future? No

Headline: "spot gold quoted at $417.50, down 20c from new york" Now answer this question: Does the news headline talk about a general event (apart from prices) in the past? No

Headline: "gold hits new record high at $1,036.20 an ounce" Now answer this question: Does the news headline compare gold with any other asset? No

Headline: "gold may hit rs 31,500, but pullback rally may not sustain for long: experts" Now answer this question: Does the news h

In [7]:
# Cell 6: 定义数据转换函数
# -----------------------
def parse_qa_pairs_from_input(input_text):
    """
    从input文本中解析所有的问答对
    返回包含问题和答案的列表
    """
    qa_pairs = []

    # 按照 "Headline:" 分割文本
    sections = re.split(r'Headline:\s*', input_text)

    # 移除空的第一个元素
    if sections and not sections[0].strip():
        sections = sections[1:]

    for section in sections:
        if not section.strip():
            continue

        # 提取标题（在引号中）
        headline_match = re.match(r'"([^"]+)"', section)
        if not headline_match:
            continue

        headline = headline_match.group(1)

        # 查找问题和答案
        remaining_text = section[len(headline_match.group(0)):].strip()

        # 查找 "Now answer this question:" 模式
        qa_pattern = r'Now answer this question:\s*(.+?)\?\s*(Yes|No|$)'
        qa_matches = re.finditer(qa_pattern, remaining_text, re.IGNORECASE)

        for match in qa_matches:
            question = match.group(1).strip()
            answer = match.group(2).strip() if match.group(2) and match.group(2).strip() else None

            # 构建完整的问题文本
            full_question = f'Headline: "{headline}" Now answer this question: {question}?'

            qa_pairs.append({
                'headline': headline,
                'question': question,
                'full_question': full_question,
                'answer': answer
            })

    return qa_pairs

print("问答对解析函数已定义!")

# Cell 7: 测试解析函数
# ------------------
print("测试第一条记录的解析:")
test_qa_pairs = parse_qa_pairs_from_input(df.iloc[0]['input'])

print(f"找到 {len(test_qa_pairs)} 个问答对:")
for i, pair in enumerate(test_qa_pairs):
    print(f"\n问答对 {i+1}:")
    print(f"  标题: {pair['headline']}")
    print(f"  问题: {pair['question']}")
    print(f"  答案: {pair['answer']}")
    print(f"  完整问题: {pair['full_question']}...")


问答对解析函数已定义!
测试第一条记录的解析:
找到 6 个问答对:

问答对 1:
  标题: Gold falls to Rs 30,800; silver down at Rs 41,200 per kg
  问题: Does the news headline talk about price in the past
  答案: Yes
  完整问题: Headline: "Gold falls to Rs 30,800; silver down at Rs 41,200 per kg" Now answer this question: Does the news headline talk about price in the past?...

问答对 2:
  标题: gold futures add to gains after adp data
  问题: Does the news headline talk about price
  答案: Yes
  完整问题: Headline: "gold futures add to gains after adp data" Now answer this question: Does the news headline talk about price?...

问答对 3:
  标题: Gold holds on to modest loss after data
  问题: Does the news headline talk about price in the future
  答案: No
  完整问题: Headline: "Gold holds on to modest loss after data" Now answer this question: Does the news headline talk about price in the future?...

问答对 4:
  标题: spot gold quoted at $417.50, down 20c from new york
  问题: Does the news headline talk about a general event (apart from prices) in the past
  答案

In [9]:
# Cell 8: 定义完整的数据转换函数
# -----------------------------
def extract_all_qa_pairs(df):
    """
    从整个数据集中提取所有问答对
    包括示例问答对和目标问答对
    """
    print("开始提取所有问答对...")
    start_time = time.time()

    all_qa_pairs = []
    qa_id_counter = 0

    for idx, row in df.iterrows():
        try:
            # 解析input中的所有问答对
            qa_pairs = parse_qa_pairs_from_input(row['input'])

            if not qa_pairs:
                print(f"警告: 第{idx}行没有找到问答对")
                continue

            # 处理每个问答对
            for i, pair in enumerate(qa_pairs):
                is_target = (i == len(qa_pairs) - 1)  # 最后一个是目标问题

                if is_target:
                    # 对于目标问题，使用gold_index确定答案
                    gold_index = row['gold_index']
                    options = row['options']

                    if isinstance(options, list) and len(options) > gold_index:
                        answer = options[gold_index]
                    else:
                        answer = "Yes" if gold_index == 1 else "No"
                else:
                    # 对于示例问题，使用解析出的答案
                    answer = pair['answer']

                # 创建JSON条目
                if answer:  # 只有当答案存在时才添加
                    qa_entry = {
                        "id": str(qa_id_counter),
                        "Question": pair['full_question'],
                        "Answer": answer
                    }

                    all_qa_pairs.append(qa_entry)
                    qa_id_counter += 1

            # 显示进度
            if (idx + 1) % 100 == 0:
                print(f"已处理 {idx + 1}/{len(df)} 条原始记录，提取了 {len(all_qa_pairs)} 个问答对...")

        except Exception as e:
            print(f"处理第{idx}行时出错: {e}")
            continue

    end_time = time.time()
    processing_time = end_time - start_time

    print(f"\n提取完成!")
    print(f"原始记录数: {len(df)}")
    print(f"提取的问答对数: {len(all_qa_pairs)}")
    print(f"处理时间: {processing_time:.2f} 秒")
    print(f"平均每条原始记录产生: {len(all_qa_pairs)/len(df):.1f} 个问答对")

    return all_qa_pairs, {
        'original_records': len(df),
        'extracted_qa_pairs': len(all_qa_pairs),
        'processing_time': processing_time,
        'avg_pairs_per_record': len(all_qa_pairs)/len(df)
    }

# 执行完整提取
all_qa_data, extraction_stats = extract_all_qa_pairs(df)

# Cell 9: 查看提取结果
# ------------------
print("提取结果示例:")
print("="*50)

# 显示前5个问答对
for i in range(min(6, len(all_qa_data))):
    print(f"\n问答对 {i+1} (ID: {all_qa_data[i]['id']}):")
    print(f"问题: {all_qa_data[i]['Question']}")
    print(f"答案: {all_qa_data[i]['Answer']}")
    print("-" * 30)


[1;30;43m流式输出内容被截断，只能显示最后 5000 行内容。[0m
警告: 第15155行没有找到问答对
警告: 第15156行没有找到问答对
警告: 第15158行没有找到问答对
警告: 第15159行没有找到问答对
警告: 第15160行没有找到问答对
警告: 第15162行没有找到问答对
警告: 第15163行没有找到问答对
警告: 第15164行没有找到问答对
警告: 第15165行没有找到问答对
警告: 第15166行没有找到问答对
警告: 第15167行没有找到问答对
警告: 第15168行没有找到问答对
警告: 第15169行没有找到问答对
警告: 第15170行没有找到问答对
警告: 第15171行没有找到问答对
警告: 第15172行没有找到问答对
警告: 第15174行没有找到问答对
警告: 第15175行没有找到问答对
警告: 第15176行没有找到问答对
警告: 第15177行没有找到问答对
警告: 第15179行没有找到问答对
警告: 第15180行没有找到问答对
警告: 第15181行没有找到问答对
警告: 第15182行没有找到问答对
警告: 第15183行没有找到问答对
警告: 第15185行没有找到问答对
警告: 第15186行没有找到问答对
警告: 第15187行没有找到问答对
警告: 第15188行没有找到问答对
警告: 第15189行没有找到问答对
警告: 第15190行没有找到问答对
警告: 第15191行没有找到问答对
警告: 第15192行没有找到问答对
警告: 第15193行没有找到问答对
警告: 第15194行没有找到问答对
警告: 第15195行没有找到问答对
警告: 第15196行没有找到问答对
警告: 第15197行没有找到问答对
警告: 第15198行没有找到问答对
警告: 第15199行没有找到问答对
警告: 第15200行没有找到问答对
警告: 第15201行没有找到问答对
警告: 第15202行没有找到问答对
警告: 第15204行没有找到问答对
警告: 第15205行没有找到问答对
警告: 第15206行没有找到问答对
警告: 第15208行没有找到问答对
警告: 第15209行没有找到问答对
警告: 第15210行没有找到问答对
警告: 第15211行没有找到问答对
警告: 第1521

In [10]:
# Cell 10: 数据质量检查
# -------------------
def quality_check_extracted_data(data):
    """检查提取数据的质量"""
    print("数据质量检查:")
    print("="*30)

    # 检查必需字段
    required_fields = ['id', 'Question', 'Answer']
    field_missing_count = {field: 0 for field in required_fields}

    for entry in data:
        for field in required_fields:
            if field not in entry or not entry[field]:
                field_missing_count[field] += 1

    print("字段完整性检查:")
    for field, missing_count in field_missing_count.items():
        if missing_count == 0:
            print(f"  ✓ {field}: 完整")
        else:
            print(f"  ✗ {field}: {missing_count} 条记录缺失")

    # 检查答案分布
    answers = [entry['Answer'] for entry in data if 'Answer' in entry]
    answer_counts = pd.Series(answers).value_counts()
    print(f"\n答案分布:")
    for answer, count in answer_counts.items():
        print(f"  {answer}: {count} ({count/len(answers)*100:.1f}%)")

    # 检查问题长度分布
    question_lengths = [len(entry['Question']) for entry in data if 'Question' in entry]
    print(f"\n问题长度统计:")
    print(f"  平均长度: {sum(question_lengths)/len(question_lengths):.0f} 字符")
    print(f"  最短: {min(question_lengths)} 字符")
    print(f"  最长: {max(question_lengths)} 字符")

    return answer_counts

# 执行质量检查
answer_distribution = quality_check_extracted_data(all_qa_data)

数据质量检查:
字段完整性检查:
  ✓ id: 完整
  ✓ Question: 完整
  ✓ Answer: 完整

答案分布:
  No: 6917 (67.5%)
  Yes: 3337 (32.5%)

问题长度统计:
  平均长度: 148 字符
  最短: 85 字符
  最长: 238 字符


In [11]:
import os

# 获取当前工作路径
current_path = os.getcwd()
print("当前工作路径:", current_path)

# 查看目录内容（验证路径正确性）
print("\n目录内容:")
print(os.listdir())

当前工作路径: /content

目录内容:
['.config', 'sample_data']


In [17]:
# Cell 12: 保存结果
# ----------------
def save_extracted_qa_data(data, filename='extracted_qa_pairs.json'):
    """保存提取的问答对数据"""
    print(f"正在保存数据到 {filename}...")

    try:
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(data, f, indent=2, ensure_ascii=False)

        file_size_mb = len(json.dumps(data, ensure_ascii=False)) / (1024 * 1024)
        print(f"✓ 数据已成功保存到 {filename}")
        print(f"文件大小: {file_size_mb:.2f} MB")
        print(f"记录数: {len(data)}")

    except Exception as e:
        print(f"保存文件时出错: {e}")

# 保存数据
save_extracted_qa_data(all_qa_data)


正在保存数据到 extracted_qa_pairs.json...
✓ 数据已成功保存到 extracted_qa_pairs.json
文件大小: 1.94 MB
记录数: 10254


In [13]:
# Cell 13: 保存统计信息
# -------------------
def save_extraction_statistics(stats, answer_dist, filename='extraction_statistics.json'):
    """保存提取统计信息"""

    detailed_stats = {
        'extraction_summary': stats,
        'answer_distribution': dict(answer_dist),
        'schema': {
            'id': 'string - 自动生成的唯一标识符',
            'Question': 'string - 完整的问题文本（包含标题和问题）',
            'Answer': 'string - Yes或No'
        },
        'data_sources': {
            'examples': '从input文本中的示例问答对提取',
            'targets': '最后一个问题根据gold_index确定答案'
        },
        'timestamp': datetime.now().isoformat()
    }

    try:
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(detailed_stats, f, indent=2, ensure_ascii=False)

        print(f"✓ 统计信息已保存到 {filename}")

    except Exception as e:
        print(f"保存统计信息时出错: {e}")

# 保存统计信息
save_extraction_statistics(extraction_stats, answer_distribution)

# Cell 14: 最终报告
# ----------------
print("="*60)
print("问答对提取完成报告")
print("="*60)
print(f"原始数据集记录数: {extraction_stats['original_records']}")
print(f"提取的问答对数: {extraction_stats['extracted_qa_pairs']}")
print(f"平均每条记录产生: {extraction_stats['avg_pairs_per_record']:.1f} 个问答对")
print(f"总处理时间: {extraction_stats['processing_time']:.2f} 秒")

print(f"\n答案分布:")
for answer, count in answer_distribution.items():
    print(f"  {answer}: {count} ({count/len(all_qa_data)*100:.1f}%)")

print(f"\n输出文件:")
print(f"  - extracted_qa_pairs.json (主数据文件)")
print(f"  - extraction_statistics.json (统计信息)")

print(f"\n最终JSON格式:")
print("""{
  "id": "<auto_generated_id>",
  "Question": "Headline: \"<headline>\" Now answer this question: <question>?",
  "Answer": "<Yes_or_No>"
}""")

print("\n提取说明:")
print("- 示例问答对: 直接从input文本中解析答案")
print("- 目标问答对: 根据gold_index确定答案 (1=Yes, 0=No)")
print("- ID: 自动递增生成")

print("\n任务完成! ✓")

保存统计信息时出错: Object of type int64 is not JSON serializable
问答对提取完成报告
原始数据集记录数: 20547
提取的问答对数: 10254
平均每条记录产生: 0.5 个问答对
总处理时间: 1.71 秒

答案分布:
  No: 6917 (67.5%)
  Yes: 3337 (32.5%)

输出文件:
  - extracted_qa_pairs.json (主数据文件)
  - extraction_statistics.json (统计信息)

最终JSON格式:
{
  "id": "<auto_generated_id>",
  "Question": "Headline: "<headline>" Now answer this question: <question>?", 
  "Answer": "<Yes_or_No>"
}

提取说明:
- 示例问答对: 直接从input文本中解析答案
- 目标问答对: 根据gold_index确定答案 (1=Yes, 0=No)
- ID: 自动递增生成

任务完成! ✓


In [15]:
# Cell 15: 验证保存的文件
# ---------------------
def verify_extracted_file(filename='extracted_qa_pairs.json'):
    """验证保存的问答对文件"""
    try:
        with open(filename, 'r', encoding='utf-8') as f:
            loaded_data = json.load(f)

        print(f"✓ 文件 {filename} 验证成功")
        print(f"加载的问答对数: {len(loaded_data)}")

        # 显示几个完整的示例
        print(f"\n完整示例:")
        for i in range(min(12, len(loaded_data))):
            print(f"\n示例 {i+1}:")
            print(json.dumps(loaded_data[i], indent=2, ensure_ascii=False))

        return loaded_data

    except Exception as e:
        print(f"验证文件时出错: {e}")
        return None

# 验证保存的文件
print("验证保存的文件:")
verified_data = verify_extracted_file()

验证保存的文件:
✓ 文件 extracted_qa_pairs.json 验证成功
加载的问答对数: 10254

完整示例:

示例 1:
{
  "id": "0",
  "Question": "Headline: \"Gold falls to Rs 30,800; silver down at Rs 41,200 per kg\" Now answer this question: Does the news headline talk about price in the past?",
  "Answer": "Yes"
}

示例 2:
{
  "id": "1",
  "Question": "Headline: \"gold futures add to gains after adp data\" Now answer this question: Does the news headline talk about price?",
  "Answer": "Yes"
}

示例 3:
{
  "id": "2",
  "Question": "Headline: \"Gold holds on to modest loss after data\" Now answer this question: Does the news headline talk about price in the future?",
  "Answer": "No"
}

示例 4:
{
  "id": "3",
  "Question": "Headline: \"spot gold quoted at $417.50, down 20c from new york\" Now answer this question: Does the news headline talk about a general event (apart from prices) in the past?",
  "Answer": "No"
}

示例 5:
{
  "id": "4",
  "Question": "Headline: \"gold hits new record high at $1,036.20 an ounce\" Now answer this ques