In [None]:
import json
import os
import re
from typing import Callable, Optional

from opencc import OpenCC
from transformers import AutoTokenizer

from crawler.utils import Logger
from curriculum_training.constants import (
    USE_VLLM, ALLOW_VLLM, MAX_INPUT_LENGTH, MAX_NEW_TOKENS, DATASET_V3_DIR,
    SUMMARY_FORMATTED_V3, SUMMARY_V3, ESSENTIALS_V3, NWR_V3
)
from news_with_rationale import NewsWithRationale as NWR
from utils import load_udn_news
from utils_vllm import (
    LLM,
    SamplingParams,
    init_vllm_model,
    filter_by_max_length,
    vllm_batch_generate,
)

assert ALLOW_VLLM
assert USE_VLLM

MODELNAME = "Qwen/Qwen2.5-32B-Instruct"

gen_logger = Logger("data_gen", verbose_level=3)

model: Optional[LLM] = None
sampling_params: Optional[SamplingParams] = None
tokenizer = AutoTokenizer.from_pretrained(MODELNAME)

cc = OpenCC("s2twp")  # Simplified Chinese to Traditional Chinese

Message = list[dict[str, str]]

if not os.path.exists(DATASET_V3_DIR):
    os.makedirs(DATASET_V3_DIR)


def essential_aspects_prompt(nwr: NWR) -> Message:
    return [
        {
            "role": "system",
            "content": (
                "請根據以下新聞內容以及摘要，提取新聞的關鍵要素與三元組，關鍵要素應為關鍵短句、名詞或事實，"
                "三元組應為[實體1 | 關係 | 實體2]的格式，"
                "這些三元組用於構成摘要，請用繁體中文回答。"
                "請將每個關鍵要素與三元組用[]與、分隔。"
                "例如："
                "關鍵要素：\n[關鍵要素1]、[關鍵要素2]、[關鍵要素3]、...\n\n"
                "三元組：\n[實體1_1 | 關係_1 | 實體1_2]、[實體2_1 | 關係_2 | 實體2_2]、..."
                ""
            )
        },
        {
            "role": "user",
            "content": f"新聞：\n{nwr.article}\n\n摘要：\n{nwr.summary}\n\n"
        }
    ]


def summary_prompt(nwr: NWR) -> Message:
    return [
        {
            "role": "system",
            "content": (
                "請根據以下新聞內容，為新聞生成一份100字內精簡的摘要，"
                "請用繁體中文回答。\n"
                "例如：\n"
                "生成摘要：\n"
            )
        },
        {"role": "user", "content": f"新聞：\n{nwr.article}"}
    ]


def local_gen_response(nwrs: list[NWR], prompt_fn: Callable) -> list[dict]:

    assert model is not None
    assert sampling_params is not None

    data: list[dict] = []

    if len(nwrs) == 0:
        gen_logger.warning("No data need to generate.")
        return data

    prompts: list[str] = [
        tokenizer.apply_chat_template(
            prompt_fn(nwr),
            tokenize=False,
            add_generation_prompt=True
        )
        for nwr in nwrs
    ]

    # Filter out prompts that are too long
    prompts, _nwrs = filter_by_max_length(MAX_INPUT_LENGTH, prompts, nwrs)

    if len(prompts) == 0:
        return data

    responses = vllm_batch_generate(model, prompts, sampling_params)
    outputs = [response.outputs[0].text for response in responses]
    data = [
        {
            "id": nwr.id,
            "news": nwr.article,
            "response": response
        }
        for nwr, response in zip(_nwrs, outputs)
    ]

    return data


def parse_essential_and_triple(response: dict) -> tuple[list[str], list[str]]:
    """
    Parse the essential aspects from the response.
    """
    pattern = (
        r"(?:關鍵要素：\s*)?((?:\[[^\[\]]+\]、?)+)\n*"
        r"(?:三元組：\s*)?((?:\[[^\[\]]+\]、?)+)"
    )
    match = re.search(pattern, response["response"])
    if match:
        content = match.group(1)
        essentials = [
            item.strip("[]") for item in content.split("、") if item
        ]
        content = match.group(2)
        tripples = [
            item.strip("[]") for item in content.split("、") if item
        ]
        return essentials, tripples

    return [], []


def parse_summary(response: dict) -> str:
    """
    Parse the summary from the response.
    """
    pattern = r"(?:生成摘要：\s*)?(.+)"
    match = re.search(pattern, response["response"].strip())
    if match:
        summary = match.group(1).strip()
        return summary
    return ""


