In [1]:
import json
import os
###############获得所有chunk数据
def get_chunk_dict(directory):
    chunk_dict={}
    for filename in os.listdir(directory):
        file_path = os.path.join(directory, filename)
        # print(file_path)
        if filename.endswith('.jsonl') and filename.startswith('chunks'):
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_number, line in enumerate(f, 1):
                    item = json.loads(line.strip())
                    chunk_id = item['chunk_id']
                    chunk_dict[chunk_id] = item
    return chunk_dict
chunk_dict=get_chunk_dict('/root/autodl-fs/OURMRAG/data/docs/new_chunks/')

In [2]:
def read_all_json_files(directory):
    data_list = []
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            file_path = os.path.join(directory, filename)
            with open(file_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                data_list.append(data)
    return data_list
# 用法示例
directory_path = '/root/autodl-fs/IMAGE/images_info'
all_image_info = read_all_json_files(directory_path)
image2format = {}
for item in all_image_info:
    for key in item:
        image2format[str(key)] = item[key]['image_path']

In [3]:
def flatten_and_deduplicate(image_lists):
    seen = set()
    unique_images = []
    for row in image_lists:
        for img in row:
            if img not in seen:
                seen.add(img)
                unique_images.append(img)
    return unique_images

In [4]:
def merge_candidate_docs(text_candidate_docs, image_candidate_docs):
    set_text = set(text_candidate_docs)
    set_image = set(image_candidate_docs)
    # 1. 找到重复的文档，顺序按 image_candidate_docs 保持
    common = [doc for doc in image_candidate_docs if doc in set_text]
    # 2. 剩余的不重复的文档（排除 common 部分）
    text_unique = [doc for doc in text_candidate_docs if doc not in common]
    image_unique = [doc for doc in image_candidate_docs if doc not in common]
    # 3. 交错合并剩余部分
    interleaved = []
    max_len = max(len(text_unique), len(image_unique))
    for i in range(max_len):
        if i < len(image_unique):
            interleaved.append(image_unique[i])
        if i < len(text_unique):
            interleaved.append(text_unique[i])
    # 4. 合并最终结果
    final_list = common + interleaved
    return final_list

In [5]:
def chunk_info_all(chunks):
    chunk_text=[]
    chunk_images=[]
    for chunk in chunks:
        chunk_text.append(chunk_dict[chunk]['text'])
        image_list=chunk_dict[chunk]['images']
        # image_list = [image2format[img] for img in image_list]
        chunk_images.append(image_list)
    # print(chunk_images)
    return chunk_text,chunk_images

In [6]:
def map_images_to_indices(find_images, all_images_list):
    # 构建图片名到索引的映射字典
    image_to_index = {img: idx for idx, img in enumerate(all_images_list)}

    # 映射每组图片为索引
    index_result = []
    for image_group in find_images:
        index_group = [image_to_index[img] for img in image_group]
        index_result.append(index_group)

    return index_result

In [7]:
import os
import json
input_file='/root/autodl-fs/OURMRAG/data/base_retrieval/mramg_train_base.jsonl'
outputfile='/root/autodl-fs/OURMRAG/data/sft/mramg_withrag_train.jsonl'
new_data=[]
all_data=[]
with open(input_file, 'r', encoding='utf-8') as f:
    for line in f:
        data = json.loads(line)
        all_data.append(data)
for data in all_data:
    all_chunk=merge_candidate_docs(data['image_find_chunks'],data['query_find_chunks'])
    data['rereieval_all_chunk']=all_chunk
    data['retriver_text'],data['retriver_images']=chunk_info_all(data['rereieval_all_chunk'])
    data["all_images_list"]=flatten_and_deduplicate(data['retriver_images'])
    img_abs_path_prefix = "/root/autodl-fs/IMAGE/images/" + data['dataset'].upper() + '/'
    all_abs_list=[img_abs_path_prefix+image2format[str(img)] for img in data['all_images_list']]
    data["all_images_path"]=all_abs_list
    image_index=map_images_to_indices(data['retriver_images'],data['all_images_list'])
    report_info = ""
    for i, (report, image_indices) in enumerate(zip(data['retriver_text'], image_index)):
        # 转换为“第 N 张图”的形式（索引从 1 开始）
        image_str = ", ".join([f"Image [{idx + 1}]" for idx in image_indices])
        report_info += f"Reference Document [{i + 1}]: {report}\n  The illustration of the Reference Document [{i+1}] : {image_str}\n\n"
    instruction=f"""# Task
Imagine you are an expert in handling multimodal input queries and producing coherent responses.You will receive:    
1. Query: The user query to be answered.
2. Contexts.    
3. A set of images.
# Input   
Query: {data['question']} 
Contexts: {report_info}    
"""
    data['messages']=[
        {
            "role": "user",
            "content": instruction
        },{
            "role":"assistant",
            "content":data["answer"]
        }
    ]
    new_data.append(data)
with open(outputfile, 'w', encoding='utf-8') as f:
    for item in new_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')
