# 1. get ip_pool

In [None]:
import os, json, ipaddress
from collections import defaultdict
import maxminddb


def find_root(start=None):
    cur = os.path.abspath(start or os.getcwd())
    while True:
        if os.path.isdir(os.path.join(cur, "data")) and os.path.isdir(os.path.join(cur, "DRFA")):
            return cur
        nxt = os.path.dirname(cur)
        if nxt == cur:
            return os.path.abspath(start or os.getcwd())
        cur = nxt


def is_ip(x: str) -> bool:
    try:
        ipaddress.ip_address(x)
        return True
    except Exception:
        return False


def ips_from_record(rec: dict) -> set[str]:
    d = rec.get("data") or {}
    ips = set()

    r = d.get("resolver")
    if isinstance(r, str) and r:
        ip = r.split(":")[0].strip()
        if is_ip(ip):
            ips.add(ip)

    for a in (d.get("answers") or []):
        if isinstance(a, dict) and (a.get("type") or "").upper() in ("A", "AAAA"):
            ip = (a.get("answer") or "").strip()
            if is_ip(ip):
                ips.add(ip)
    return ips


def build_ip_pool():
    root = find_root()
    data = os.path.join(root, "data")
    drfa = os.path.join(root, "DRFA")

    asn_db = os.path.join(data, "GeoLite2-ASN_20250702", "GeoLite2-ASN.mmdb")
    city_db = os.path.join(data, "GeoLite2-City_20250702", "GeoLite2-City.mmdb")
    ctry_db = os.path.join(data, "GeoLite2-Country_20250702", "GeoLite2-Country.mmdb")
    in_path = os.path.join(data,"datasets", "dataset_sample.jsonl")
    out_path = os.path.join(drfa, "ip_pool.json")

    for p in (asn_db, city_db, ctry_db, in_path):
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing: {p}")

    by_aco = defaultdict(lambda: defaultdict(lambda: defaultdict(set)))  # asn->country->org->ips
    by_asn, by_ctry, by_org = defaultdict(set), defaultdict(set), defaultdict(set)
    seen, n_lines = set(), 0

    with maxminddb.open_database(asn_db) as asn_r, \
         maxminddb.open_database(city_db) as city_r, \
         maxminddb.open_database(ctry_db) as ctry_r, \
         open(in_path, "r", encoding="utf-8") as fin:

        for line in fin:
            line = line.strip()
            if not line:
                continue
            n_lines += 1
            try:
                rec = json.loads(line)
            except Exception:
                continue

            for ip in ips_from_record(rec):
                if ip in seen:
                    continue
                seen.add(ip)

                ar = asn_r.get(ip) or {}
                asn = str(ar.get("autonomous_system_number") or "Unknown")
                org = str(ar.get("autonomous_system_organization") or "Unknown")

                cr = city_r.get(ip) or {}
                country = (cr.get("country") or {}).get("iso_code")
                if not country:
                    rr = ctry_r.get(ip) or {}
                    country = (rr.get("country") or {}).get("iso_code")
                country = str(country or "Unknown")

                by_aco[asn][country][org].add(ip)
                by_asn[asn].add(ip)
                by_ctry[country].add(ip)
                by_org[org].add(ip)

    def conv(x):
        if isinstance(x, set): return sorted(x)
        if isinstance(x, dict): return {k: conv(v) for k, v in x.items()}
        return x

    out = {
        "meta": {"source": in_path, "n_lines_read": n_lines, "n_unique_ips": len(seen)},
        "by_asn_country_org": conv(by_aco),
        "by_asn": conv(by_asn),
        "by_country": conv(by_ctry),
        "by_org": conv(by_org),
    }

    os.makedirs(drfa, exist_ok=True)
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(out, f, ensure_ascii=False)

    print(f"[OK] {out_path}  lines={n_lines}  unique_ips={len(seen)}")


build_ip_pool()


[OK] /home/dc/jhr/dns-hijack/llm/artifact/DRFA/ip_pool.json  lines=5000  unique_ips=4914


# 2. data augmentation prompt

In [None]:
import json, random, ipaddress

FIELDS = [
    "shuffle_answers",
    "ttl_jitter",
    "resolver_replace",
    "answer_ip_replace",
    "timestamp_jitter",
    "port_jitter",
]

def is_ip(x: str) -> bool:
    try:
        ipaddress.ip_address(x)
        return True
    except Exception:
        return False

def build_ip2key(ip_pool: dict):
    """
    reverse mapping from ip to (asn, country, org)
    """
    ip2key = {}
    aco = ip_pool.get("by_asn_country_org", {})
    for asn, c_map in aco.items():
        for country, o_map in (c_map or {}).items():
            for org, ips in (o_map or {}).items():
                for ip in ips or []:
                    ip2key[ip] = (asn, country, org)
    return ip2key

