In [None]:
import json
import random
from collections import defaultdict
import os

# ======= 用户自定义配置区域 =======
json_path = 'arxivqa.jsonl'                # 输入数据路径（支持 .json or .jsonl）
train_cats = ['math', 'cs']             # 训练和验证用的两个分类
zeroshot_cat = 'physics'                   # Zero-Shot 测试的分类
output_dir = 'split_result'                # 输出目录
seed = 42                                  # 随机种子
max_samples_per_class = 2000               # 每个训练类别最大样本数
val_samples_per_class = 500                # 每个验证类别最大样本数
max_zeroshot_samples = 500                 # Zero-shot 分类最大样本数
# ===================================

def extract_category(example):
    return example['id'].split('-')[0]

def load_json_or_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        first_line = f.readline()
        f.seek(0)
        if first_line.strip().startswith('{'):
            return [json.loads(line.strip()) for line in f if line.strip()]
        else:
            return json.load(f)

def split_data(data, train_cats, zeroshot_cat, seed=42, max_per_class=None, val_per_class=None, max_zero=None):
    random.seed(seed)
    categorized = defaultdict(list)

    for item in data:
        cat = extract_category(item)
        categorized[cat].append(item)

    train_data = []
    val_data = []
    test_data = []

    for cat in train_cats:
        items = categorized[cat]
        random.shuffle(items)

        required_total = (max_per_class or 0) + (val_per_class or 0)
        if len(items) < required_total:
            raise ValueError(f"分类 {cat} 的样本不足 {required_total} 条，当前只有 {len(items)} 条")

        train_items = items[:max_per_class]
        val_items = items[max_per_class:max_per_class + val_per_class]

        train_data.extend(train_items)
        val_data.extend(val_items)

        print(f"[{cat}] 训练集: {len(train_items)} 条，验证集: {len(val_items)} 条，总样本: {len(items)}")

    # Zero-shot 测试数据
    test_items = categorized[zeroshot_cat]
    random.shuffle(test_items)
    if max_zero is not None:
        test_items = test_items[:max_zero]
    test_data = test_items

    print(f"[{zeroshot_cat}] Zero-shot 测试集: {len(test_data)} 条")

    return train_data, val_data, test_data

