In [9]:
import json
from collections import Counter
import random

# 读取test.jsonl文件
category_subcategory_pairs = []
sampled_data = []
pair_counts = {}

with open('test.jsonl', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line in lines:
        data = json.loads(line)
        metadata = data['metadata']
        pair = f"{metadata['category']}-{metadata['subcategory']}"
        
        # 统计每个pair的数量
        if pair not in pair_counts:
            pair_counts[pair] = []
        pair_counts[pair].append(line)

# 对每个pair进行采样
for pair, lines in pair_counts.items():
    count = len(lines)
    if count > 400:
        # 根据数量确定采样比例
        if count > 2000:
            ratio = 0.3  # 大类别采样比例0.3
        else:
            # 400-2000之间的类别,采样比例在0.66-0.3之间线性变化
            ratio = 0.66 - (count - 400) * (0.36) / (2000 - 400)
        sample_size = int(count * ratio)
        sampled_lines = random.sample(lines, sample_size)
        sampled_data.extend(sampled_lines)
    else:
        # 小类别全部保留
        sampled_data.extend(lines)

# 写入采样后的数据
with open('test_sampled.jsonl', 'w', encoding='utf-8') as f:
    for line in sampled_data:
        f.write(line)

# 统计采样后的分布
final_pairs = []
for line in sampled_data:
    data = json.loads(line)
    metadata = data['metadata']
    pair = f"{metadata['category']}-{metadata['subcategory']}"
    final_pairs.append(pair)

distribution = Counter(final_pairs)

# 打印采样后的分布情况
print("采样后的类别-子类别组合分布情况:")
for pair, count in distribution.most_common():
    print(f"{pair}: {count}")


采样后的类别-子类别组合分布情况:
data understanding-fact checking: 624
data understanding-data identification: 624
data understanding-data comparison: 616
visual understanding-visual elements retrieval: 613
data understanding-data extraction with condition: 529
composite understanding-data identification: 426
data understanding-data counting with condition: 416
visual understanding-chart classification: 330
visual understanding-style detection: 324
composite understanding-fact checking: 294
composite understanding-data comparison: 278
composite understanding-data counting with condition: 234
composite understanding-data extraction with condition: 212


In [2]:
import json
from collections import Counter
import random

# 读取train.jsonl文件
category_subcategory_pairs = []
sampled_data = []
pair_counts = {}

with open('train.jsonl', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line in lines:
        data = json.loads(line)
        metadata = data['metadata']
        pair = f"{metadata['category']}-{metadata['subcategory']}"
        
        # 先统计每个pair的数量
        if pair not in pair_counts:
            pair_counts[pair] = []
        pair_counts[pair].append(line)

# 对每个pair进行采样
for pair, lines in pair_counts.items():
    count = len(lines)
    if count > 10000:
        # 采样到max(10000, count/2)
        sample_size = max(10000, count // 2)
        sampled_lines = random.sample(lines, sample_size)
        sampled_data.extend(sampled_lines)
    else:
        sampled_data.extend(lines)

# 写入采样后的数据
with open('train_sampled.jsonl', 'w', encoding='utf-8') as f:
    for line in sampled_data:
        f.write(line)

# 统计采样后的分布
final_pairs = []
for line in sampled_data:
    data = json.loads(line)
    metadata = data['metadata']
    pair = f"{metadata['category']}-{metadata['subcategory']}"
    final_pairs.append(pair)

distribution = Counter(final_pairs)

# 按频次降序排序并打印结果
print("采样后类别-子类别组合的分布情况:")
for pair, count in distribution.most_common():
    print(f"{pair}: {count}")


采样后类别-子类别组合的分布情况:
data understanding-data comparison: 21271
data understanding-data identification: 19397
data understanding-fact checking: 19147
data understanding-data extraction with condition: 13489
visual understanding-visual elements retrieval: 11352
data understanding-data counting with condition: 10000
composite understanding-data identification: 7184
visual understanding-chart classification: 6786
visual understanding-style detection: 6579
composite understanding-fact checking: 4740
composite understanding-data comparison: 2873
composite understanding-data counting with condition: 2435
composite understanding-data extraction with condition: 2273


In [5]:
import json
import re

# 读取train_sampled.jsonl文件
modified_data = []

with open('train_sampled.jsonl', 'r', encoding='utf-8') as f:
    lines = f.readlines()
    for line in lines:
        data = json.loads(line)
        # 检查是否为chart classification问题
        if "chart classification" in data['metadata']['subcategory'].lower():
            # 从conversation中提取问题文本
            question_text = data['conversation'][0]['content']
            
            # 使用正则表达式提取选项部分
            options_text = re.search(r'Select.*?(?=\n\nAnswer)', question_text, re.DOTALL)
            if options_text:
                options_text = options_text.group()
                # 提取每个选项的文本
                options = re.findall(r'[A-D]\.?\s*(.*?)(?=\n[A-D]\.?|$)', options_text)
                
                # 替换每个选项为小写
                modified_text = options_text
                for i, opt in enumerate(options):
                    # 保持选项字母(A/B/C/D)为大写,但选项内容转为小写
                    old_option = f"{chr(65+i)}. {opt}"
                    new_option = f"{chr(65+i)}. {opt.lower()}"
                    modified_text = modified_text.replace(old_option, new_option)
                    
                    # 处理其他格式的选项
                    old_option = f"{chr(65+i)}) {opt}"
                    new_option = f"{chr(65+i)}) {opt.lower()}"
                    modified_text = modified_text.replace(old_option, new_option)
                    
                    old_option = f"({chr(65+i)}) {opt}"
                    new_option = f"({chr(65+i)}) {opt.lower()}"
                    modified_text = modified_text.replace(old_option, new_option)
                
                # 替换原文中的选项部分
                question_text = re.sub(r'Select.*?(?=\n\nAnswer)', modified_text, question_text, flags=re.DOTALL)
                data['conversation'][0]['content'] = question_text
                
        modified_data.append(data)

# 写回修改后的数据
with open('train_sampled.jsonl', 'w', encoding='utf-8') as f:
    for data in modified_data:
        f.write(json.dumps(data, ensure_ascii=False) + '\n')

print("已完成选项小写转换")


已完成选项小写转换
