In [1]:
from pathlib import Path
import pandas as pd
from tqdm import tqdm
import os
from langchain_chroma import Chroma
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from dotenv import load_dotenv

In [2]:
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.embeddings import DashScopeEmbeddings

In [3]:
env_path = Path('.') / '.env'
load_dotenv(dotenv_path=env_path)


True

In [4]:
QWEN_API_KEY = os.environ.get("DASHSCOPE_API_KEY")
QWEN_BASE_URL = os.environ.get("DASHSCOPE_API_URL")
llm = ChatOpenAI(
        model="qwen-plus-latest",
        temperature=0.3,
        api_key=QWEN_API_KEY,
        base_url=QWEN_BASE_URL
    )

embedding_model = DashScopeEmbeddings(
        model=os.environ.get("EMBEDDING_MODEL")
        # (它会自动从 DASHSCOPE_API_KEY 环境变量中获取密钥)
        # (它会使用正确的、原生的阿里云 embedding 端点)
    )

In [5]:

DATASET1_PATH = "./data/input/sft_dataset_4000.json"
DATASET2_PATH = "./data/input/MITRE-ATTACK_dataset_test.json"
OUTPUT_DATASET_PATH = "./data/output/sft_dataset_with_qwen_compat_explanations.json"

VECTORDB_PERSIST_DIR = "./data/database/ttp_chroma_db_qwen_compat/"

In [6]:
# from langchain_core.documents import Document
# df_ttp = pd.read_json(DATASET2_PATH, orient="records")
#
# required_columns = ['ID', 'name', 'description']
# if not all(col in df_ttp.columns for col in required_columns):
#     raise ValueError(f"Dataset 2 JSON file must contain the following keys: {required_columns}")
#
#     # 将 DataFrame 转换为 Langchain Document 对象列表
# documents = []
# for index, row in df_ttp.iterrows():
#     doc = Document(
#         page_content=row['description'],
#         metadata={'id': row['ID'], 'name': row['name']}
#     )
#     documents.append(doc)
# print(f"✅ Loaded and converted {len(documents)} TTP documents.")


In [7]:
from langchain_core.documents import Document
try:
    df_ttp = pd.read_json(DATASET2_PATH, orient="records")

    required_columns = ['ID', 'name', 'description']
    if not all(col in df_ttp.columns for col in required_columns):
        raise ValueError(f"Dataset 2 JSON file must contain the following keys: {required_columns}")

    documents = []
    skipped_rows = 0
    # (这是您 naive_rag.ipynb 中的单元格 7 的内容，但增加了修复)
    for index, row in df_ttp.iterrows():
        description_content = row['description']

        # --- !! 关键修复点 (Robust Fix) !! ---
        # 我们必须跳过 None/NaN、非字符串类型 *以及* 空字符串
        if pd.isna(description_content) or not isinstance(description_content, str) or description_content.strip() == "":
            print(f"⚠️ WARNING: Skipping row {index} (ID: {row.get('ID', 'N/A')}) due to missing, empty, or non-string description.")
            skipped_rows += 1
            continue # 跳过这一行

        # 只有有效的、非空的字符串才会进入这里
        doc = Document(
            page_content=description_content, # 现在可以安全使用
            metadata={'id': row['ID'], 'name': row['name']}
        )
        documents.append(doc)

    print(f"✅ Loaded and converted {len(documents)} TTP documents.")
    if skipped_rows > 0:
        print(f"⚠️ Skipped {skipped_rows} rows due to invalid/missing descriptions.")

    if not documents:
        raise ValueError("No valid documents were loaded. Please check your JSON file for 'description' fields.")

except FileNotFoundError:
    print(f"❌ ERROR: Knowledge Base file not found at '{DATASET2_PATH}'")
    exit(1)
except ValueError as ve:
    print(f"❌ ERROR: {ve}")
    exit(1)
except Exception as e:
    print(f"❌ An unexpected error occurred while loading Dataset 2: {e}")
    exit(1)

✅ Loaded and converted 508 TTP documents.