def save_as_jsonl(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

def main():
    data = load_json_or_jsonl(json_path)
    train_data, val_data, test_data = split_data(
        data,
        train_cats,
        zeroshot_cat,
        seed=seed,
        max_per_class=max_samples_per_class,
        val_per_class=val_samples_per_class,
        max_zero=max_zeroshot_samples
    )

    os.makedirs(output_dir, exist_ok=True)

    save_as_jsonl(train_data, os.path.join(output_dir, 'train.jsonl'))
    save_as_jsonl(val_data, os.path.join(output_dir, 'val.jsonl'))
    save_as_jsonl(test_data, os.path.join(output_dir, 'test_zeroshot.jsonl'))

    print(f"\n✅ 数据划分完成，输出目录：{output_dir}")
    print(f"训练集总大小: {len(train_data)}")
    print(f"验证集总大小: {len(val_data)}")
    print(f"Zero-Shot 测试集大小: {len(test_data)}")

if __name__ == '__main__':
    main()


✅ 数据划分完成，输出目录：split_result
训练集大小: 1600
验证集大小: 200
Zero-Shot 测试集大小: 200


In [None]:
import json
from collections import Counter


def extract_category(example):
    return example['id'].split('-')[0]

def count_categories(jsonl_path):
    counter = Counter()
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                category = extract_category(item)
                counter[category] += 1
    return counter

if __name__ == "__main__":
    jsonl_path1 = "split_result/test_zeroshot.jsonl"  # 替换成你的文件路径
    counts1 = count_categories(jsonl_path1)
    print("每个zero shot:")
    for cat, count in counts1.items():
        print(f"{cat:15s}: {count}")
    print('\n')
     
    jsonl_path2 = 'split_result/train.jsonl'
    counts2 = count_categories(jsonl_path2)
    print("每个train:")
    for cat, count in counts2.items():
        print(f"{cat:15s}: {count}")
    print('\n')  
    
    jsonl_path3 = 'split_result/val.jsonl'
    counts3 = count_categories(jsonl_path3)
    print("每个val:")
    for cat, count in counts3.items():
        print(f"{cat:15s}: {count}")

每个zero shot:

physics        : 200

每个train:

physics        : 800
cs             : 800
每个val:

physics        : 100
cs             : 100


In [None]:
with open("split_result/test_zeroshot.jsonl", 'r') as fr:
  arxiv_qa = [ json.loads(line.strip()) for line in fr]
len(arxiv_qa)

200

In [None]:
with open("split_result/test_zeroshot.jsonl", 'r') as fr:
  arxiv_qa = [ json.loads(line.strip()) for line in fr]
  
test_dataset = []
for i in range(len(arxiv_qa)):
    sample = arxiv_qa[i]
    try:
        if sample['question'] is None:
            print(f"第 {i} 个样本没有问题，跳过")
            continue

        test_dataset.append({
            'id':sample['id'],
            "image": sample['image'], 
            "question": sample["question"],
            "choices": sample["options"],
            "answer": sample["label"],
            "solution": sample["rationale"]
        })
    except Exception as e:
        print(f"跳过第 {i} 个样本，错误：{e}")
        continue

In [29]:
test_dataset[1]

{'id': 'physics-40887',
 'image': 'images/1904.13123_0.jpg',
 'question': 'Which computational method consistently estimates the energy closest to the FCI benchmark across all C-H bond lengths for the \\( ^1\\Sigma^+ \\) state?',
 'choices': ['EOM-CCSD',
  'EOM-pCCD-LCCSD(HF)',
  'EOM-pCCD-LCCSD(opt)',
  'All methods estimate equally close to FCI',
  '## Question 2'],
 'answer': 'C. EOM-pCCD-LCCSD(opt)',
 'solution': 'In the figure for the \\( ^1\\Sigma^+ \\) state, the EOM-pCCD-LCCSD(opt) data points are consistently nearest to the FCI dashed line across the range of C-H bond lengths, indicating that this method provides the closest estimates to the FCI benchmark.'}

In [1]:
import os
import json
import shutil

# ======= 用户配置区域 =======
jsonl_files = [
    "split_result/train.jsonl",
    "split_result/val.jsonl",
    "split_result/test_zeroshot.jsonl"
]  # 你的三个jsonl路径

image_root = '.'              # 原始图片所在的根目录（如 images/ 在这个目录下）
output_dir = 'using_image'     # 新建图片保存目录
# ============================

def collect_image_paths(jsonl_paths):
    image_paths = set()
    for path in jsonl_paths:
        with open(path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip(): continue
                data = json.loads(line)
                image_path = data.get("image")
                if image_path:
                    image_paths.add(image_path)
    return sorted(image_paths)

def copy_images(image_paths, src_root, dst_dir):
    os.makedirs(dst_dir, exist_ok=True)
    missing = []

    for rel_path in image_paths:
        src_path = os.path.join(src_root, rel_path)
        dst_path = os.path.join(dst_dir, os.path.basename(rel_path))

        if os.path.exists(src_path):
            shutil.copy2(src_path, dst_path)
        else:
            missing.append(rel_path)

    print(f"✅ 已复制 {len(image_paths) - len(missing)} 张图片到：{dst_dir}")
    if missing:
        print(f"⚠️ 有 {len(missing)} 张图片未找到:")
        for m in missing[:5]:
            print(f" - {m}")
        if len(missing) > 5:
            print("...")

def main():
    all_image_paths = collect_image_paths(jsonl_files)
    copy_images(all_image_paths, image_root, output_dir)

if __name__ == '__main__':
    main()


✅ 已复制 1933 张图片到：using_image


In [2]:
import tarfile
import os

def compress_to_tgz(input_paths, output_tgz_path):
    with tarfile.open(output_tgz_path, "w:gz") as tar:
        for path in input_paths:
            if not os.path.exists(path):
                print(f"⚠️ 文件或目录不存在：{path}")
                continue
            tar.add(path, arcname=os.path.basename(path))
    print(f"📦 压缩完成：{output_tgz_path}")

# ======= 用户配置区 =======
# 可以是文件或文件夹的路径列表
input_files = [
    'using_image'
]
output_tgz = 'using_image.tgz'
# ==========================

if __name__ == '__main__':
    compress_to_tgz(input_files, output_tgz)



📦 压缩完成：using_image.tgz


In [2]:
import tarfile
import os

def extract_tgz(tgz_path, output_dir):
    if not os.path.exists(tgz_path):
        print(f"❌ 找不到文件: {tgz_path}")
        return

    with tarfile.open(tgz_path, "r:gz") as tar:
        tar.extractall(path=output_dir)
    print(f"✅ 解压完成：{tgz_path} → {output_dir}")

# ======= 用户配置区 =======
tgz_file = 'using_image.tgz'        # 要解压的 .tgz 文件路径
extract_to = 'using_image'        # 解压到哪个目录
# ===========================

if __name__ == '__main__':
    extract_tgz(tgz_file, extract_to)


✅ 解压完成：using_image.tgz → using_image
