In [6]:
import json
import os
from pathlib import Path
from typing import List, Dict, Any
from tqdm import tqdm
import shutil
from collections import defaultdict

# 配置GitHub信息
class Config:
    # 本地数据路径
    BASE_LOCAL_PATH = "/mnt/mydev2/M256374"
    
    # 数据集路径
    AOKVQA_PATH = f"{BASE_LOCAL_PATH}/A_OK_VQA"
    OKVQA_PATH = f"{BASE_LOCAL_PATH}/OK_VQA"
    FVQA_PATH = f"{BASE_LOCAL_PATH}/F_VQA"
    
    # ⚠️ GitHub配置 - 请修改为你的账户信息
    GITHUB_USERNAME = "HUGEOLab"  # 修改为你的GitHub用户名
    GITHUB_REPO = "unified-vqa-dataset"
    GITHUB_BRANCH = "main"
    BASE_URL = f"https://github.com/{GITHUB_USERNAME}/{GITHUB_REPO}/tree/{GITHUB_BRANCH}"
    
    # 输出路径
    OUTPUT_JSON = "unified_dataset.json"
    OUTPUT_STATS = "dataset_stats.json"
    
    # 本地GitHub仓库路径（如果已克隆）
    LOCAL_GITHUB_REPO = "/mnt/mydev2/M256374/unified-vqa-dataset"  # 例如: "/path/to/unified-vqa-dataset"

print("✓ 配置加载完成")
print(f"GitHub Base URL: {Config.BASE_URL}")

✓ 配置加载完成
GitHub Base URL: https://github.com/HUGEOLab/unified-vqa-dataset/tree/main


In [7]:

def load_json(path: str) -> Any:
    """加载JSON文件"""
    if not os.path.exists(path):
        raise FileNotFoundError(f"文件未找到: {path}")
    with open(path, 'r') as f:
        return json.load(f)

def save_json(data: Any, path: str) -> None:
    """保存JSON文件"""
    with open(path, 'w') as f:
        json.dump(data, f, indent=2, ensure_ascii=False)
    print(f"✓ 已保存: {path}")

def copy_image_to_github_repo(src_path: str, relative_dest_path: str) -> bool:
    """
    复制图片到本地GitHub仓库
    
    Args:
        src_path: 源图片路径
        relative_dest_path: 相对于GitHub仓库的目标路径
    
    Returns:
        bool: 是否成功复制
    """
    if Config.LOCAL_GITHUB_REPO is None:
        return False
    
    dest_path = os.path.join(Config.LOCAL_GITHUB_REPO, relative_dest_path)
    dest_dir = os.path.dirname(dest_path)
    
    os.makedirs(dest_dir, exist_ok=True)
    
    if os.path.exists(src_path):
        if not os.path.exists(dest_path):
            shutil.copy2(src_path, dest_path)
            return True
    return False

print("✓ 工具函数定义完成")

✓ 工具函数定义完成


In [8]:
class AOKVQAProcessor:
    @staticmethod
    def process() -> List[Dict]:
        """处理A-OKVQA数据集"""
        print("\n" + "="*60)
        print("处理 A-OKVQA 数据集")
        print("="*60)
        
        unified_samples = []
        
        for split in ['train', 'val']:
            json_path = f"{Config.AOKVQA_PATH}/json_files/{split}_datasets.json"
            
            if not os.path.exists(json_path):
                print(f"⚠ 未找到: {json_path}")
                continue
            
            print(f"\n加载 {split} 集合...")
            data = load_json(json_path)
            
            if not isinstance(data, list):
                data = [data]
            
            print(f"处理 {len(data)} 个样本...")
            for sample in tqdm(data):
                # 提取原始image_path
                original_image_path = sample.get('image_path', '')
                filename = os.path.basename(original_image_path)
                image_year = "train2017" if split == "train" else "val2017"
                
                local_relative_path = f"aokvqa/{image_year}/{filename}"
                image_url = f"{Config.BASE_URL}/data/{local_relative_path}"
                
                # 如果有本地GitHub仓库，复制图片
                if Config.LOCAL_GITHUB_REPO:
                    src_image_path = os.path.join(Config.AOKVQA_PATH, image_year, filename)
                    copy_image_to_github_repo(src_image_path, local_relative_path)
                
                # 创建统一格式的样本
                unified_sample = {
                    "dataset": "aokvqa",
                    "split": sample.get('split', split),
                    "image_id": sample.get('image_id'),
                    "question_id": sample.get('question_id'),
                    "question": sample.get('question'),
                    "image_url_path": image_url,
                    "image_local_path": local_relative_path,
                    # 保留原始字段
                    "choices": sample.get('choices'),
                    "correct_choice_idx": sample.get('correct_choice_idx'),
                    "direct_answers": sample.get('direct_answers'),
                    "difficult_direct_answer": sample.get('difficult_direct_answer'),
                    "rationales": sample.get('rationales')
                }
                
                unified_samples.append(unified_sample)
        
        print(f"\n✓ A-OKVQA 处理完成: {len(unified_samples)} 个样本")
        return unified_samples