def parse_summaries(nwrs: list[NWR], responses: list[dict]) -> list[NWR]:
    """
    Parse the summaries from the responses and update the corresponding
    NWR objects.
    """
    response_map = {response["id"]: response for response in responses}

    for nwr in nwrs:
        response = response_map.get(nwr.id)
        if not response:
            # gen_logger.warning(f"No response found for NWR with id {nwr.id}")
            continue

        summary = parse_summary(response)
        if summary:
            summary = cc.convert(summary)
            nwr.summary = summary
            nwr.rationale_summary = summary

    return nwrs


def parse_esss_and_tris(nwrs: list[NWR], responses: list[dict]) -> list[NWR]:
    """
    Parse the essential aspects and triples from the responses
    and update the corresponding NWR objects.
    """
    response_map = {response["id"]: response for response in responses}

    for nwr in nwrs:
        response = response_map.get(nwr.id)
        if not response:
            # gen_logger.warning(f"No response found for NWR with id {nwr.id}")
            continue

        essentials, triples = parse_essential_and_triple(response)
        if essentials:
            essentials = [cc.convert(ess) for ess in essentials]
            nwr.essential_aspects = essentials
        if triples:
            triples = [cc.convert(tri) for tri in triples]
            nwr.triples = triples

    return nwrs


def load_data(filename: str) -> list[dict]:
    """
    Load the data from the file.
    """
    data: list[dict] = []
    try:
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                data.append(json.loads(line))
    except FileNotFoundError:
        gen_logger.warning(f"{filename} not found, starting from scratch.")

    gen_logger.info(f"Loaded {len(data)} data from {filename}")
    return data


def load_essentials_and_triples() -> tuple[
    dict[int, list[str]], dict[int, list[str]]
]:
    data = load_data(ESSENTIALS_V3)
    essential_aspects: dict[int, list[str]] = {}
    triples: dict[int, list[str]] = {}
    for dat in data:
        if dat["id"] in essential_aspects:
            gen_logger.warning(f"Duplicated essential id: {dat['id']}")
            continue
        parsed_essentials, parsed_triples = parse_essential_and_triple(dat)
        essential_aspects[dat["id"]] = parsed_essentials
        triples[dat["id"]] = parsed_triples
    return essential_aspects, triples


def load_summary(filename) -> dict[int, str]:
    data = load_data(filename)
    summaries: dict[int, str] = {}
    for dat in data:
        if dat["id"] in summaries:
            gen_logger.warning(f"Duplicated summary id: {dat['id']}")
            continue
        summaries[dat["id"]] = parse_summary(dat)
    return summaries


def get_ids_from_file(filename: str) -> set[int]:
    """
    Get the ids from the file.
    """
    ids: set[int] = set()
    data = load_data(filename)
    for dat in data:
        if dat["id"] in ids:
            gen_logger.warning(f"Duplicated id: {dat['id']}")
            continue
        ids.add(dat["id"])
    return ids


def get_finished_ids() -> tuple[set[int], set[int], set[int], set[int]]:
    """
    Get the finished essential/triple/summary/nwr ids from the generated files.
    """

    # find finished essential ids
    essential_ids = get_ids_from_file(ESSENTIALS_V3)
    triple_ids = essential_ids.copy()

    # find finished zh-tw ids
    summary_ids = get_ids_from_file(SUMMARY_FORMATTED_V3)

    nwr_ids: set[int] = essential_ids & triple_ids & summary_ids

    return essential_ids, triple_ids, summary_ids, nwr_ids


In [None]:
model, sampling_params = init_vllm_model(
    model_name=MODELNAME,
    max_input_length=MAX_INPUT_LENGTH,
    max_new_tokens=MAX_NEW_TOKENS
)

In [None]:

# load the news
news_list: list[str] = load_udn_news()
# news_list = news_list[:10]
gen_logger.info(f"Loaded {len(news_list)} news")

# load finished ids
essential_ids, triple_ids, summary_ids, nwr_ids = get_finished_ids()
gen_logger.info(f"Finished essential count: {len(essential_ids)}")
gen_logger.info(f"Finished triple count: {len(triple_ids)}")
gen_logger.info(f"Finished summary count: {len(summary_ids)}")
gen_logger.info(f"Finished NWR count: {len(nwr_ids)}")

essential_data, triple_data = load_essentials_and_triples()
summary_data = load_summary(SUMMARY_FORMATTED_V3)

