# extract paths to jsonl

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Read data/dataset_sample.jsonl, for each record:
  - keep the original record unchanged in field "record"
  - extract entity paths (entity graph RAG)
  - extract resolution paths (resolution graph RAG)
Write to one output JSONL file (one record -> one JSON line).

Dependencies:
  - entity_rag.py (init, get_paths)
  - resolution_rag.py (init, get_paths_from_record)
"""

from __future__ import annotations

import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional

import entity_rag
import resolution_rag


_IP_RE = re.compile(r"^\d{1,3}(\.\d{1,3}){3}$")


def is_ipv4(ip: str) -> bool:
    if not ip or not _IP_RE.match(ip):
        return False
    try:
        return all(0 <= int(p) <= 255 for p in ip.split("."))
    except Exception:
        return False


def extract_day_yyyymmdd(ts: str) -> Optional[str]:
    if not ts:
        return None
    try:
        return datetime.fromisoformat(ts).strftime("%Y%m%d")
    except Exception:
        m = re.search(r"(\d{4})-(\d{2})-(\d{2})", ts)
        if not m:
            return None
        return f"{m.group(1)}{m.group(2)}{m.group(3)}"


def parse_record_basic(record: Dict[str, Any]) -> Dict[str, Any]:
    """
    Extract minimal fields needed for calling RAG modules.
    """
    domain = (record.get("name") or "").strip().lower()
    ts = (record.get("timestamp") or "").strip()
    day = extract_day_yyyymmdd(ts)

    data = record.get("data") or {}
    resolver_field = (data.get("resolver") or "").strip()
    resolver_ip = resolver_field.split(":")[0].strip() if resolver_field else ""

    answer_ips: List[str] = []
    for ans in (data.get("answers") or []):
        if str(ans.get("type", "")).upper() != "A":
            continue
        ip = str(ans.get("answer", "")).strip()
        if is_ipv4(ip):
            answer_ips.append(ip)

    # dedup keep order
    seen = set()
    answer_ips = [x for x in answer_ips if not (x in seen or seen.add(x))]

    return {"domain": domain, "resolver_ip": resolver_ip, "answer_ips": answer_ips, "day": day}


def entity_path_to_hops(nodes: List[str], edges: List[str]) -> List[Dict[str, Any]]:
    hops = []
    for i, rtype in enumerate(edges):
        hops.append({"src": nodes[i], "rtype": rtype, "dst": nodes[i + 1]})
    return hops


def build_path_str_from_hops(hops: List[Dict[str, Any]]) -> str:
    if not hops:
        return ""
    s = hops[0]["src"]
    for h in hops:
        s += f" -[{h['rtype']}]-> {h['dst']}"
    return s


def format_entity_paths(entity_paths_raw: List[Dict[str, Any]], limit: int) -> List[Dict[str, Any]]:
    """
    Raw entity paths from entity_rag.get_paths:
      {pair, score, nodes, edges}
    -> formatted:
      {pair, score, path_str, hops}
    """
    out = []
    for p in entity_paths_raw[:limit]:
        nodes = p.get("nodes") or []
        edges = p.get("edges") or []
        hops = entity_path_to_hops(nodes, edges)
        out.append(
            {
                "pair": p.get("pair"),
                "score": float(p.get("score", 0.0)),
                "path_str": build_path_str_from_hops(hops),
                "hops": hops,
            }
        )
    return out


def format_resolution_paths(res_paths_raw: List[Dict[str, Any]], limit: int) -> List[Dict[str, Any]]:
    """
    Raw resolution paths from resolution_rag.get_paths_from_record:
      {pair, score, nodes, edge_weights}
    -> formatted:
      {pair, score, path_str, hops, edge_weights}
    """
    out = []
    for p in res_paths_raw[:limit]:
        nodes = p.get("nodes") or []
        ws = p.get("edge_weights") or []
        hops = []
        # resolution traversal is undirected; we don't have rtype in raw, so use "temporal_edge"
        for i in range(min(len(nodes) - 1, len(ws))):
            hops.append({"src": nodes[i], "rtype": "temporal_edge", "dst": nodes[i + 1], "w": float(ws[i])})

        # path_str include weight
        if hops:
            s = hops[0]["src"]
            for h in hops:
                s += f" -[{h['rtype']}(w={h['w']:.3f})]-> {h['dst']}"
        else:
            s = ""

        out.append(
            {
                "pair": p.get("pair"),
                "score": float(p.get("score", 0.0)),
                "path_str": s,
                "hops": hops,
                "edge_weights": [float(x) for x in ws],
            }
        )
    return out


def main(
    in_jsonl: str = "../data/datasets/dataset_sample.jsonl",
    out_jsonl: str = "../data/datasets/dataset_sample_with_paths.jsonl",
    entity_graph_path: str = "../outputs/entity_graph.gpickle",
    resolution_graph_path: str = "../outputs/resolution_graph.gpickle",
    asn_mmdb: str = "../data/GeoLite2-ASN_20250702/GeoLite2-ASN.mmdb",
    entity_limit: int = 10,
    resolution_limit: int = 10,
    max_records: Optional[int] = None,
) -> None:
    if not os.path.exists(in_jsonl):
        raise FileNotFoundError(in_jsonl)

    # init two RAG modules once
    entity_rag.init(graph_path=entity_graph_path, max_expand_edges=5, decay_alpha=0.6)
    resolution_rag.init(graph_path=resolution_graph_path, asn_mmdb=asn_mmdb, max_expand_neighbors=5, lam=0.25)

    n_in = 0
    n_out = 0

    with open(in_jsonl, "r", encoding="utf-8") as rf, open(out_jsonl, "w", encoding="utf-8") as wf:
        for line in rf:
            if max_records is not None and n_in >= max_records:
                break

            line = line.strip()
            if not line:
                continue

            n_in += 1
            try:
                record = json.loads(line)
            except Exception:
                continue

            basic = parse_record_basic(record)
            domain = basic["domain"]
            resolver_ip = basic["resolver_ip"]
            answer_ips = basic["answer_ips"]
            day = basic["day"]

            # --- entity paths: run per answer_ip then merge ---
            entity_paths_raw: List[Dict[str, Any]] = []
            if domain and resolver_ip and answer_ips:
                for aip in answer_ips:
                    # entity_rag.get_paths returns list of {pair,score,nodes,edges}
                    entity_paths_raw.extend(
                        entity_rag.get_paths(domain=domain, resolver_ip=resolver_ip, answer_ip=aip)
                    )
                entity_paths_raw.sort(key=lambda x: float(x.get("score", 0.0)), reverse=True)

            # --- resolution paths: use record-level API (it will parse day + answers + lookup ASN) ---
            res_paths_raw: List[Dict[str, Any]] = []
            if day:
                res_paths_raw = resolution_rag.get_paths_from_record(record)

            out_obj = {
                "record": record,  # keep original intact
                "paths": {
                    "entity": format_entity_paths(entity_paths_raw, entity_limit),
                    "resolution": format_resolution_paths(res_paths_raw, resolution_limit),
                },
            }
            # print(out_obj)
            wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
            n_out += 1

    # close optional resources
    try:
        resolution_rag.close()
    except Exception:
        pass

    print(f"Done. Read={n_in}, Written={n_out}, Output={out_jsonl}")


if __name__ == "__main__":
    main()


Done. Read=5000, Written=5000, Output=data/dataset_sample_with_paths.jsonl