# 执行处理
aokvqa_samples = AOKVQAProcessor.process()


处理 A-OKVQA 数据集

加载 train 集合...
处理 17056 个样本...


100%|██████████| 17056/17056 [02:45<00:00, 102.99it/s]



加载 val 集合...
处理 1145 个样本...


100%|██████████| 1145/1145 [00:12<00:00, 94.76it/s] 


✓ A-OKVQA 处理完成: 18201 个样本





In [9]:
class OKVQAProcessor:
    @staticmethod
    def process() -> List[Dict]:
        """处理OK-VQA数据集"""
        print("\n" + "="*60)
        print("处理 OK-VQA 数据集")
        print("="*60)
        
        unified_samples = []
        
        # 修正路径
        base_path = "/mnt/mydev2/M256374/previous_work/for_research/selection"
        data_files = {
            'train': f"{base_path}/train_data.json",
            'val': f"{base_path}/val_data.json"
        }
        
        for split, json_path in data_files.items():
            if not os.path.exists(json_path):
                print(f"⚠ 未找到: {json_path}")
                continue
            
            print(f"\n加载 {split} 集合...")
            data = load_json(json_path)
            
            if not isinstance(data, list):
                data = [data]
            
            print(f"处理 {len(data)} 个样本...")
            for sample in tqdm(data):
                # 提取原始image_path并获取文件名
                original_image_path = sample.get('image_path', '')
                filename = os.path.basename(original_image_path)
                image_year = "train2014" if split == "train" else "val2014"
                
                local_relative_path = f"okvqa/{image_year}/{filename}"
                image_url = f"{Config.BASE_URL}/data/{local_relative_path}"
                
                # 如果有本地GitHub仓库，复制图片
                if Config.LOCAL_GITHUB_REPO:
                    src_image_path = os.path.join(Config.OKVQA_PATH, image_year, filename)
                    copy_image_to_github_repo(src_image_path, local_relative_path)
                
                # 创建统一格式的样本
                unified_sample = {
                    "dataset": "okvqa",
                    "split": split,
                    "question_id": sample.get('question_id'),
                    "question": sample.get('question'),
                    "image_url_path": image_url,
                    "image_local_path": local_relative_path,
                    # 保留原始字段
                    "answers": sample.get('answers')
                }
                
                unified_samples.append(unified_sample)
        
        print(f"\n✓ OK-VQA 处理完成: {len(unified_samples)} 个样本")
        return unified_samples

# 执行处理
okvqa_samples = OKVQAProcessor.process()


处理 OK-VQA 数据集

加载 train 集合...
处理 9009 个样本...


100%|██████████| 9009/9009 [02:52<00:00, 52.27it/s]



加载 val 集合...
处理 5046 个样本...


100%|██████████| 5046/5046 [00:58<00:00, 86.22it/s] 


✓ OK-VQA 处理完成: 14055 个样本





In [10]:

class FVQAProcessor:
    @staticmethod
    def process() -> List[Dict]:
        """处理F-VQA数据集"""
        print("\n" + "="*60)
        print("处理 F-VQA 数据集")
        print("="*60)
        
        unified_samples = []
        
        json_path = f"{Config.FVQA_PATH}/input_files/all_data.json"
        
        if not os.path.exists(json_path):
            print(f"⚠ 未找到: {json_path}")
            return unified_samples
        
        print(f"加载数据...")
        data = load_json(json_path)
        
        # F-VQA的数据结构是字典，键是question_id
        items = data.items() if isinstance(data, dict) else enumerate(data)
        items_list = list(items)
        
        print(f"处理 {len(items_list)} 个样本...")
        for question_id, sample in tqdm(items_list):
            if isinstance(sample, dict):
                # 提取image_path
                original_image_path = sample.get('image_path', '')
                filename = os.path.basename(original_image_path)
                
                local_relative_path = f"fvqa/images/{filename}"
                image_url = f"{Config.BASE_URL}/data/{local_relative_path}"
                
                # 如果有本地GitHub仓库，复制图片
                if Config.LOCAL_GITHUB_REPO:
                    src_image_path = os.path.join(Config.FVQA_PATH, 'new_dataset_release', 'images', filename)
                    copy_image_to_github_repo(src_image_path, local_relative_path)
                
                # 创建统一格式的样本
                unified_sample = {
                    "dataset": "fvqa",
                    "split": "FULL",
                    "question_id": str(question_id),
                    "question": sample.get('question'),
                    "answer": sample.get('answer'),
                    "image_url_path": image_url,
                    "image_local_path": local_relative_path,
                    # 保留其他原始字段
                    "fact_surface": sample.get('fact_surface'),
                    "ans_source": sample.get('ans_source'),
                    "visual_concept": sample.get('visual_concept'),
                    "kb_source": sample.get('kb_source'),
                    "fact": sample.get('fact'),
                    "img_file": sample.get('img_file')
                }
                
                unified_samples.append(unified_sample)
        
        print(f"\n✓ F-VQA 处理完成: {len(unified_samples)} 个样本")
        return unified_samples

# 执行处理
fvqa_samples = FVQAProcessor.process()


处理 F-VQA 数据集
加载数据...
处理 5826 个样本...


100%|██████████| 5826/5826 [00:13<00:00, 423.91it/s]


✓ F-VQA 处理完成: 5826 个样本





In [11]:

# 合并所有样本
all_samples = aokvqa_samples + okvqa_samples + fvqa_samples

print("\n" + "="*60)
print("数据集合并完成")
print("="*60)
print(f"A-OKVQA: {len(aokvqa_samples)} 个样本")
print(f"OK-VQA: {len(okvqa_samples)} 个样本")
print(f"F-VQA: {len(fvqa_samples)} 个样本")
print(f"总计: {len(all_samples)} 个样本")

# ============================================================================
# Cell 7: 数据验证
# ============================================================================

def validate_dataset(samples: List[Dict]) -> None:
    """验证数据集完整性"""
    print("\n" + "="*60)
    print("数据验证")
    print("="*60)
    
    required_fields = ['dataset', 'split', 'question', 'image_url_path']
    
    missing_count = 0
    for i, sample in enumerate(samples):
        for field in required_fields:
            if field not in sample or sample[field] is None:
                print(f"✗ 样本 {i}: 缺少字段 '{field}'")
                missing_count += 1
    
    if missing_count == 0:
        print("✓ 所有必要字段完整")
    else:
        print(f"⚠ 发现 {missing_count} 个缺失字段")

validate_dataset(all_samples)


数据集合并完成
A-OKVQA: 18201 个样本
OK-VQA: 14055 个样本
F-VQA: 5826 个样本
总计: 38082 个样本

数据验证
✓ 所有必要字段完整


In [12]:
# ============================================================================
# Cell 8: 显示样本示例
# ============================================================================

def show_samples(samples: List[Dict], dataset: str = None, num: int = 2) -> None:
    """显示样本"""
    print("\n" + "="*60)
    if dataset:
        print(f"显示 {dataset.upper()} 数据集的样本")
        display_samples = [s for s in samples if s['dataset'] == dataset]
    else:
        print("显示随机样本")
        display_samples = samples
    print("="*60)
    
    for i, sample in enumerate(display_samples[:num]):
        print(f"\n样本 {i+1}:")
        print(f"  数据集: {sample['dataset'].upper()}")
        print(f"  分割: {sample.get('split', 'N/A')}")
        print(f"  问题: {sample['question'][:80]}...")
        if 'choices' in sample and sample['choices']:
            print(f"  选项: {sample['choices']}")
        if 'answers' in sample and sample['answers']:
            print(f"  答案: {sample['answers'][:3]}...")
        if 'answer' in sample:
            print(f"  答案: {sample['answer']}")
        print(f"  图片URL: {sample['image_url_path']}")

