In [1]:
import os
import sys
import json
import yaml
import pickle
from tqdm import tqdm
sys.path.append("../../../")
from mllm.utils import (mapping_dict_keys,
                        json2tokenV2,
                        token2jsonV2,
                        load_jsonl_file,
                        random_select_list,
                        save_jsonl_file)

In [2]:
label_files = [
    "/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/metadata.jsonl",
]


field_mapping = {
    "starting_station": "起始站",
    "destination_station": "终点站",
    "seat_category": "座位等级",
    "ticket_rates": "票据价格",
    "ticket_num": "票据号码",
    "train_num": "车次",
    "date": "出发日期",
    "name": "姓名"
}

In [3]:
def convert_swift_data(ex_data, image_path, query):
    new_ex_data = {}
    for k, v in ex_data.items():
        new_k = field_mapping[k]
        new_ex_data[new_k] = v

    json_str = json2tokenV2(new_ex_data)
    row_swift_data = {
        "query": query,
        "response": json_str,
        "image_path": [image_path]
    }
    return row_swift_data

DATA_ROOT = '/mnt/n/data/mllm-data/mllm-finetune-data/trainticket'

def convert(convert_task_list:list, data_root="/mnt/"):
    for label_file_path in tqdm(convert_task_list):
        save_label_folder = os.path.join(os.path.dirname(label_file_path), "swift_label")
        label_file_name = os.path.splitext(os.path.basename(label_file_path))[0]
        if not os.path.exists(save_label_folder):
            os.makedirs(save_label_folder)
        ori_label_data = load_jsonl_file(label_file_path)
        train_data, val_data, test_data = [], [], []
        for row_data in ori_label_data:
            #print(row_data)
            usage = row_data["数据用途"]
            if usage == "训练":
                DATA_ROOT_TRAIN = '/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/train/hcp_aug_2/'
                file_name = os.path.basename(row_data['图片路径'])
                image_path = os.path.join(DATA_ROOT_TRAIN, file_name)
            else:
                image_path = os.path.join(DATA_ROOT, row_data['图片路径'])
            if not os.path.exists(image_path):
                print(f"image_path:{image_path} not exist, continue!")
                continue
            ex_data = json.loads(row_data["抽取结果"])
            prompt = "请抽取火车票中的起始站、终点站、座位等级、票据价格、票据号码、出发日期、车次及姓名等字段"
            converted_data = convert_swift_data(ex_data, image_path, prompt)
            if usage == "训练":
                train_data.append(converted_data)
            elif usage == "验证":
                val_data.append(converted_data)
            else:
                test_data.append(converted_data)
        if len(val_data) <= 300:
            select_val_data = random_select_list(test_data, min(300-len(val_data), len(test_data)))
            val_data.extend(select_val_data)
        else:
            val_data = val_data[:300]

        for (save_data, save_path) in zip([train_data, val_data, test_data],
                                          [f"{save_label_folder}/{label_file_name}_train.jsonl",
                                           f"{save_label_folder}/{label_file_name}_val.jsonl",
                                           f"{save_label_folder}/{label_file_name}_test.jsonl"]):
            print(save_path)
            save_jsonl_file(save_data, save_path)
            # with open(save_path, "w", encoding="utf-8") as fo:
            #     for row_data in save_data:
            #         fo.write(json.dumps(row_data, ensure_ascii=False) + "\n")

convert(label_files)

100%|██████████| 1/1 [00:08<00:00,  8.16s/it]

/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/swift_label/metadata_train.jsonl
/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/swift_label/metadata_val.jsonl
/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/swift_label/metadata_test.jsonl





In [4]:
synth_label_1 = "/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/real_1920.pkl"
synth_label_2 = "/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/synth_300k.pkl"
with open(synth_label_1, 'rb') as fi:
    synth_data_1 = pickle.load(fi)
with open(synth_label_2, 'rb') as fi:
    synth_data_2 = pickle.load(fi)
list(synth_data_2.keys())[0], synth_data_2[list(synth_data_2.keys())[0]]

('output_10/images_0/hcp_2061',
 {'starting_station': '北京南站',
  'destination_station': '济南西站',
  'seat_category': '新空调硬座',
  'ticket_rates': '¥353.5元',
  'ticket_num': 'H6978778',
  'date': '2018年01月17日',
  'train_num': 'G13',
  'name': '林玉霜'})

In [5]:
a = list(set(list(synth_data_1.keys())) & set(list(synth_data_2.keys())))

In [6]:
# 合成数据单独处理
DATA_ROOT_SYNTH = "/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/train"
save_synth_swift_label = "/mnt/n/data/mllm-data/mllm-finetune-data/trainticket/swift_label/metadata_train_synth_300k.txt"
def convert_synth(synth_data:dict, data_root="/mnt/"):
    train_data = []
    for file_name, ex_data in synth_data.items():
        #print(row_data)
        image_path = f"{DATA_ROOT_SYNTH}/{file_name}.jpg"
        if not os.path.exists(image_path):
            print(f"image_path:{image_path} not exist, continue!")
            continue
        prompt = "请抽取火车票中的起始站、终点站、座位等级、票据价格、票据号码、出发日期、车次及姓名等字段"
        converted_data = convert_swift_data(ex_data, image_path, prompt)
        train_data.append(converted_data)
    save_jsonl_file(train_data, save_synth_swift_label)
convert_synth(synth_data_2)