def pick_candidates(ip_pool: dict, asn: str, country: str, org: str, original_ip: str, k=10):
    aco = ip_pool["by_asn_country_org"]
    cands = list(aco.get(asn, {}).get(country, {}).get(org, []))

    if not cands:
        oc = aco.get(asn, {}).get(country, {})
        if isinstance(oc, dict):
            for lst in oc.values():
                cands.extend(lst)

    if not cands:
        cands = list(ip_pool.get("by_asn", {}).get(asn, []))
    if not cands:
        cands = list(ip_pool.get("by_country", {}).get(country, []))

    cands = [x for x in cands if x != original_ip]
    random.shuffle(cands)
    return cands[:k]

def applicable_fields(record):
    d = record.get("data") or {}
    ans = d.get("answers") or []
    has_answers = isinstance(ans, list) and len(ans) > 0
    has_ttl = any(isinstance(a, dict) and "ttl" in a for a in ans) if has_answers else False
    has_resolver = isinstance(d.get("resolver"), str) and d.get("resolver")
    return [
        f for f in FIELDS
        if (f != "answer_ip_replace" or has_answers)
        and (f != "ttl_jitter" or has_ttl)
        and (f != "shuffle_answers" or (has_answers and len(ans) >= 2))
        and (f != "resolver_replace" or has_resolver)
    ]

def build_drfa_prompt(record: dict, ip_pool: dict, ip2key: dict, augmentation_n=3):
    """
    Input a single record, output a DRFA perturbation-guided prompt (strict JSON string)
    - Do not read MaxMind
    - Select 10 candidate IPs from the ip_pool along with asn/country/org
    - Require LLM to output consistency_score (1-100)
    """
    augmentation_n = augmentation_n 
    augmentation_n_prompt = augmentation_n + 5
    rec = dict(record)
    rec.pop("label", None) # avoid leaking label

    d = rec.get("data") or {}
    resolver = d.get("resolver") or ""
    resolver_ip = resolver.split(":")[0].strip() if isinstance(resolver, str) else ""
    resolver_port = resolver.split(":")[1].strip() if ":" in resolver else "53"

    answers = d.get("answers") or []
    answer_ips = []
    for a in answers:
        if isinstance(a, dict) and (a.get("type") or "").upper() in ("A", "AAAA"):
            ip = (a.get("answer") or "").strip()
            if is_ip(ip):
                answer_ips.append(ip)

    avail = applicable_fields(rec)
    used = set()
    tasks = []

    for i in range(augmentation_n_prompt):
        cand = [f for f in avail if f not in used] or avail
        f = random.choice(cand)
        used.add(f)

        task = {"id": i + 1, "field": f}

        if f == "resolver_replace" and is_ip(resolver_ip) and resolver_ip in ip2key:
            asn, country, org = ip2key[resolver_ip]
            task["resolver_port_keep"] = resolver_port
            task["resolver_candidates"] = pick_candidates(ip_pool, asn, country, org, resolver_ip, k=10)

        if f == "answer_ip_replace" and answer_ips:
            mp = {}
            for ip in sorted(set(answer_ips)):
                if ip in ip2key:
                    asn, country, org = ip2key[ip]
                    mp[ip] = pick_candidates(ip_pool, asn, country, org, ip, k=10)
            task["answer_candidates"] = mp  # 可能为空：说明 pool里没该IP映射

        tasks.append(task)

    prompt = {
        "task": (
            "Generate multiple candidate DNS record augmentations using field-level perturbations. "
            "Then self-evaluate each candidate for semantic consistency with the original record, "
            "and SELECT the top augmentations with the highest consistency scores."
        ),

        "augmentation_policy": {
            "answer_ip_replace": {
                "constraint": "Select ONLY from provided equivalent IP candidates (same ASN, country, organization).",
                "semantic_requirement": "Resolution semantics must remain unchanged."
            },
            "resolver_replace": {
                "constraint": "Select ONLY from provided resolver IP candidates (same country and organization).",
                "semantic_requirement": "Resolver behavior must remain semantically equivalent."
            },
            "ttl_jitter": {
                "constraint": "Apply bounded proportional perturbation.",
                "allowed_range": "TTL must remain > 0 and within a reasonable operational range.",
                "semantic_requirement": "Must not alter caching semantics."
            },
            "shuffle_answers": {
                "constraint": "Only reorder answer entries.",
                "semantic_requirement": "Answer set must remain identical."
            },
            "port_jitter": {
                "constraint": "Select from common DNS ports.",
                "allowed_ports": [53, 5353, 853],
                "semantic_requirement": "Protocol semantics must remain unchanged."
            },
            "timestamp_jitter": {
                "constraint": "Apply small local temporal shift.",
                "semantic_requirement": "Must not affect long-term resolution behavior."
            }
        },

        "rules": [
            "Return STRICT JSON only.",
            f"First, generate {augmentation_n_prompt} candidate augmented records.",
            "Each candidate must apply ONLY the specified perturbation task.",
            "Do NOT modify qname or domain identity.",
            "If candidate sets are provided, replacements must be chosen exclusively from them.",
            "Assign each candidate a semantic consistency score (integer 1–100).",
            f"Then, SELECT and RETURN ONLY the top {augmentation_n} candidates with the highest scores.",
            "Returned candidates must be sorted by descending consistency score."
        ],

        "original_record": rec,

        "augmentation_tasks": tasks,

        "output_schema": {
            "selected_augmented_records": [
                {
                    "record": "<augmented_dns_record_json>",
                    "consistency_score": "integer in [1, 100]",
                    "applied_task_id": "integer"
                }
            ]
        }
    }

    return json.dumps(prompt, ensure_ascii=False)

