In [None]:
# -*- coding: utf-8 -*-
# @Time    : 2025/5/3 21:36
# @Author  : Maoyuan Li
# @File    : batch_debate_generate.py
# @Software: PyCharm
import os
import json
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch

from rag_for_longchain.retriever.faiss_retriever import retrieve_candidates
from rag_for_longchain.generator.testgen import generate_counterargument_via_api

# -------------------- 模型与分词器初始化 --------------------
MODEL_NAME = "hfl/chinese-roberta-wwm-ext-large"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

# -------------------- 重用辅助函数 --------------------
def extract_utterances(data):
    utterances = []
    if isinstance(data, str):
        for line in data.splitlines():
            line = line.strip()
            if line:
                utterances.append(line)
    elif isinstance(data, list):
        for item in data:
            utterances.extend(extract_utterances(item))
    elif isinstance(data, dict):
        for v in data.values():
            utterances.extend(extract_utterances(v))
    else:
        try:
            utterances.extend(extract_utterances(str(data)))
        except:
            pass
    return utterances

def recursive_split(text, max_length=512):
    tokens = tokenizer.tokenize(text)
    if len(tokens) <= max_length:
        return [text]
    mid = len(tokens) // 2
    part1 = tokenizer.convert_tokens_to_string(tokens[:mid])
    part2 = tokenizer.convert_tokens_to_string(tokens[mid:])
    return recursive_split(part1, max_length) + recursive_split(part2, max_length)

def generate_embeddings(chunks):
    embs = []
    for chunk in chunks:
        inputs = tokenizer(chunk, return_tensors="pt", truncation=True, max_length=512)
        with torch.no_grad():
            outputs = model(**inputs)
        emb = outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy()
        embs.append(emb)
    return np.vstack(embs)

def process_chunks(chunks):
    all_texts = []
    skill_scores = {}
    for c in chunks:
        all_texts.append(c["text"])
        labels = c.get("labels", {})
        for skill, info in labels.items():
            try:
                score = float(info.get("评分", 0))
            except:
                score = 0
            skill_scores[skill] = skill_scores.get(skill, 0) + score
    if skill_scores:
        best_skill = max(skill_scores, key=skill_scores.get)
        best_score = skill_scores[best_skill]
    else:
        best_skill, best_score = None, 0
    return {
        "所有文本": all_texts,
        "最佳辩论技巧": best_skill,
        "最高评分": best_score
    }

In [None]:


# -------------------- 单场辩论处理 --------------------
def process_debate_file(debate_path: Path):
    data = json.loads(debate_path.read_text(encoding="utf-8"))
    # 1. 提取 utterances → 切块 → 生成嵌入
    utterances = extract_utterances(data)
    chunks = []
    for utt in utterances:
        for piece in recursive_split(utt):
            chunks.append(piece)
    print('debug:chunks',chunks)
    embeddings = generate_embeddings(chunks)
    # 2. 检索 top_k 片段
    best_chunks, best_labels = retrieve_candidates(embeddings, top_k=5)
    #print('debug:best_chuns,best_labels', best_labels)
    # 3. 按标签统计，找到“最佳辩论技巧”
    result = process_chunks(best_chunks)
    all_texts_variable = "\n\n".join(result["所有文本"])
    print('debug:all_texts_variable', all_texts_variable)
    # 4. 分别生成正方和反方文本
    outputs = {}
    for stance in ("pro", "con"):
        prompt_meta = {"最佳辩论技巧": result["最佳辩论技巧"]}
        speech = generate_counterargument_via_api(
            best_chunks, all_texts_variable, prompt_meta, stance=stance
        )
        outputs[stance] = speech
    # 5. 返回结构
    return {
        "topic": data.get("topic", debate_path.stem),
        "model": MODEL_NAME,
        "pro": outputs["pro"],
        "con": outputs["con"]
    }


In [None]:

# -------------------- 主入口 --------------------
def main(input_dir: str, output_file: str):
    input_dir = Path(input_dir)
    results = []
    # 遍历一级子文件夹
    for sub in input_dir.iterdir():
        if not sub.is_dir():
            continue
        # 找到第一个非 last_two.json 的 .json
        for file in sub.glob("*.json"):
            if file.name == "last_two.json":
                continue
            # 处理并收集
            res = process_debate_file(file)
            results.append(res)
            break

    # 写入最终 JSON
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=4)
    print(f"已生成汇总文件：{output_file}")

if __name__ == "__main__":
    # 直接写死你的文件夹路径
    input_dir = r"D:\converstional_rag\23acldata\input_data\LLM实验测试"
    output_file = r"D:\converstional_rag\23acldata\output_data\debaterRAG\all_debates_output.json"
    main(input_dir, output_file)