# 显示各数据集的样本
for dataset in ['aokvqa', 'okvqa', 'fvqa']:
    show_samples(all_samples, dataset=dataset, num=1)


显示 AOKVQA 数据集的样本

样本 1:
  数据集: AOKVQA
  分割: train
  问题: What is the man by the bags awaiting?...
  选项: ['skateboarder', 'train', 'delivery', 'cab']
  图片URL: https://github.com/HUGEOLab/unified-vqa-dataset/tree/main/data/aokvqa/train2017/000000299207.jpg

显示 OKVQA 数据集的样本

样本 1:
  数据集: OKVQA
  分割: train
  问题: What is the hairstyle of the blond called?...
  答案: ['pony tail', 'pony tail', 'pony tail']...
  图片URL: https://github.com/HUGEOLab/unified-vqa-dataset/tree/main/data/okvqa/train2014/COCO_train2014_000000051606.jpg

显示 FVQA 数据集的样本

样本 1:
  数据集: FVQA
  分割: FULL
  问题: Which object can be found in a jazz club...
  答案: trumpet
  图片URL: https://github.com/HUGEOLab/unified-vqa-dataset/tree/main/data/fvqa/images/ILSVRC2012_test_00050748.JPEG


In [13]:
# ============================================================================
# Cell 9: 生成统计信息
# ============================================================================

def generate_statistics(samples: List[Dict]) -> Dict:
    """生成数据统计"""
    stats = {
        "total_samples": len(samples),
        "by_dataset": {},
        "by_split": {},
        "sample_size": {}
    }
    
    # 按数据集统计
    for dataset in ['aokvqa', 'okvqa', 'fvqa']:
        count = len([s for s in samples if s['dataset'] == dataset])
        stats['by_dataset'][dataset] = count
    
    # 按分割统计
    for split in ['train', 'val', 'test']:
        count = len([s for s in samples if s.get('split') == split])
        if count > 0:
            stats['by_split'][split] = count
    
    # 计算JSON大小
    json_str = json.dumps(samples)
    stats['sample_size']['json_bytes'] = len(json_str.encode('utf-8'))
    stats['sample_size']['json_mb'] = len(json_str.encode('utf-8')) / (1024 * 1024)
    
    return stats

# 生成统计
stats = generate_statistics(all_samples)

print("\n" + "="*60)
print("数据统计信息")
print("="*60)
print(f"\n总样本数: {stats['total_samples']}")
print(f"\n按数据集统计:")
for dataset, count in stats['by_dataset'].items():
    print(f"  {dataset.upper()}: {count} 个样本")

print(f"\n按分割统计:")
for split, count in stats['by_split'].items():
    print(f"  {split}: {count} 个样本")

print(f"\n数据量:")
print(f"  JSON大小: {stats['sample_size']['json_mb']:.2f} MB")

# ============================================================================
# Cell 10: 保存统一数据集和统计信息
# ============================================================================

# 保存统一数据集
save_json(all_samples, Config.OUTPUT_JSON)

# 保存统计信息
save_json(stats, Config.OUTPUT_STATS)

print("\n✓ 数据已保存完成")

# ============================================================================
# Cell 11: 加载和验证保存的数据
# ============================================================================

# 加载验证
with open(Config.OUTPUT_JSON, 'r') as f:
    loaded_data = json.load(f)

print(f"\n✓ 验证: 成功加载 {len(loaded_data)} 个样本")
print(f"✓ 第一个样本的数据集: {loaded_data[0]['dataset']}")
print(f"✓ 最后一个样本的数据集: {loaded_data[-1]['dataset']}")


数据统计信息

总样本数: 38082

按数据集统计:
  AOKVQA: 18201 个样本
  OKVQA: 14055 个样本
  FVQA: 5826 个样本

按分割统计:
  train: 26065 个样本
  val: 6191 个样本

数据量:
  JSON大小: 22.85 MB