In [8]:
try:
    # (我们现在使用 Chroma.from_documents()，因为 DashScopeEmbeddings 类可以正确处理它)
    if os.path.exists(VECTORDB_PERSIST_DIR) and os.listdir(VECTORDB_PERSIST_DIR):
        print(f"    Loading existing vector database from '{VECTORDB_PERSIST_DIR}'...")
        vectordb = Chroma(
            persist_directory=VECTORDB_PERSIST_DIR,
            embedding_function=embedding_model # 现在这是 DashScopeEmbeddings
        )
        print("✅ Existing vector database loaded successfully.")
    else:
        print(f"    No existing database found. Creating new vector database at '{VECTORDB_PERSIST_DIR}'...")
        print("    (This may take a few minutes, calling Chroma.from_documents()...)")
        vectordb = Chroma.from_documents(
            documents=documents, # 清理过的列表
            embedding=embedding_model, # 现在这是 DashScopeEmbeddings
            persist_directory=VECTORDB_PERSIST_DIR
        )
        print(f"✅ New vector database created and persisted.")

    retriever = vectordb.as_retriever(search_kwargs={"k": 3})
    print("✅ Retriever created successfully.")

except Exception as e:
    print(f"❌ ERROR: Failed during vector database creation/loading: {e}")
    print("    This error is now coming from the native DashScopeEmbeddings class.")
    exit(1)

    Loading existing vector database from './data/database/ttp_chroma_db_qwen_compat/'...
✅ Existing vector database loaded successfully.
✅ Retriever created successfully.


In [9]:
len(vectordb.get()['documents'])

508

In [10]:
from langchain_core.prompts import PromptTemplate
prompt_template_str = """
You are a top-tier cybersecurity CTI analyst and an expert in the MITRE ATT&CK framework.
Your task is to establish a clear, logical connection between a CTI description and its corresponding ATT&CK technique (TTP).

Please follow these steps strictly:
1.  Analyze the provided **[CTI Input]**, identifying the key actions, tools, or targets (e.g., "used macros", "download and deploy", "SOCKS proxy").
2.  Review the **[Retrieved TTP Context]** to understand the official definitions of the retrieved techniques.
3.  Generate a detailed **[Reasoning Process]** that explains exactly why the key actions in the **[CTI Input]** match the definition of the **[Target TTP]**. You must specify which words or phrases from the input correspond to which aspects of the technique's definition.
4.  Your response must **only** contain the detailed **[Reasoning Process]**.

---
[CTI Input]:
{input_cti}

[Target TTP]:
{target_ttp}

[Retrieved TTP Context]:
{context}
---

Please output only your **[Reasoning Process]**:
"""
prompt = PromptTemplate.from_template(prompt_template_str)

In [11]:
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers import StrOutputParser
# --- 4.2 辅助函数 ---
def format_docs(docs):
    """Formats retrieved documents for the prompt."""
    formatted = []
    for doc in docs:
        metadata = doc.metadata
        formatted.append(f"ID: {metadata.get('id', 'N/A')}\nName: {metadata.get('name', 'N/A')}\nDescription: {doc.page_content}")
    return "\n---\n".join(formatted)

In [12]:
# # --- 4.3 构建链 (LCEL) ---
# retriever_chain = (
#     RunnableLambda(lambda x: x['input'])
#     | retriever
#     | format_docs
# )
#
# generation_chain = (
#     {
#         "context": retriever_chain,
#         "original_input": RunnablePassthrough()
#     }
#     | RunnableParallel(
#         input_cti=lambda x: x['original_input']['input'],
#         target_ttp=lambda x: x['original_input']['output'],
#         context=lambda x: x['context']
#       )
#     | prompt
#     | llm # (这仍然是 ChatOpenAI 包装器)
#     | StrOutputParser()
# )
#
# print("✅ RAG chain built successfully.")

# 链 1: 只负责检索 (将在单线程循环中运行)
retriever_chain = (
    RunnableLambda(lambda x: x['input']) # 接收 {'input': ...}
    | retriever
    | format_docs # 返回上下文(context)字符串
)
print("✅ Retriever chain (for sequential retrieval) built successfully.")

