In [None]:
import json
from pathlib import Path
from tqdm import tqdm

def filter_complete_samples(input_jsonl, output_jsonl, image_root):
    """
    过滤掉有缺失图片的样本，输出新的jsonl文件
    
    Args:
        input_jsonl: 输入JSONL文件路径
        output_jsonl: 输出JSONL文件路径
        image_root: 图片根目录
    """
    image_root = Path(image_root)
    
    # 统计信息
    total_samples = 0
    valid_samples = 0
    removed_samples = 0
    total_images = 0
    missing_images = 0
    
    # 先计算总行数
    with open(input_jsonl, 'r', encoding='utf-8') as f:
        total_lines = sum(1 for _ in f)
    
    # 处理数据
    with open(input_jsonl, 'r', encoding='utf-8') as f_in, \
         open(output_jsonl, 'w', encoding='utf-8') as f_out:
        
        for line in tqdm(f_in, total=total_lines, desc="处理样本"):
            try:
                data = json.loads(line.strip())
                total_samples += 1
                
                sample_id = data.get('id', '')
                has_missing = False
                
                # 检查 visual_demo 中的图片
                visual_demo = data.get('visual_demo', [])
                for img_name in visual_demo:
                    img_path = image_root / sample_id / img_name
                    total_images += 1
                    if not img_path.exists():
                        has_missing = True
                        missing_images += 1
                
                # 检查 stage_to_estimate 中的图片
                stage_to_estimate = data.get('stage_to_estimate', [])
                for img_name in stage_to_estimate:
                    img_path = image_root / sample_id / img_name
                    total_images += 1
                    if not img_path.exists():
                        has_missing = True
                        missing_images += 1
                
                # 如果所有图片都存在，写入输出文件
                if not has_missing:
                    f_out.write(line)
                    valid_samples += 1
                else:
                    removed_samples += 1
                    
            except:
                pass
    
    # 打印统计结果
    print(f"\n{'='*80}")
    print(f"处理完成!")
    print(f"{'='*80}")
    print(f"输入样本数: {total_samples}")
    print(f"输出样本数: {valid_samples}")
    print(f"移除样本数: {removed_samples}")
    print(f"保留比例: {valid_samples/total_samples*100:.2f}%" if total_samples > 0 else "N/A")
    print(f"\n总图片数: {total_images}")
    print(f"缺失图片数: {missing_images}")
    print(f"图片完整性: {(total_images-missing_images)/total_images*100:.2f}%" if total_images > 0 else "N/A")
    print(f"\n输出文件: {output_jsonl}")
    
    if removed_samples == 0:
        print(f"\n✅ 所有样本图片完整!")
    else:
        print(f"\n⚠️  移除了 {removed_samples} 个有缺失图片的样本")


# if __name__ == '__main__':
#     # 设置路径
#     INPUT_JSONL = '/home/vcj9002/jianshu/workspace/code/ProgressLM/data/raw/visual_demo/visual_h5_franka_3rgb_raw.jsonl'      # 输入JSONL文件
#     OUTPUT_JSONL = '/home/vcj9002/jianshu/workspace/code/ProgressLM/data/raw/visual_demo/visual_h5_franka_3rgb_train.jsonl'    # 输出JSONL文件
#     IMAGE_ROOT = '/home/vcj9002/jianshu/workspace/data/robomind/data/images'  # 图片根目录
    
#     # 运行过滤
#     filter_complete_samples(INPUT_JSONL, OUTPUT_JSONL, IMAGE_ROOT)

if __name__ == '__main__':
    # 设置路径
    INPUT_JSONL = '/home/vcj9002/jianshu/workspace/code/ProgressLM/data/eval/visual/visual_h5_franka_3rgb_t.jsonl'      # 输入JSONL文件
    OUTPUT_JSONL = '/home/vcj9002/jianshu/workspace/code/ProgressLM/data/eval/visual/visual_h5_franka_3rgb_test.jsonl'    # 输出JSONL文件
    IMAGE_ROOT = '/home/vcj9002/jianshu/workspace/data/robomind/data/images'  # 图片根目录
    
    # 运行过滤
    filter_complete_samples(INPUT_JSONL, OUTPUT_JSONL, IMAGE_ROOT)

处理样本: 100%|██████████| 36195/36195 [00:06<00:00, 5764.22it/s]


处理完成!
输入样本数: 36195
输出样本数: 33186
移除样本数: 3009
保留比例: 91.69%

总图片数: 252099
缺失图片数: 21879
图片完整性: 91.32%

输出文件: /home/vcj9002/jianshu/workspace/code/ProgressLM/data/raw/visual_demo/visual_h5_franka_3rgb_train.jsonl

⚠️  移除了 3009 个有缺失图片的样本



