# 1.Generation Answer

In [None]:
import json
import os
from collections import defaultdict

# ====== config ======
IN_PATH = "../outputs/dataset_sample_with_paths.jsonl"
OUT_PATH = "../outputs/answer_gen.jsonl"

TOPK_ENTITY_PER_PAIR = 2
TOPK_RES_PER_PAIR = 2
MAX_PROMPT_CHARS = 6000

SYSTEM = (
    "You are a DNS integrity anomaly detector. "
    "Use the DNS record and evidence paths. Return ONLY valid JSON."
)

def clip(s, n):
    return s if len(s) <= n else s[:n-3] + "..."

def pick_topk_by_pair(paths, topk):
    g = defaultdict(list)
    for p in paths or []:
        g[str(p.get("pair", ""))].append(p)
    keep = []
    for pair, plist in g.items():
        # sort low->high, keep topK highest scores, then restore low->high
        plist = sorted(plist, key=lambda x: float(x.get("score", 0.0)))
        keep.extend(plist[-topk:] if topk > 0 else plist)
    return sorted(keep, key=lambda x: float(x.get("score", 0.0)))  # low->high (best near end)

def build_assistant_answer(label_norm, evidence_paths):
    """
    generate assistant answer JSON string
    - label_norm: "normal" or "anomalous"
    - evidence_paths: list of evidence path strings
    """
    rationale = f"The DNS record is classified as {label_norm}. "
    rationale += "Based on the evidence paths retrieved, the anomaly or normality is determined as follows: "
    
    rationale_details = []
    for path in evidence_paths:
        rationale_details.append(f"- {path}")
    
    rationale += " ".join(rationale_details)
    
    obj = {
        "label": label_norm,
        "confidence": 0.75,  
        "rationale": rationale, 
        "evidence": evidence_paths, 
    }
    return json.dumps(obj, ensure_ascii=False)

def build_prompt(sample):
    record = sample["record"]
    paths = sample.get("paths", {}) or {}

    ent = pick_topk_by_pair(paths.get("entity", []), TOPK_ENTITY_PER_PAIR)
    res = pick_topk_by_pair(paths.get("resolution", []), TOPK_RES_PER_PAIR)

    parts = []
    
    parts.append("Task: Detect DNS integrity anomalies. Specifically, we focus on anomalies where the DNS resolution process is manipulated or hijacked. These anomalies may include situations where the DNS resolution returns incorrect or inaccessible responses due to malicious interventions such as cache poisoning, resolver manipulation, or unauthorized redirection.")

    parts.append("DNS Integrity Anomalies: These are DNS records where the resolution path has been altered without the domain owner's consent. Such anomalies can occur due to malicious activities like DNS hijacking, DNS manipulation, or network-level censorship. In this task, we aim to detect such anomalies by analyzing DNS records in the context of entity relationships and resolution behaviors.")
    
    parts.append("The two paths used in this model are as follows:")
    parts.append("1) **Entity Graph Path**: This path represents the relationships between DNS entities such as domains, resolvers, countries, and policies. It captures the structural associations among the static attributes of DNS entities and is used to understand long-term relationships between entities that might indicate DNS manipulation.")
    parts.append("2) **Resolution Graph Path**: This path captures the historical resolution behavior of DNS records, which includes the domains' resolution patterns over time. It tracks how a domain's resolution has evolved, including any abnormal shifts that may indicate issues like DNS hijacking or manipulation.")
    
    record_copy = record.copy()
    if 'label' in record_copy:
        del record_copy['label'] 
    parts.append("\nDNS Record (raw JSON):")
    parts.append(json.dumps(record_copy, ensure_ascii=False)) 
    parts.append("")
    
    parts.append("Evidence Paths (sorted by increasing score; strongest near the end):")
    idx = 1
    parts.append("[EntityGraph]")
    if ent:
        for p in ent:
            parts.append(f"[E{idx:02d}] score={float(p.get('score',0.0)):.6f} pair={p.get('pair','')}")
            parts.append(f"  {p.get('path_str','')}")
            idx += 1
    else:
        parts.append("(none)")

    parts.append("[ResolutionGraph]")
    if res:
        for p in res:
            parts.append(f"[E{idx:02d}] score={float(p.get('score',0.0)):.6f} pair={p.get('pair','')}")
            parts.append(f"  {p.get('path_str','')}")
            idx += 1
    else:
        parts.append("(none)")

    parts.append("")
    parts.append(
        "Return ONLY this JSON:\n"
        "{\n"
        '  "label": "normal|anomalous",\n'
        '  "confidence": 0.0-1.0,\n'
        '  "rationale": "brief",\n'
        '  "evidence": ["E01","E03"]\n'
        "}"
    )

    return clip("\n".join(parts), MAX_PROMPT_CHARS)

def main():
    os.makedirs(os.path.dirname(OUT_PATH) or ".", exist_ok=True)
    n = 0
    with open(IN_PATH, "r", encoding="utf-8") as fin, open(OUT_PATH, "w", encoding="utf-8") as fout:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            s = json.loads(line)

            # Extract the label from each record and generate a prompt
            # Determine whether the label is anomaly or normal based on the dataset name
            label_norm = "normal"  # normal

            if s.get("record", {}).get("label") in ["dataset1", "dataset2", "dataset3", "dataset4"]:
                label_norm = "anomalous"
            
            prompt = build_prompt(s)

            # evidence paths
            evidence_paths = []
            for path in s.get("paths", {}).get("entity", []):
                evidence_paths.append(f"Entity Graph: {path.get('path_str', '')}")
            for path in s.get("paths", {}).get("resolution", []):
                evidence_paths.append(f"Resolution Graph: {path.get('path_str', '')}")

            # SFT
            out = {
                "messages": [
                    {"role": "system", "content": SYSTEM},
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": build_assistant_answer(label_norm, evidence_paths)},  # 添加assistant答案
                ],
                "metadata": {
                    "qname": s.get("record", {}).get("name", ""),
                    "timestamp": s.get("record", {}).get("timestamp", ""),
                    "resolver": (s.get("record", {}).get("data", {}) or {}).get("resolver", ""),
                    "label": label_norm  # 记录label，分离出label用于SFT
                },
            }
            fout.write(json.dumps(out, ensure_ascii=False) + "\n")
            n += 1

    print(f"[OK] wrote {OUT_PATH}  n={n}")

if __name__ == "__main__":
    main()


[OK] wrote data/answer_gen.jsonl  n=5000


# 2. FINE-TUNING

We use LLaMA-Factory to fine-tuning and test

```
git clone https://github.com/hiyouga/LLaMA-Factory.git
cd LLaMA-Factory
```