# 链 2: 只负责生成 (将在并发批处理中运行)
# 这条链需要一个字典，包含所有 Prompt 所需的信息
llm_chain = (
    prompt # (Prompt 模板保持不变)
    | llm
    | StrOutputParser()
)
print("✅ LLM chain (for batch generation) built successfully.")

✅ Retriever chain (for sequential retrieval) built successfully.
✅ LLM chain (for batch generation) built successfully.


In [13]:
# --- 4.4 测试 RAG 链 ---
# print("    Testing RAG chain with a sample input...")
# try:
#     test_chain_input = {
#         'input': "TrickBot has used macros in Excel documents to download and deploy the malware on the user’s machine.",
#         'output': "T1059: Command and Scripting Interpreter"
#     }
#     test_explanation = generation_chain.invoke(test_chain_input)
#     print("✅ RAG chain test successful.")
#     print("--- Sample Explanation (from chain test) ---")
#     print(test_explanation)
#     print("------------------------------------------")
# except Exception as e:
#     print(f"⚠️ WARNING: RAG chain test failed: {e}")

In [14]:

import json

# try:
#     with open(DATASET1_PATH, 'r', encoding='utf-8') as f:
#         dataset1 = json.load(f)
#     print(f"✅ Dataset 1 loaded successfully with {len(dataset1)} records.")
# except FileNotFoundError:
#     print(f"❌ ERROR: Dataset 1 file not found at '{DATASET1_PATH}'")
#     exit(1)
# except json.JSONDecodeError:
#     print(f"❌ ERROR: Dataset 1 file '{DATASET1_PATH}' is not a valid JSON file.")
#     exit(1)
# except Exception as e:
#     print(f"❌ An unexpected error occurred while loading Dataset 1: {e}")
#     exit(1)

try:
    with open(DATASET1_PATH, 'r', encoding='utf-8') as f:
        dataset1 = json.load(f)
    print(f"✅ Dataset 1 loaded successfully with {len(dataset1)} records.")
except FileNotFoundError:
    print(f"❌ ERROR: Dataset 1 file not found at '{DATASET1_PATH}'")
    exit(1)
except Exception as e:
    print(f"❌ An unexpected error occurred while loading Dataset 1: {e}")
    exit(1)

# 这是“阶段 B” (LLM) 所需的输入列表
inputs_for_llm_batch = []
valid_items_for_zip = [] # 仍然需要这个来在最后拼装

print(f"--- Starting Phase A: Retrieving context for {len(dataset1)} items (Single-threaded)... ---")
# 这个循环非常快 (几分钟)，因为它只是在查询本地数据库
for item in tqdm(dataset1, desc="Phase A: Retrieving Context"):
    if not isinstance(item, dict) or 'input' not in item or 'output' not in item:
        print(f"⚠️ WARNING: Skipping malformed data item: {item}")
        continue

    try:
        # --- !! 关键修复：我们在这里手动执行 retriever_chain !! ---

        # 1. 准备纯字符串查询
        query_string = item['input']
        if not query_string: # 额外的安全检查
             print(f"⚠️ WARNING: Skipping item with empty 'input' string.")
             continue

        # 2. (手动) 调用 retriever (这会调用 DashScopeEmbeddings)
        #    retriever 是在 [Step 3] 中创建的
        retrieved_docs = retriever.invoke(query_string)

        # 3. (手动) 调用 format_docs (这是在 [Step 4.2] 中定义的)
        retrieved_context_str = format_docs(retrieved_docs)

        # --- 修复结束 ---

        # 4. 准备 LLM 链的输入
        llm_input_dict = {
            "input_cti": item['input'],
            "target_ttp": item['output'],
            "context": retrieved_context_str
        }

        # 5. 将准备好的数据添加到列表中
        inputs_for_llm_batch.append(llm_input_dict)
        valid_items_for_zip.append(item) # 保持 item 同步

    except Exception as e:
        # 捕获 `retriever.invoke` 或 `format_docs` 中的任何错误
        print(f"❌ ERROR during retrieval for item '{item.get('input', 'N/A')[:50]}...': {e}")
        # (我们跳过这个失败的条目)