# assert len(essential_data) == len(list(essential_ids))
# assert len(triple_data) == len(list(triple_ids))
# assert len(summary_data) == len(list(summary_ids))

nwrs = [NWR(news, id=i) for i, news in enumerate(news_list)]
for nwr in nwrs:
    if nwr.id in essential_ids:
        nwr.essential_aspects = [
            cc.convert(ess) for ess in essential_data[nwr.id]
        ]
    if nwr.id in triple_ids:
        nwr.triples = [cc.convert(tri) for tri in triple_data[nwr.id]]
    if nwr.id in summary_ids:
        nwr.summary = cc.convert(summary_data[nwr.id])
        nwr.rationale_summary = nwr.summary

# with open(NWR_V3, "w", encoding="utf-8") as f:
#     for nwr in nwrs:
#         if nwr.id in nwr_ids:
#             f.write(json.dumps(nwr.to_dict(), ensure_ascii=False) + "\n")

gen_logger.info(f"Loaded {len(nwr_ids)} NWRs")


In [None]:
# generate summary responses and parse them
_nwrs = [nwr for nwr in nwrs if nwr.id not in summary_ids]
gen_logger.info(f"Generating {len(_nwrs)} summaries")

responses = local_gen_response(_nwrs, summary_prompt)
nwrs = parse_summaries(nwrs, responses)  # update the NWRs
gen_logger.info(f"Generated {len(responses)} summary responses")

with open(SUMMARY_V3, "a", encoding="utf-8") as f:
    for dat in responses:
        f.write(json.dumps(dat, ensure_ascii=False) + "\n")

gen_logger.info(f"{len([nwr for nwr in nwrs if nwr.summary == ''])} "
                f"NWRs could not generate summaries")

# remove the NWRs that could not generate summaries
nwrs = [nwr for nwr in nwrs if nwr.summary != ""]
gen_logger.info(f"remaining {len(nwrs)} NWRs")

In [None]:
# generate essential responses and parse them
_nwrs = [nwr for nwr in nwrs if nwr.id not in essential_ids]
gen_logger.info(f"Generating {len(_nwrs)} essential aspects")
responses = local_gen_response(_nwrs, essential_aspects_prompt)
nwrs = parse_esss_and_tris(nwrs, responses)  # update the NWRs
gen_logger.info(f"Generated {len(responses)} essential and triples")

with open(ESSENTIALS_V3, "a", encoding="utf-8") as f:
    for dat in responses:
        f.write(json.dumps(dat, ensure_ascii=False) + "\n")

# remove the NWRs that could not generate essential aspects
gen_logger.info(
    f"{len([nwr for nwr in nwrs if nwr.essential_aspects == []])} "
    f"NWRs could not generate essential aspects"
)
nwrs = [nwr for nwr in nwrs if nwr.essential_aspects != []]
gen_logger.info(f"remaining {len(nwrs)} NWRs")

# # save the NWRs to file
# with open(NWR_V3, "w", encoding="utf-8") as f:
#     for nwr in nwrs:
#         f.write(json.dumps(nwr.to_dict(), ensure_ascii=False) + "\n")


In [None]:
import json

# import ollama
import re
from opencc import OpenCC
from ollama import chat
# from tqdm import tqdm

from crawler.utils import Logger
from curriculum_training.constants import (
    MAX_INPUT_LENGTH,
    USE_VLLM,
    ALLOW_VLLM,
    NWR_FORMATTED_V3,
    SUMMARY_V3,
    ESSENTIALS_V3,
)
from news_with_rationale import NewsWithRationale as NWR
from utils import int_set_str

from transformers import AutoTokenizer, PreTrainedTokenizer
from utils_vllm import (
    filter_by_max_length,
    vllm_batch_generate,
)

assert ALLOW_VLLM

FORMAT_MODEL = "Qwen/Qwen2.5-32B-Instruct"
FORMAT_MODEL_OLLAMA = "qwen2.5:32b-instruct"


logger = Logger("data_format")
cc = OpenCC("s2twp")  # Simplified Chinese to Traditional Chinese

ESS_START = len("關鍵要素：\n[")
TRI_START = len("\n[")


In [None]:

def format_ess_tri_sys_prompt() -> str:
    """ Get the format system prompt for the model. """
    return (
        # "請完成以下任務：\n"
        "請將以下內容轉換為關鍵要素和三元組的格式：\n"
        "1. 你會收到若干個關鍵要素，請將每個要素以[]與、分隔，並移除不必要的符號，如 '1.'、'；'等。\n"
        "2. 你會收到若干個三元組，請將其以頓號分隔，並移除不必要的符號，如 '1.'、'；'等。\n"
        "範例：\n"
        "關鍵要素：\n"
        "[關鍵要素1]、[關鍵要素2]、[關鍵要素3]...\n"
        "三元組：\n"
        "[三元組1_1 | 三元組1_2 | 三元組1_3]、[三元組2_1 | 三元組2_2 | 三元組2_3]..."
    )


def format_ess_tri_user_prompt(nwr: NWR) -> str:
    """
    Get the format system/user prompt for the given NewsWithRationale object.
    """
    return (
        "請將以下內容轉換為指定格式：\n"
        "關鍵要素：\n"
        f"{nwr.essential_aspects_str()}\n"
        "三元組：\n"
        f"{nwr.triples_str()}"
    )


def format_summ_sys_prompt() -> str:
    """
    Get the format system prompt for the model.
    """
    return (
        "你會收到一個新聞文章以及其摘要，請評估該摘要是否為良好的摘要。\n"
        "若摘要良好（符合且文章內容），請輸出\"符合\"並將無關的內容移除。\n"
        "若摘要出現重複、雜亂的訊息、有未完成的句子，或是摘要本身不完整，請輸出\"不符合\"。\n"
        "範例輸入：\n"
        "新聞：\n"
        "新聞內容\n\n"
        "摘要：\n"
        "摘要內容：\n\n"
        "範例輸出：\n"
        "符合\n"
        "乾淨版摘要\n\n"
        "或者\n"
        "不符合"
    )


def format_summ_user_prompt(data: dict) -> str:
    """
    Get the format user prompt for the given NewsWithRationale object.
    """
    return (
        "請評估以下內容的摘要是否是良好的新聞摘要：\n"
        "新聞：\n"
        f"{data['news']}\n\n"
        "摘要：\n"
        f"{data['response']}\n\n"
    )


def process_summ_ollama(data: dict) -> str:
    """
    Process the given NewsWithRationale object and return the formatted string.
    """
    sys_prompt = format_summ_sys_prompt()
    user_prompt = format_summ_user_prompt(data)

    # Generate the response using the model from Ollama
    gen_response = chat(
        model=FORMAT_MODEL_OLLAMA,
        messages=[
            {"role": "system", "content": sys_prompt},
            {"role": "user", "content": user_prompt},
        ]
    )
    assert gen_response.message.content is not None

    return gen_response.message.content


def get_finished_id(filename) -> set[int]:
    """ Get the finished news/NWR/zh-tw ids from the generated files. """
    finished_ids = set()
    line_bak: str
    try:
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                line_bak = line
                finished_ids.add(data["id"])
    except FileNotFoundError:
        logger.info(f"{filename} not found, starting from scratch.")
        pass
    except json.JSONDecodeError:
        logger.error(f"Error decoding JSON in line: {line_bak}.")
        exit()
    return finished_ids


def read_summ_file(filename: str, finished_ids=None) -> list[dict]:
    """ Read the summary objects from the given file. """
    if finished_ids is None:
        finished_ids = set()
    summ_data: list[dict] = []
    try:
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                if data["id"] in finished_ids:
                    continue
                summ_data.append(data)
    except FileNotFoundError:
        logger.info(f"{filename} not found, starting from scratch.")
        pass
    return summ_data


def read_ess_tri_file(filename: str, finished_ids=None) -> list[dict]:
    """ Read the essentials and triples from the given file. """
    if finished_ids is None:
        finished_ids = set()
    ess_tri_data: list[dict] = []
    try:
        with open(filename, "r", encoding="utf-8") as f:
            for line in f:
                data = json.loads(line)
                if data["id"] in finished_ids:
                    continue
                ess_tri_data.append(data)
    except FileNotFoundError:
        logger.info(f"{filename} not found, starting from scratch.")
        pass
    return ess_tri_data


def extract_essentials_and_triples(
    response: str
) -> tuple[list[str], list[str]]:
    # Extract the essential aspects and triples from the response
    formatted_response = cc.convert(response)

    try:
        ess_str, tri_str = formatted_response.split("三元組：")
        new_essentials = ess_str.split("]、[")
        new_essentials[0] = new_essentials[0][ESS_START:]
        new_essentials[-1] = new_essentials[-1][:-3]  # Remove \n\n

        new_triples = tri_str.split("]、[")
        new_triples[0] = new_triples[0][TRI_START:]
        new_triples[-1] = new_triples[-1][:-1]

    except ValueError:
        logger.error(f"Invalid response format: {response}")
        exit()

    return new_essentials, new_triples


