In [None]:
# demo_knn_neighbors_distinct_ids.py
from types import SimpleNamespace
from collections import defaultdict
from typing import Dict, List, Tuple, Optional, Literal
import random

Ans = Literal["A", "B"]

class Demo:
    def __init__(self, observed_dict: Dict[str, Dict[str, Ans]]):
        # observed_dict: { qid: { cid: "A"/"B" } }
        self.observed_dict = observed_dict
        self.graph = SimpleNamespace(neighbor={})
        self._validate()

    def _validate(self):
        for qid, cmap in self.observed_dict.items():
            if not isinstance(qid, str):
                raise ValueError(f"qid must be str, got {type(qid)}")
            for cid, ans in cmap.items():
                if not isinstance(cid, str):
                    raise ValueError(f"cid must be str, got {type(cid)} (qid={qid})")
                if ans not in ("A", "B"):
                    raise ValueError(f"ans must be 'A' or 'B', got {ans} (qid={qid}, cid={cid})")

    def update_neighbors_info(
        self,
        K: int = 2,
        seed: Optional[int] = 42,
    ) -> None:
        """Update self.graph.neighbor = {cid: [neighbor_cid, ...]} using top-K
        """
        rng = random.Random(seed)
        graph_nodes = list(self.observed_dict.keys())

        # Build {cid: {qid: ans}}
        by_case: Dict[str, Dict[str, Ans]] = defaultdict(dict)
        for qid, case_map in self.observed_dict.items():
            for cid, ans in case_map.items():
                by_case[cid][qid] = ans  # cid is str; ans is "A"/"B"

        if len(by_case) == 0:
            # random select K neighbor for each caseid
            neighbors: Dict[str, List[str]] = {}
            for cid in graph_nodes:
                pool = [n for n in graph_nodes if n != cid]
                rng.shuffle(pool)
                neighbors[cid] = sorted(pool[:K])  # sort only for stable display
            self.graph.neighbor = neighbors
            return self.graph.neighbor

        case_ids = list(by_case.keys())
        qids_by_case = {cid: set(qa.keys()) for cid, qa in by_case.items()}

        def sim(a: str, b: str) -> Optional[float]:
            shared = qids_by_case[a] & qids_by_case[b]
            m = len(shared)   
            qa, qb = by_case[a], by_case[b]
            matches = sum(qa[q] == qb[q] for q in shared)
            return matches / m

  
        for cid in case_ids:
            cand: List[Tuple[str, float]] = []
            for nid in case_ids:
                if nid == cid:
                    continue
                s = sim(cid, nid)
                if s is not None:
                    cand.append((nid, float(s)))

            cand.sort(key=lambda x: (-x[1], x[0]))
            chosen = cand[:K]
            self.graph.neighbor[cid] = [u for u, _ in chosen]
        
        return self.graph.neighbor


def pretty(title: str, d: Dict[str, List[str]]) -> None:
    print(f"\n=== {title} ===")
    for cid in sorted(d):
        print(f"{cid:>3} -> {d[cid]}")

if __name__ == "__main__":
    # observed_dict: {qid: {cid: ans}}
    # Use distinct case IDs (U1..U5) and answers 'A'/'B'
    # observed_dict = {
    #     "q1": {"U1":"A","U2":"A","U3":"B","U4":"A","U5":"A"},
    #     "q2": {"U1":"A","U2":"A","U3":"B","U4":"A","U5":"A"},
    #     "q3": {"U1":"B","U2":"B","U3":"A","U4":"B","U5":"B"},
    #     "q4": {"U1":"A","U2":"A","U3":"B","U4":"B","U5":"A"},  # U2 and U5 identical -> ties happen
    # }
    demo = Demo(observed_dict)

    demo.update_neighbors_info(K=3, seed=42)
    pretty("K=3, seed=42", demo.graph.neighbor)



=== K=3, seed=42 ===
 q1 -> ['q2', 'q3', 'q4']
 q2 -> ['q1', 'q3', 'q4']
 q3 -> ['q1', 'q2', 'q4']
 q4 -> ['q1', 'q2', 'q3']