# 3. generation augmentation record

In [None]:
import os, json
import requests

# ========= project paths =========
def find_root(start=None):
    cur = os.path.abspath(start or os.getcwd())
    while True:
        if os.path.isdir(os.path.join(cur, "data")) and os.path.isdir(os.path.join(cur, "DRFA")):
            return cur
        nxt = os.path.dirname(cur)
        if nxt == cur:
            return os.path.abspath(start or os.getcwd())
        cur = nxt

root = find_root()
data_dir = os.path.join(root, "data")
drfa_dir = os.path.join(root, "DRFA")

IN_PATH = os.path.join(data_dir, "dataset_sample.jsonl")
IP_POOL_PATH = os.path.join(drfa_dir, "ip_pool.json")
OUT_PATH = os.path.join(drfa_dir, "drfa_llm_outputs.jsonl")

# ========= API Configuration =========
API_KEY = "your API key here"
BASE_URL = "https://api.deepseek.com"
MODEL = "deepseek-chat"

# ========= load ip_pool =========
with open(IP_POOL_PATH, "r", encoding="utf-8") as f:
    ip_pool = json.load(f)

def build_ip2key(ip_pool: dict):
    """
    Reverse lookup of IP from ip_pool["by_asn_country_org"] -> (asn,country,org)
    """
    ip2key = {}
    aco = ip_pool.get("by_asn_country_org", {})
    for asn, c_map in aco.items():
        for country, o_map in (c_map or {}).items():
            for org, ips in (o_map or {}).items():
                for ip in ips or []:
                    ip2key[ip] = (asn, country, org)
    return ip2key

ip2key = build_ip2key(ip_pool)  # 你已实现

# ========= call API =========
def call_llm_api(prompt: str):
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": MODEL,
        "messages": [
            {"role": "system", "content": "You are a DNS data augmentation assistant. Return STRICT JSON only."},
            {"role": "user", "content": prompt}
        ]
    }
    print("[DEBUG] API Request Payload:", json.dumps(payload, ensure_ascii=False, indent=2))
    response = requests.post(f"{BASE_URL}/v1/chat/completions", headers=headers, json=payload)
    print("[DEBUG] API Response Status Code:", response.status_code)
    print("[DEBUG] API Response Body:", response.text)
    if response.status_code == 200:
        result = response.json()
        return result.get("choices", [{}])[0].get("message", {}).get("content", "").strip()
    else:
        raise Exception(f"API call failed with status code {response.status_code}: {response.text}")


# ========= main loop =========
os.makedirs(drfa_dir, exist_ok=True)

with open(IN_PATH, "r", encoding="utf-8") as fin, \
     open(OUT_PATH, "w", encoding="utf-8") as fout:

    for line_id, line in enumerate(fin):
        if line_id >= 10:  # Only process the first 10 lines
            break

        line = line.strip()
        if not line:
            continue

        record = json.loads(line)

        prompt = build_drfa_prompt(
            record=record,
            ip_pool=ip_pool,
            ip2key=ip2key,
            augmentation_n=3
        )

        llm_output = call_llm_api(prompt)

        fout.write(json.dumps({
            "line_id": line_id,
            "record": record,
            "prompt": prompt,
            "llm_output": llm_output
        }, ensure_ascii=False) + "\n")

print(f"[OK] DRFA LLM outputs saved to {OUT_PATH}")

[DEBUG] API Request Payload: {
  "model": "deepseek-chat",
  "messages": [
    {
      "role": "system",
      "content": "You are a DNS data augmentation assistant. Return STRICT JSON only."
    },
    {
      "role": "user",
      "content": "{\"task\": \"Generate multiple candidate DNS record augmentations using field-level perturbations. Then self-evaluate each candidate for semantic consistency with the original record, and SELECT the top augmentations with the highest consistency scores.\", \"augmentation_policy\": {\"answer_ip_replace\": {\"constraint\": \"Select ONLY from provided equivalent IP candidates (same ASN, country, organization).\", \"semantic_requirement\": \"Resolution semantics must remain unchanged.\"}, \"resolver_replace\": {\"constraint\": \"Select ONLY from provided resolver IP candidates (same country and organization).\", \"semantic_requirement\": \"Resolver behavior must remain semantically equivalent.\"}, \"ttl_jitter\": {\"constraint\": \"Apply bounded pro

KeyboardInterrupt: 