def extract_summ(response: str) -> str:
    # Extract the summary from the response
    response = cc.convert(response)

    if response.startswith("生成摘要：\n"):
        summary = response[len("生成摘要：\n") :].strip()
        return summary

    logger.error(f"Invalid summ response format: {response}")
    raise ValueError(f"Invalid summ response format: {response}")


def extract_summ_acceptance(response: str) -> tuple[bool, str]:
    pattern = r"^(符合)(?:\n乾淨版摘要：)?\n(.+)|^(不符合)"
    match = re.match(pattern, response)
    if match:
        if match.group(1) is not None:
            clean_summary = match.group(2)
            return True, clean_summary
        else:
            return False, ""
    else:
        logger.error(f"Invalid response format: {response}")
        return False, ""


def format_summ(summ_file: str, output_file: str) -> None:
    """
    Main function to format the data using the model.
    """
    # Load the finished ids from the formatted file
    finished_ids = get_finished_id(output_file)
    logger.info(f"Finished ids count: {len(finished_ids)}")
    logger.info(f"Finished ids: {int_set_str(finished_ids)}")

    # Load the NewsWithRationale excluding the finished ids
    summ_todo: list[dict] = read_summ_file(summ_file, finished_ids)
    summ_finished: list[dict] = read_summ_file(output_file, finished_ids)

    logger.info(f"Loaded {len(summ_finished)} summ from {output_file}")
    logger.info(f"Total {len(summ_todo)} summ to process")

    # summ_todo = summ_todo[:10]  # for demonstration

    ids = [summ["id"] for summ in summ_todo]
    assert len(ids) == len(summ_todo)

    logger.info(f"Remains {len(ids)} summ to process")
    logger.info(f"Remains summ ids: {int_set_str(set(ids))}")

    output_strs: list[str] = []
    if USE_VLLM:
        sys_prompt = format_summ_sys_prompt()
        prompts = [
            tokenizer.apply_chat_template(
                [
                    {"role": "system", "content": sys_prompt},
                    {"role": "user", "content": format_summ_user_prompt(summ)}
                ],
                tokenize=False,
                add_generation_prompt=True,
            )
            for summ in summ_todo
        ]

        assert len(prompts) == len(summ_todo)

        # Filter out prompts that are too long
        prompts, summ_todo = filter_by_max_length(
            MAX_INPUT_LENGTH, prompts, summ_todo
        )

        responses = vllm_batch_generate(model, prompts, sampling_params)
        output_strs = [response.outputs[0].text for response in responses]

    else:
        raise NotImplementedError("Ollama processing not implemented")

    for i, (summ, output) in enumerate(zip(summ_todo, output_strs)):
        try:
            success, summ_str = extract_summ_acceptance(output)
        except Exception as e:
            logger.error(f"Error processing id {summ['id']}: {e}")
            logger.error(f"Formatted response:\n{output}")
            continue

        if success:
            # Save the formatted response
            summ_finished.append(
                {
                    "id": summ["id"],
                    "news": summ["news"],
                    # "response": summ_str,
                    "response": summ,
                }
            )

    with open(output_file, "w", encoding="utf-8") as f:
        for summ in summ_finished:
            f.write(json.dumps(summ, ensure_ascii=False) + "\n")


In [None]:

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(FORMAT_MODEL)

In [None]:
# clean up invalid data

finished_summ_ids = get_finished_id(SUMMARY_V3)

finished_ids = finished_summ_ids.intersection(get_finished_id(ESSENTIALS_V3))


summ_data: list[dict] = read_summ_file(SUMMARY_V3)
ess_data: list[dict] = read_ess_tri_file(ESSENTIALS_V3)

summ_data = [summ for summ in summ_data if summ["id"] in finished_summ_ids]
ess_data = [ess for ess in ess_data if ess["id"] in finished_ids]

print(f"Finished summ ids count: {len(finished_summ_ids)}")
print(f"Finished ids count: {len(finished_ids)}")

print(f"Save {len(summ_data)} summ to {SUMMARY_V3}")
print(f"Save {len(ess_data)} ess to {ESSENTIALS_V3}")

