In [14]:
import json
import random
from collections import defaultdict
import os

# ======= 用户自定义配置区域 =======
json_path = 'arxivqa.jsonl'                # 输入数据路径（支持 .json or .jsonl）
train_cats = ['physics', 'cs']             # 训练和验证用的两个分类
zeroshot_cat = 'physics'                  # Zero-Shot 测试的分类
output_dir = 'split_result'        # 输出目录
seed = 42                                  # 随机种子
max_samples_per_class = 1000               # 每个训练/验证类别最大样本数
max_zeroshot_samples = 200                # Zero-shot 分类最大样本数
# ===================================

def extract_category(example):
    return example['id'].split('-')[0]

def load_json_or_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        first_line = f.readline()
        f.seek(0)
        if first_line.strip().startswith('{'):
            return [json.loads(line.strip()) for line in f if line.strip()]
        else:
            return json.load(f)

def split_data(data, train_cats, zeroshot_cat, seed=42, max_per_class=None, max_zero=None):
    random.seed(seed)
    categorized = defaultdict(list)

    for item in data:
        cat = extract_category(item)
        categorized[cat].append(item)

    train_data = []
    val_data = []
    test_data = []

    for cat in train_cats:
        items = categorized[cat]
        random.shuffle(items)
        if max_per_class is not None:
            items = items[:max_per_class]
        n = len(items)
        train_end = int(0.8 * n)
        val_end = int(0.9 * n)
        train_data.extend(items[:train_end])
        val_data.extend(items[train_end:val_end])

    # Zero-shot 测试数据（也限制最大数量）
    test_items = categorized[zeroshot_cat]
    random.shuffle(test_items)
    if max_zero is not None:
        test_items = test_items[:max_zero]
    test_data = test_items

    return train_data, val_data, test_data

def save_as_jsonl(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def main():
    data = load_json_or_jsonl(json_path)
    train_data, val_data, test_data = split_data(
        data,
        train_cats,
        zeroshot_cat,
        seed=seed,
        max_per_class=max_samples_per_class,
        max_zero=max_zeroshot_samples
    )

    os.makedirs(output_dir, exist_ok=True)

    save_as_jsonl(train_data, os.path.join(output_dir, 'train.jsonl'))
    save_as_jsonl(val_data, os.path.join(output_dir, 'val.jsonl'))
    save_as_jsonl(test_data, os.path.join(output_dir, 'test_zeroshot.jsonl'))

    print(f"✅ 数据划分完成，输出目录：{output_dir}")
    print(f"训练集大小: {len(train_data)}")
    print(f"验证集大小: {len(val_data)}")
    print(f"Zero-Shot 测试集大小: {len(test_data)}")

if __name__ == '__main__':
    main()


✅ 数据划分完成，输出目录：split_result
训练集大小: 1600
验证集大小: 200
Zero-Shot 测试集大小: 200


In [21]:
import json
from collections import Counter


def extract_category(example):
    return example['id'].split('-')[0]

def count_categories(jsonl_path):
    counter = Counter()
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                category = extract_category(item)
                counter[category] += 1
    return counter

if __name__ == "__main__":
    jsonl_path1 = "I:/ArxivQA/split_result/test_zeroshot.jsonl"  # 替换成你的文件路径
    counts1 = count_categories(jsonl_path1)
    print("每个zero shot:\n")
    for cat, count in counts1.items():
        print(f"{cat:15s}: {count}\n")
    
    jsonl_path2 = 'I:/ArxivQA\split_result/train.jsonl'
    counts2 = count_categories(jsonl_path2)
    print("每个train:\n")
    for cat, count in counts2.items():
        print(f"{cat:15s}: {count}")
        
    jsonl_path3 = 'I:/ArxivQA\split_result/val.jsonl'
    counts3 = count_categories(jsonl_path3)
    print("每个val:\n")
    for cat, count in counts3.items():
        print(f"{cat:15s}: {count}")
        
        

每个zero shot:

physics        : 200

每个train:

physics        : 800
cs             : 800
每个val:

physics        : 100
cs             : 100


In [26]:
with open("I:/ArxivQA/split_result/test_zeroshot.jsonl", 'r') as fr:
  arxiv_qa = [ json.loads(line.strip()) for line in fr]
len(arxiv_qa)

200

In [28]:
with open("I:/ArxivQA/split_result/test_zeroshot.jsonl", 'r') as fr:
  arxiv_qa = [ json.loads(line.strip()) for line in fr]
  
test_dataset = []
for i in range(len(arxiv_qa)):
    sample = arxiv_qa[i]
    try:
        if sample['question'] is None:
            print(f"第 {i} 个样本没有问题，跳过")
            continue

        test_dataset.append({
            'id':sample['id'],
            "image": sample['image'], 
            "question": sample["question"],
            "choices": sample["options"],
            "answer": sample["label"],
            "solution": sample["rationale"]
        })
    except Exception as e:
        print(f"跳过第 {i} 个样本，错误：{e}")
        continue

In [29]:
test_dataset[1]

{'id': 'physics-40887',
 'image': 'images/1904.13123_0.jpg',
 'question': 'Which computational method consistently estimates the energy closest to the FCI benchmark across all C-H bond lengths for the \\( ^1\\Sigma^+ \\) state?',
 'choices': ['EOM-CCSD',
  'EOM-pCCD-LCCSD(HF)',
  'EOM-pCCD-LCCSD(opt)',
  'All methods estimate equally close to FCI',
  '## Question 2'],
 'answer': 'C. EOM-pCCD-LCCSD(opt)',
 'solution': 'In the figure for the \\( ^1\\Sigma^+ \\) state, the EOM-pCCD-LCCSD(opt) data points are consistently nearest to the FCI dashed line across the range of C-H bond lengths, indicating that this method provides the closest estimates to the FCI benchmark.'}

In [1]:
import os
import json
import shutil

# ======= 用户配置区域 =======
jsonl_files = [
    "split_result/train.jsonl",
    "split_result/val.jsonl",
    "split_result/test_zeroshot.jsonl"
]  # 你的三个jsonl路径

image_root = '.'              # 原始图片所在的根目录（如 images/ 在这个目录下）
output_dir = 'using_image'     # 新建图片保存目录
# ============================

def collect_image_paths(jsonl_paths):
    image_paths = set()
    for path in jsonl_paths:
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip(): continue
                data = json.loads(line)
                image_path = data.get("image")
                if image_path:
                    image_paths.add(image_path)
    return sorted(image_paths)

def copy_images(image_paths, src_root, dst_dir):
    os.makedirs(dst_dir, exist_ok=True)
    missing = []

    for rel_path in image_paths:
        src_path = os.path.join(src_root, rel_path)
        dst_path = os.path.join(dst_dir, os.path.basename(rel_path))

        if os.path.exists(src_path):
            shutil.copy2(src_path, dst_path)
        else:
            missing.append(rel_path)

    print(f"✅ 已复制 {len(image_paths) - len(missing)} 张图片到：{dst_dir}")
    if missing:
        print(f"⚠️ 有 {len(missing)} 张图片未找到:")
        for m in missing[:5]:
            print(f" - {m}")
        if len(missing) > 5:
            print("...")

def main():
    all_image_paths = collect_image_paths(jsonl_files)
    copy_images(all_image_paths, image_root, output_dir)

if __name__ == '__main__':
    main()


✅ 已复制 1933 张图片到：using_image