print(f"✅ Phase A complete. Prepared {len(inputs_for_llm_batch)} items for LLM generation.")-

✅ Dataset 1 loaded successfully with 4000 records.
--- Starting Phase A: Retrieving context for 4000 items (Single-threaded)... ---


Phase A: Retrieving Context: 100%|██████████| 4000/4000 [42:10<00:00,  1.58it/s]

✅ Phase A complete. Prepared 4000 items for LLM generation.





In [15]:
# import time
# new_dataset_with_explanation = []
# error_count = 0
# max_retries = 3
#
# # --- 关键修改 1: 我们使用 enumerate() 来获取索引 (i) ---
# for i, item in enumerate(tqdm(dataset1, desc="Generating Explanations")):
#     if not isinstance(item, dict) or 'input' not in item or 'output' not in item:
#         # (使用 tqdm.write 打印警告)
#         tqdm.write(f"⚠️ WARNING: Skipping malformed data item: {item}")
#         continue
#
#     chain_input = {"input": item['input'], "output": item['output']}
#
#     explanation = None
#     retries = 0
#     while retries < max_retries:
#         try:
#             explanation = generation_chain.invoke(chain_input)
#             break
#         except Exception as e:
#             retries += 1
#             # (使用 tqdm.write 打印错误)
#             tqdm.write(f"⚠️ WARNING: Error processing '{item.get('input', 'N/A')[:50]}...' (Attempt {retries}/{max_retries}): {e}")
#             if retries < max_retries:
#                 tqdm.write("    Retrying after 5 seconds...")
#                 time.sleep(5)
#             else:
#                 tqdm.write(f"❌ ERROR: Max retries reached. Failed to generate explanation for this item.")
#                 error_count += 1
#
#     if explanation:
#         final_output_string = f"[Reasoning Process]:\n{explanation.strip()}\n\n[Final Answer]:\n{item['output']}"
#
#         # --- 关键修改 2: 在这里添加实时打印 ---
#         tqdm.write("\n" + "="*80)
#         tqdm.write(f"--- [Item {i+1}/{len(dataset1)}] ---")
#         tqdm.write(f"Input:    {item['input'][:150]}...") # 打印Input（截断）
#         tqdm.write(f"Target:   {item['output']}")
#         tqdm.write(f"--- Generated Explanation (Real-time) ---")
#         tqdm.write(explanation.strip()) # 打印RAG模型的实时输出
#         tqdm.write("="*80 + "\n")
#         # --- 实时打印结束 ---
#
#         new_sample = {
#             "instruction": item.get('instruction', "Find the techniques and ID from MITRE ATT&CK framework."),
#             "input": item['input'],
#             "output": final_output_string
#         }
#         new_dataset_with_explanation.append(new_sample)

In [16]:
# --- 辅助函数：创建“小批次” ---
def create_mini_batches(list_data, batch_size):
    for i in range(0, len(list_data), batch_size):
        yield list_data[i:i + batch_size]

# --- 设置并发数 ---
# (我们现在可以安全地使用高并发，因为它只调用 Qwen API)
MAX_CONCURRENCY = 30  # (从 10 增加到 30，以接近 22 分钟的总时间)
MINI_BATCH_SIZE = 30  # (TQDM 进度条每 30 个条目更新一次)

new_dataset_with_explanation = []
error_count = 0

# 为 LLM 输入和原始 item 创建小批次
llm_input_batches = list(create_mini_batches(inputs_for_llm_batch, MINI_BATCH_SIZE))
item_batches = list(create_mini_batches(valid_items_for_zip, MINI_BATCH_SIZE))

print(f"    Total items: {len(inputs_for_llm_batch)}")
print(f"    Mini-batch size (per progress bar update): {MINI_BATCH_SIZE}")
print(f"    Number of mini-batches (total progress bar steps): {len(llm_input_batches)}")
print(f"    Internal concurrency (API calls): {MAX_CONCURRENCY}")
print(f"    Estimated time per step: ~10-15 seconds.")
print(f"    Estimated total time: ~{ (len(llm_input_batches) * 12) / 60 :.0f} minutes.")