✓ 已保存: unified_dataset.json
✓ 已保存: dataset_stats.json

✓ 数据已保存完成

✓ 验证: 成功加载 38082 个样本
✓ 第一个样本的数据集: aokvqa
✓ 最后一个样本的数据集: fvqa


In [14]:
# ============================================================================
# Cell 12: GitHub上传指南
# ============================================================================

print("\n" + "="*60)
print("GitHub上传指南")
print("="*60)

if Config.LOCAL_GITHUB_REPO:
    print(f"""
✓ 本地图片已复制到: {Config.LOCAL_GITHUB_REPO}

请执行以下命令上传到GitHub:

1. 进入仓库目录:
   cd {Config.LOCAL_GITHUB_REPO}

2. 查看状态:
   git status

3. 添加所有文件:
   git add .

4. 提交更改:
   git commit -m "Add unified VQA dataset with {len(all_samples)} samples"

5. 上传到GitHub:
   git push origin main

6. 验证上传:
   访问: https://github.com/{Config.GITHUB_USERNAME}/{Config.GITHUB_REPO}
""")
else:
    print(f"""
⚠ 未设置本地GitHub仓库路径

如果要自动复制图片到GitHub仓库，请:
1. 创建GitHub仓库: {Config.GITHUB_REPO}
2. 克隆到本地: git clone https://github.com/{Config.GITHUB_USERNAME}/{Config.GITHUB_REPO}.git
3. 修改Config.LOCAL_GITHUB_REPO = "/path/to/{Config.GITHUB_REPO}"
4. 重新运行脚本

或者手动执行:
1. 上传 unified_dataset.json 到仓库
2. 上传 dataset_stats.json 到仓库
3. 创建对应的目录结构并上传图片
""")




GitHub上传指南

✓ 本地图片已复制到: /mnt/mydev2/M256374/unified-vqa-dataset

请执行以下命令上传到GitHub:

1. 进入仓库目录:
   cd /mnt/mydev2/M256374/unified-vqa-dataset

2. 查看状态:
   git status

3. 添加所有文件:
   git add .

4. 提交更改:
   git commit -m "Add unified VQA dataset with 38082 samples"

5. 上传到GitHub:
   git push origin main

6. 验证上传:
   访问: https://github.com/HUGEOLab/unified-vqa-dataset



In [6]:
import json

INPUT = "/mnt/mydev2/M256374/unified-vqa-dataset/unified_dataset.json"
OUTPUT = "/mnt/mydev2/M256374/unified-vqa-dataset/unified_dataset.json"

HF_BASE = "https://huggingface.co/datasets/Geojx/unified-vqa-images/resolve/main"

def to_rel_path(s: str) -> str:
    """Turn various forms into a clean relative path like aokvqa/train2017/xxx.jpg"""
    if not isinstance(s, str):
        return ""
    s = s.strip()

    # Case 1: already relative path
    if "://" not in s:
        return s.lstrip("/")

    # Case 2: github URL like .../data/aokvqa/train2017/xxx.jpg
    # Extract the part after "/data/" if present.
    marker = "/data/"
    if marker in s:
        return s.split(marker, 1)[1].lstrip("/")

    # Otherwise: unknown URL, return empty to avoid making garbage HF urls
    return ""

with open(INPUT, "r") as f:
    data = json.load(f)

fixed = 0
skipped = 0

for item in data:
    # 1) best source: image_local_path
    rel = to_rel_path(item.get("image_local_path", ""))

    # 2) fallback: some datasets might store local path in other keys
    if not rel:
        rel = to_rel_path(item.get("image_path", ""))

    # 3) last resort: try image_url_path (github url)
    if not rel:
        rel = to_rel_path(item.get("image_url_path", ""))

    if not rel:
        skipped += 1
        continue

    # optional: fix typo extension
    if rel.endswith(".jpp"):
        rel = rel[:-4] + ".jpg"

    item["image_local_path"] = rel              # normalize it
    item["image_url"] = f"{HF_BASE}/{rel}"      # hf url

    fixed += 1

with open(OUTPUT, "w") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

print(f"✓ fixed {fixed} samples, skipped {skipped}")



✓ fixed 38082 samples, skipped 0