# with open(SUMMARY_V3, "w", encoding="utf-8") as f:
#     for summ in summ_data:
#         f.write(json.dumps(summ, ensure_ascii=False) + "\n")

# with open(ESSENTIALS_V3, "w", encoding="utf-8") as f:
#     for ess in ess_data:
#         f.write(json.dumps(ess, ensure_ascii=False) + "\n")

formatted_summ_data: list[dict] = read_summ_file(SUMMARY_V3)
summ_dict: dict[int, dict] = {summ["id"]: summ for summ in formatted_summ_data}
ess_dict = {ess["id"]: ess for ess in ess_data}

nwr_list: list[NWR] = []

for id in finished_ids:
    if id not in summ_dict or id not in ess_dict:
        logger.warning(f"Missing data for id {id}")
        continue

    summ = extract_summ(summ_dict[id]["response"])
    ess, tri = extract_essentials_and_triples(
        ess_dict[id]["response"]
    )

    nwr = NWR(
        article=summ_dict[id]["news"],
        summary=summ,
        id=id,
        rationale_summary=summ,
        essential_aspects=ess,
        triples=tri,
    )

    nwr_list.append(nwr)

logger.info(f"Loaded {len(nwr_list)} NWRs")

# sort the NWRs by id
nwr_list.sort(key=lambda x: x.id)

with open(NWR_FORMATTED_V3, "w", encoding="utf-8") as f:
    for nwr in nwr_list:
        f.write(json.dumps(nwr.to_dict(), ensure_ascii=False) + "\n")

logger.info(f"Saved {len(nwr_list)} NWRs to {NWR_FORMATTED_V3}")




In [None]:
format_summ(
    SUMMARY_V3,
    SUMMARY_FORMATTED_V3
)

In [None]:
# # clean up invalid data

# finished_summ_ids = get_finished_id(SUMMARY_FORMATTED_V3)
# # print(f"Finished summ ids count: {len(finished_summ_ids)}")
# finished_summ_ids = finished_summ_ids.intersection(get_finished_id(SUMMARY_V3))

# finished_ids = finished_summ_ids.intersection(get_finished_id(ESSENTIALS_V3))


# summ_data: list[dict] = read_summ_file(SUMMARY_V3)
# ess_data: list[dict] = read_ess_tri_file(ESSENTIALS_V3)

# summ_data = [summ for summ in summ_data if summ["id"] in finished_summ_ids]
# ess_data = [ess for ess in ess_data if ess["id"] in finished_ids]

# print(f"Finished summ ids count: {len(finished_summ_ids)}")
# print(f"Finished ids count: {len(finished_ids)}")

# print(f"Save {len(summ_data)} summ to {SUMMARY_V3}")
# print(f"Save {len(ess_data)} ess to {ESSENTIALS_V3}")

# # with open(SUMMARY_V3, "w", encoding="utf-8") as f:
# #     for summ in summ_data:
# #         f.write(json.dumps(summ, ensure_ascii=False) + "\n")

# # with open(ESSENTIALS_V3, "w", encoding="utf-8") as f:
# #     for ess in ess_data:
# #         f.write(json.dumps(ess, ensure_ascii=False) + "\n")

# formatted_summ_data: list[dict] = read_summ_file(SUMMARY_FORMATTED_V3)
# summ_dict: dict[int, dict] = {summ["id"]: summ for summ in formatted_summ_data}
# ess_dict = {ess["id"]: ess for ess in ess_data}

# nwr_list: list[NWR] = []

# for id in finished_ids:
#     if id not in summ_dict or id not in ess_dict:
#         logger.warning(f"Missing data for id {id}")
#         continue

#     ess, tri = extract_essentials_and_triples(
#         ess_dict[id]["response"]
#     )

#     nwr = NWR(
#         article=summ_dict[id]["news"],
#         summary=summ_dict[id]["response"],
#         id=id,
#         rationale_summary=summ_dict[id]["response"],
#         essential_aspects=ess,
#         triples=tri,
#     )

#     nwr_list.append(nwr)

# logger.info(f"Loaded {len(nwr_list)} NWRs")

# # sort the NWRs by id
# nwr_list.sort(key=lambda x: x.id)

# with open(NWR_FORMATTED_V3, "w", encoding="utf-8") as f:
#     for nwr in nwr_list:
#         f.write(json.dumps(nwr.to_dict(), ensure_ascii=False) + "\n")

# logger.info(f"Saved {len(nwr_list)} NWRs to {NWR_FORMATTED_V3}")