# --- 我们在自己的循环中调用 .batch()，并用 tqdm 包装它 ---
try:
    for i in tqdm(range(len(llm_input_batches)), desc="Phase B: Generating Explanations"):
        llm_input_batch = llm_input_batches[i] # 这是包含 {input_cti, ...} 的字典列表
        original_item_batch = item_batches[i]

        try:
            # 在“小批次”上调用 llm_chain.batch()
            # (这不再调用 ChromaDB，只调用 Qwen)
            results = llm_chain.batch(
                llm_input_batch, # (注意：我们现在传递的是 llm_input_batch)
                config={"max_concurrency": MAX_CONCURRENCY}
            )

            # 处理这个小批次的结果
            for item, explanation in zip(original_item_batch, results):
                if isinstance(explanation, Exception):
                    tqdm.write(f"\n⚠️ WARNING (in batch {i}): Error processing '{item['input'][:50]}...': {explanation}\n")
                    error_count += 1
                else:
                    final_output_string = f"[Reasoning Process]:\n{explanation.strip()}\n\n[Final Answer]:\n{item['output']}"
                    new_sample = {
                        "instruction": item.get('instruction', "Find the techniques and ID from MITRE ATT&CK framework."),
                        "input": item['input'],
                        "output": final_output_string
                    }
                    new_dataset_with_explanation.append(new_sample)

        except Exception as batch_error:
            tqdm.write(f"\n❌ FATAL ERROR during mini-batch {i} (Qwen API Error?): {batch_error}")
            tqdm.write("    Skipping this entire batch...")
            error_count += len(llm_input_batch)

except KeyboardInterrupt:
    print("\n\n--- User interrupted. Stopping batch processing... ---")
    pass
except Exception as e:
    print(f"\n❌ ERROR: An unexpected fatal error occurred: {e}")
    exit(1)

    Total items: 4000
    Mini-batch size (per progress bar update): 30
    Number of mini-batches (total progress bar steps): 134
    Internal concurrency (API calls): 30
    Estimated time per step: ~10-15 seconds.
    Estimated total time: ~27 minutes.


Phase B: Generating Explanations: 100%|██████████| 134/134 [39:13<00:00, 17.56s/it]


In [17]:
print(f"\n--- [Step 7] Generation Complete ---")
print(f"    Successfully generated explanations: {len(new_dataset_with_explanation)} records")
print(f"    Failed items: {error_count} records")

print(f"\n--- [Step 8] Saving new dataset to '{OUTPUT_DATASET_PATH}'... ---")
try:
    with open(OUTPUT_DATASET_PATH, 'w', encoding='utf-8') as f:
        json.dump(new_dataset_with_explanation, f, indent=2, ensure_ascii=False)
    print(f"✅ New dataset saved successfully!")
    if new_dataset_with_explanation:
      print("\n--- Example of new dataset (first item) ---")
      print(json.dumps(new_dataset_with_explanation[0], indent=2, ensure_ascii=False))
except Exception as e:
    print(f"❌ ERROR: Failed to save the new dataset: {e}")

print("\n--- Script Finished ---")


--- [Step 7] Generation Complete ---
    Successfully generated explanations: 4000 records
    Failed items: 0 records

--- [Step 8] Saving new dataset to './data/output/sft_dataset_with_qwen_compat_explanations.json'... ---
✅ New dataset saved successfully!

--- Example of new dataset (first item) ---
{
  "instruction": "Find the techniques and ID from MITRE ATT&CK framework.",
  "input": "TrickBot has used macros in Excel documents to download and deploy the malware on the user’s machine.",
  "output": "[Reasoning Process]:\n[Reasoning Process]: The CTI input states that \"TrickBot has used macros in Excel documents to download and deploy the malware on the user’s machine.\" The key action here is the use of **macros**—specifically within an **Excel document**—to execute malicious behavior, namely downloading and deploying malware. Macros in Microsoft Office applications, including Excel, are implemented using **Visual Basic for Applications (VBA)**, which is a scripting language th