In [2]:
from __future__ import annotations

import math
import pickle
from dataclasses import dataclass
from datetime import datetime, date
from typing import Any, Dict, List, Optional, Tuple

import networkx as nx


# -----------------------------
# Utilities
# -----------------------------

def nodeD(domain: str) -> str:
    return f"D:{domain.strip().lower()}"

def nodeS(asn: int | str) -> str:
    return f"S:{int(asn)}"

def parse_day_yyyymmdd(day: str) -> date:
    # "20250904"
    return datetime.strptime(day, "%Y%m%d").date()

def clamp_nonneg(x: int) -> int:
    return x if x >= 0 else 0


# -----------------------------
# Edge temporal weight
# -----------------------------

def edge_temporal_weight(
    G: nx.DiGraph,
    u: str,
    v: str,
    query_day: str,
    lam: float = 0.25,
) -> float:
    """
    Compute temporal reliability weight for the (domain, asn) edge, using edge attrs:
      - day_counts: dict {YYYYMMDD: count}
      - days: set of YYYYMMDD (optional)
      - total_count: int (optional)

    Graph is directed D->S, but retrieval walks are undirected; so we try (u,v) then (v,u).
    """
    # get edge data regardless direction
    data = None
    if G.has_edge(u, v):
        data = G[u][v]
    elif G.has_edge(v, u):
        data = G[v][u]
    else:
        return 0.0

    day_counts = data.get("day_counts")
    if not isinstance(day_counts, dict) or len(day_counts) == 0:
        # fallback: if only `days` exists, treat each day count=1
        days = data.get("days")
        if isinstance(days, (set, list, tuple)) and len(days) > 0:
            day_counts = {str(d): 1 for d in days}
        else:
            return 0.0

    qd = parse_day_yyyymmdd(query_day)

    num = 0.0
    den = 0.0
    for d, c in day_counts.items():
        try:
            dd = parse_day_yyyymmdd(str(d))
        except Exception:
            continue
        try:
            c = int(c)
        except Exception:
            continue
        if c <= 0:
            continue

        delta = clamp_nonneg((qd - dd).days)
        w = math.exp(-lam * delta)
        num += c * w
        den += c

    return (num / den) if den > 0 else 0.0


# -----------------------------
# Beam search path retrieval
# -----------------------------

@dataclass
class RPath:
    nodes: List[str]                  # node sequence
    edge_weights: List[float]         # per-hop temporal weights

    def length(self) -> int:
        return len(self.edge_weights)

    def score(self) -> float:
        """
        Path score: mean(edge_weights) / sqrt(L)  (length penalty)
        """
        L = self.length()
        if L <= 0:
            return 0.0
        return (sum(self.edge_weights) / L) / math.sqrt(L)


class ResolutionPathRetriever:
    def __init__(
        self,
        G: nx.DiGraph,
        lam: float = 0.25,
        max_expand_neighbors: int = 5,
        avoid_cycles: bool = True,
    ):
        self.G = G
        self.lam = lam
        self.max_expand_neighbors = max_expand_neighbors
        self.avoid_cycles = avoid_cycles

    def _neighbors_undirected(self, u: str) -> List[str]:
        """
        Treat the directed graph as undirected for traversal.
        """
        nbrs = set()
        if u in self.G:
            nbrs.update(self.G.successors(u))
            nbrs.update(self.G.predecessors(u))
        return list(nbrs)

    def _top_neighbors(
        self,
        u: str,
        query_day: str,
        visited: Optional[set] = None,
    ) -> List[Tuple[str, float]]:
        """
        Return up to max_expand_neighbors neighbors sorted by local edge temporal weight.
        """
        candidates: List[Tuple[str, float]] = []
        for v in self._neighbors_undirected(u):
            if self.avoid_cycles and visited is not None and v in visited:
                continue
            w = edge_temporal_weight(self.G, u, v, query_day, lam=self.lam)
            candidates.append((v, w))

        candidates.sort(key=lambda x: x[1], reverse=True)
        return candidates[: self.max_expand_neighbors]

    def beam_search(
        self,
        start: str,
        target: str,
        query_day: str,
        max_hops: int = 4,
        beam_width: int = 30,
        top_k: int = 5,
    ) -> List[RPath]:
        """
        Find top_k paths from start to target (undirected traversal),
        scored by temporal decay weights along edges.
        """
        if start not in self.G or target not in self.G:
            return []

        frontier: List[RPath] = [RPath(nodes=[start], edge_weights=[])]
        completed: List[RPath] = []

        for _depth in range(1, max_hops + 1):
            next_frontier: List[RPath] = []

            for p in frontier:
                u = p.nodes[-1]
                visited = set(p.nodes) if self.avoid_cycles else None

                for v, w in self._top_neighbors(u, query_day, visited):
                    newp = RPath(nodes=p.nodes + [v], edge_weights=p.edge_weights + [w])
                    if v == target:
                        completed.append(newp)
                    else:
                        next_frontier.append(newp)

            if not next_frontier:
                break

            next_frontier.sort(key=lambda x: x.score(), reverse=True)
            frontier = next_frontier[:beam_width]

        completed.sort(key=lambda x: x.score(), reverse=True)
        return completed[:top_k]


# -----------------------------
# Convenience: retrieval for a DNS record
# -----------------------------

def retrieve_resolution_paths_for_record(
    retriever: ResolutionPathRetriever,
    domain: str,
    answer_asns: List[int],
    query_day: str,
    max_hops: int = 4,
    beam_width: int = 30,
    top_k_each_asn: int = 3,
) -> List[Dict[str, Any]]:
    """
    For a DNS record at query_day, retrieve paths between D:domain and each S:asn.
    Merge and sort.
    """
    start = nodeD(domain)
    all_paths: List[Dict[str, Any]] = []

    for asn in answer_asns:
        target = nodeS(asn)
        paths = retriever.beam_search(
            start=start,
            target=target,
            query_day=query_day,
            max_hops=max_hops,
            beam_width=beam_width,
            top_k=top_k_each_asn,
        )
        for p in paths:
            all_paths.append({
                "pair": f"{start} <-> {target}",
                "score": p.score(),
                "nodes": p.nodes,
                "edge_weights": p.edge_weights,
            })

    all_paths.sort(key=lambda x: x["score"], reverse=True)
    return all_paths


# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    # Load resolution graph
    with open("../outputs/resolution_graph.gpickle", "rb") as f:
        RG: nx.DiGraph = pickle.load(f)

    retriever = ResolutionPathRetriever(
        RG,
        lam=0.25,                 
        max_expand_neighbors=5,   
        avoid_cycles=True,
    )

    domain = "mygiftcard.ru"
    query_day = "20250506"
    answer_asns = [43298] 

    paths = retrieve_resolution_paths_for_record(
        retriever,
        domain=domain,
        answer_asns=answer_asns,
        query_day=query_day,
        max_hops=4,
        beam_width=30,
        top_k_each_asn=3,
    )

    for i, p in enumerate(paths[:10], 1):
        print(f"[{i}] score={p['score']:.6f}  {p['pair']}")
        print("    " + " -> ".join(p["nodes"]))
        print("    edge_w=" + " ".join(f"{w:.3f}" for w in p["edge_weights"]))


[1] score=0.666667  D:mygiftcard.ru <-> S:43298
    D:mygiftcard.ru -> S:43298
    edge_w=0.667
