In [12]:
import json
from datasets import load_dataset

ds = load_dataset("jiawei-ucas/ConsistentChat")["train"]

# -------------------------
# Pair conversations
# -------------------------
def to_pairs(conv):
    pairs = []
    i = 0
    while i < len(conv) - 1:
        if conv[i]["from"] == "human" and conv[i+1]["from"] == "gpt":
            pairs.append([conv[i], conv[i+1]])
            i += 2
        else:
            i += 1
    return pairs

# -------------------------
# Merge helpers
# -------------------------
def merge_pairs(pairs):
    return "\n".join([f"human: {h['value']}\ngpt: {g['value']}" for h,g in pairs])

def merge_human(pairs):
    return "\n".join([f"human: {h['value']}" for h,_ in pairs])

def merge_gpt(pairs):
    return "\n".join([f"gpt: {g['value']}" for _,g in pairs])

# -------------------------
# First + Last
# -------------------------
def first_last(pairs):
    if len(pairs) <= 1:
        return pairs
    return [pairs[0], pairs[-1]]

# -------------------------
# Writer
# -------------------------
paths = {
    "full": {
        "pair": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/pair_full.jsonl",
        "human": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/human_full.jsonl",
        "gpt": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/gpt_full.jsonl",
    },
    "fl": {
        "pair": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/pair_first_last.jsonl",
        "human": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/human_first_last.jsonl",
        "gpt": "/home/hyang/BigML_sharding_dataset/Qwen3-VL-Embedding/data/consistent_chat/gpt_first_last.jsonl",
    }
}

files = {
    k: {kk: open(vv, "w") for kk, vv in v.items()}
    for k, v in paths.items()
}

# -------------------------
# Process
# -------------------------
for ex in ds:
    pairs = to_pairs(ex["conversations"])

    # FULL
    ex_pair = dict(ex);  ex_pair["conversations"]  = merge_pairs(pairs)
    ex_h    = dict(ex);  ex_h["conversations"]     = merge_human(pairs)
    ex_g    = dict(ex);  ex_g["conversations"]     = merge_gpt(pairs)

    files["full"]["pair"].write(json.dumps(ex_pair)+"\n")
    files["full"]["human"].write(json.dumps(ex_h)+"\n")
    files["full"]["gpt"].write(json.dumps(ex_g)+"\n")

    # FIRST LAST
    fl = first_last(pairs)

    ex_pair = dict(ex);  ex_pair["conversations"]  = merge_pairs(fl)
    ex_h    = dict(ex);  ex_h["conversations"]     = merge_human(fl)
    ex_g    = dict(ex);  ex_g["conversations"]     = merge_gpt(fl)

    files["fl"]["pair"].write(json.dumps(ex_pair)+"\n")
    files["fl"]["human"].write(json.dumps(ex_h)+"\n")
    files["fl"]["gpt"].write(json.dumps(ex_g)+"\n")

# close
for group in files.values():
    for f in group.values():
        f.close()
