# Rate-Distortion Guided Knowledge Graph Refinement with LLM-Assisted Operations

## This notebook contains the functions that produce the following actions:
- Read in lecture notes JSON into leaf elements
- Build fused distance on the lecture side: chronology + logic + semantics
- Build fused distance on the KG side: structure + semantics
- Compute Fused Gromov–Wasserstein (FGW) coupling if POT is installed
  (optional; exact). Otherwise, compute a proxy coupling via entropic OT
  over a feature+structure cost surrogate.
- Execute greedy refinement ops that minimize $L = r + \beta * d$:
  - add concepts from under-represented content
  - split/merge concepts from coupling patterns
  - add/rewire/remove relationships
- Refine the KG by LLM based on domain knowledge
- Output:
  - Refined KG JSON (same schema as input)
  - Refinement report JSON (iterations, objective, rates, distortions)
  - Coupling analysis JSON (top-k matches)
  - History of refinements
  - Rate–Distortion curve PNG


## Libraries:
- Required: numpy, networkx, scipy (for distance + kmeans), scikit-learn
- Optional: sentence-transformers (for embeddings; else falls back to TF-IDF)
- Optional: POT (Python Optimal Transport, `pip install POT`) for exact FGW

## Notes:
- The script is deterministic up to random seeds for kmeans; you can set --seed.
- If POT is present, exact FGW is used. If not, we compute an entropic OT coupling
  with a proxy cost that mixes feature distance and a local-structure penalty.

## Import Libraries

In [None]:
#!pip install -q POT
#!pip install -q sentence-transformers

In [1]:
import collections
import dataclasses
import itertools
import json
import logging
import math
import os
import random
import re
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

%matplotlib inline

# Optional deps
try:
    import networkx as nx
except Exception as e:
    raise SystemExit("This script requires 'networkx'. Please install it: pip install networkx") from e

try:
    from sklearn.feature_extraction.text import TfidfVectorizer
    from sklearn.metrics.pairwise import cosine_similarity
    from sklearn.cluster import KMeans
except Exception as e:
    raise SystemExit("This script requires scikit-learn. Please install it: pip install scikit-learn") from e

try:
    from scipy.spatial.distance import cdist
    from scipy.special import rel_entr
except Exception as e:
    raise SystemExit("This script requires scipy. Please install it: pip install scipy") from e

# Optional embeddings
_SENTENCE_TF = None
try:
    from sentence_transformers import SentenceTransformer
    _SENTENCE_TF = SentenceTransformer
except Exception:
    _SENTENCE_TF = None

# Optional POT for exact FGW
_HAS_POT = False
try:
    import ot
    from ot.gromov import fused_gromov_wasserstein
    _HAS_POT = True
except Exception:
    _HAS_POT = False


In [2]:
_SENTENCE_TF, _HAS_POT

(sentence_transformers.SentenceTransformer.SentenceTransformer, True)

## Utilitiy and Data Structures

In [3]:
# -----------------------------
# Utility & Data Structures
# -----------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)

@dataclass
class HyperParams:
    alpha_chron: float = 0.2
    alpha_logic: float = 0.3
    alpha_sem: float = 0.5
    gamma_struct: float = 0.4
    gamma_sem: float = 0.6
    lambda_feat: float = 0.6     # feature balance inside FGW objective
    beta: float = 20.0           # rate–distortion trade-off
    theta_add: float = 0.06
    theta_split: float = 0.35
    theta_merge: float = 0.12
    theta_relate: float = 0.25
    max_iterations: int = 10
    convergence_threshold: float = 0.01
    sinkhorn_eps: float = 0.05
    sinkhorn_iter: int = 300

@dataclass
class LectureElement:
    idx: int
    id_path: str
    section_path: List[str]
    text: str
    typ: str

@dataclass
class KGNode:
    id: str
    label: str
    type: str
    definition: str
    aliases: List[str]
    provenance: List[Dict[str, Any]]
    attributes: Dict[str, Any]
    confidence: float
    rationale: str

@dataclass
class KGEdge:
    id: str
    source: str
    target: str
    relation: str
    definition: str
    provenance: List[Dict[str, Any]]
    confidence: float
    rationale: str

@dataclass
class RefinementOutcome:
    iterations: int
    final_objective: float
    rate: float
    distortion: float
    fgw_distance: float
    history: List[Dict[str, Any]]
    operations: Dict[str, int]


## I/O Helpers

In [4]:
# -----------------------------
# IO Helpers
# -----------------------------

def read_json(path: str) -> Any:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def write_json(obj: Any, path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)

def ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)


## Read KG with Edge Text Added

In [5]:
# return definition text for an edge
def get_edge_definition(kg, source_id, target_id, relation):
    
    slabel = source_id
    tlabel = target_id
    
    for anode in kg['nodes']:
        
        if anode['id'].lower() == source_id.lower():
            slabel = anode['label']
                
        if anode['id'].lower() == target_id.lower():
            tlabel = anode['label']

        
    return slabel + " " + relation + " " + tlabel

# add relations where coupled lecture elements are near
def get_edge_rationale(kg, source_id, target_id, relation):

    slabel = source_id
    tlabel = target_id
    
    for anode in kg['nodes']:
        
        if anode['id'].lower() == source_id.lower():
            slabel = anode['label']
                
        if anode['id'].lower() == target_id.lower():
            tlabel = anode['label']

    retionale = f'{slabel} {relation} {tlabel} because their coupled lecture elements are near.'
    
    return rationale
    
# read KG with definition edges
def read_kg(path: str) -> Any:
    kg_json = read_json(path)

    for edge in kg_json['edges']:
        if ('definition' not in edge) or (not edge['definition']):
            sid = edge['source']
            tid = edge['target']
            e_def = get_edge_definition(kg_json, sid, tid, edge['relation'])

            edge['definition'] = e_def

    return kg_json 

## Lecture JSON Parsing

In [6]:
# -----------------------------
# Lecture JSON Parsing
# -----------------------------

def _flatten_lecture_json(node: Dict[str, Any],
                          section_stack: Optional[List[str]] = None,
                          out: Optional[List[LectureElement]] = None,
                          idx_offset: int = 0) -> List[LectureElement]:
    """
    Generic flattener for lecture JSONs with fields:
      - id, level, type, title, content, children (list)
    or where leaf 'elements' already exist (preferred).
    """
    if out is None:
        out = []
    if section_stack is None:
        section_stack = []

    title = node.get("title") or node.get("label") or node.get("name") or ""
    typ = node.get("type") or "section"
    id_str = str(node.get("id", ""))
    content = node.get("content", "")

    # If explicit elements exist (pre-cleaned file), use them
    elements = node.get("elements") or []
    if elements and isinstance(elements, list):
        for e in elements:
            text = e.get("text", "").strip()
            if not text:
                continue
            etype = e.get("type", "text")
            out.append(LectureElement(
                idx=idx_offset + len(out),
                id_path=id_str,
                section_path=section_stack + [title] if title else section_stack[:],
                text=text,
                typ=etype
            ))
    else:
        # Fallback: split content into crude sentence-ish lines
        #
        # This part repeats "Break Down to Individual Elements"
        # in "RD_FGW_source_markdown_json.ipynb".
        # Assume HERE that :code and :markdwon suffixes have been removed
        #
        lines = [ln.strip() for ln in str(content).splitlines() if ln.strip()]
        for ln in lines:
            out.append(LectureElement(
                idx=idx_offset + len(out),
                id_path=id_str,
                section_path=section_stack + [title] if title else section_stack[:],
                text=ln,
                typ="text"
            ))

    # Recurse into children
    for ch in node.get("children", []) or []:
        _flatten_lecture_json(ch, section_stack + ([title] if title else []),
                              out, idx_offset)

    return out

import re
from copy import copy

def clean_up_elements(elements):
    out = []
    for e in elements:
        etext = (e.get('text') if isinstance(e, dict) else getattr(e, 'text', '')) or ''
        if re.fullmatch(r"\s*(?:`{3,}|~{3,}|_{3,}|-{3,})\s*", etext):
            continue
        t = re.sub(r"\b(?:\:code|\:markdown|markdown|exercise|example|agenda|summary)\b", "", etext, flags=re.I)
        t = re.sub(r"^\s*[#>*+\-`~\u2013\u2014\_]*\s*(?:\d+[.)]\s*)?", "", t)
        t = re.sub(r"(?:`{3,}|~{3,}|_{3,}|-{3,})", " ", t)
        t = re.sub(r"\s+", " ", t).strip().lower()
        if t and not re.fullmatch(r"[\W_]+", t):
            if isinstance(e, dict):
                ne = dict(e); ne['text'] = t
            else:
                ne = copy(e); setattr(ne, 'text', t)
            out.append(ne)
    return out


def parse_lecture_elements(lecture_json_path: str) -> List[LectureElement]:
    data = read_json(lecture_json_path)
    # If already a list of sections with elements, flatten all
    if isinstance(data, dict) and data.get("id") is not None:
        elements = _flatten_lecture_json(data)
    elif isinstance(data, list):
        elements = []
        for x in data:
            elements.extend(_flatten_lecture_json(x))
    else:
        raise ValueError("Unsupported lecture JSON format.")

    # Attach sequential chronological index
    for i, el in enumerate(elements):
        el.idx = i

    # clean up the elements' text
    cleaned_elements = clean_up_elements(elements)

    return cleaned_elements

## Embeddings

In [7]:
# -----------------------------
# Embeddings
# -----------------------------

class Embedder:
    def __init__(self, model_name: str = "all-MiniLM-L6-v2", use_sentence_tf: bool = True):
        self.model_name = model_name
        self.use_sentence_tf = use_sentence_tf and (_SENTENCE_TF is not None)
        self.model = None
        self.tfidf = None

        if self.use_sentence_tf:
            try:
                self.model = _SENTENCE_TF(self.model_name)
            except Exception as e:
                logging.warning("SentenceTransformer init failed; falling back to TF-IDF: %s", e)
                self.use_sentence_tf = False

        if not self.use_sentence_tf:
            self.tfidf = TfidfVectorizer(max_features=4096)

    def fit_transform(self, texts: List[str]) -> np.ndarray:
        if self.use_sentence_tf:
            print("Use SentenceTransformer.")
            return np.asarray(self.model.encode(texts, normalize_embeddings=True))
        else:
            print("Use TFIDF to embed -> Check.")
            X = self.tfidf.fit_transform(texts)
            # Normalize L2
            X = X.astype(np.float32)
            norms = np.sqrt((X.power(2)).sum(axis=1)).A1 + 1e-12
            X = X.multiply(1/norms[:, None])
            return X.toarray()

    def transform(self, texts: List[str]) -> np.ndarray:
        if self.use_sentence_tf:
            print("Use SentenceTransformer.")
            return np.asarray(self.model.encode(texts, normalize_embeddings=True))
        else:
            print("Use TFIDF to embed -> check.")
            X = self.tfidf.transform(texts)
            # Normalize L2
            X = X.astype(np.float32)
            norms = np.sqrt((X.power(2)).sum(axis=1)).A1 + 1e-12
            X = X.multiply(1/norms[:, None])
            return X.toarray()


## Building Distance Matrix for Lecture Notes

In [8]:
# --------------------------------------
# Distance Utilities for Lecture Notes
# --------------------------------------

def lcp_length(a: List[str], b: List[str]) -> int:
    n = min(len(a), len(b))
    i = 0
    while i < n and a[i] == b[i]:
        i += 1
    return i

def build_lecture_distance(elements: List[LectureElement],
                           embeddings: np.ndarray,
                           alpha_chron: float,
                           alpha_logic: float,
                           alpha_sem: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns:
      D_L: fused NxN distance
      D_chron, D_logic, D_sem: component distances
    """
    N = len(elements)
    idxs = np.array([el.idx for el in elements], dtype=float)
    max_idx = max(1, int(np.max(idxs)))
    # Chronological: normalized absolute index difference
    D_chron = np.abs(idxs[:, None] - idxs[None, :]) / max(1.0, float(max_idx))

    # Logic: 1 - normalized LCP length
    # Precompute section paths
    paths = [el.section_path for el in elements]
    max_depth = max((len(p) for p in paths), default=1)
    D_logic = np.zeros((N, N), dtype=float)
    for i in range(N):
        for j in range(i, N):
            l = lcp_length(paths[i], paths[j])
            d = 1.0 - (l / max(1.0, float(max_depth)))
            D_logic[i, j] = D_logic[j, i] = d

    # Semantic: 1 - cosine similarity (clipped)
    S = cosine_similarity(embeddings)
    D_sem = np.clip(1.0 - S, 0.0, 2.0)

    D_L = alpha_chron * D_chron + alpha_logic * D_logic + alpha_sem * D_sem
    # Normalize to [0,1]
    D_L = (D_L - D_L.min()) / (D_L.max() - D_L.min() + 1e-12)
    return D_L, D_chron, D_logic, D_sem

## Building Distance Matrix for Knowledge Graph

In [9]:
# --------------------------------------
# Distance Utilities for Knowledge Graph
# --------------------------------------


def build_kg_graph(kg: Dict[str, Any]) -> nx.Graph:

    G = nx.Graph()
    for n in kg.get("nodes", []):
        G.add_node(n["id"], **n)
    for e in kg.get("edges", []):
        # Treat as undirected for structure metrics; keep relation on edge data
        G.add_edge(e["source"], e["target"], **e)
    return G

def node_text_for_embedding(node: Dict[str, Any]) -> str:
    parts = [node.get("label", ""), node.get("definition", "")]
    aliases = node.get("aliases") or []
    if isinstance(aliases, list):
        parts.extend(aliases[:3])
    return ". ".join([p for p in parts if p]).strip()

def build_kg_distance(kg: Dict[str, Any],
                      embedder: Embedder,
                      gamma_struct: float,
                      gamma_sem: float) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[str]]:
    nodes = kg.get("nodes", [])
    node_ids = [n["id"] for n in nodes]
    texts = [node_text_for_embedding(n) or n.get("label", n["id"]) for n in nodes]
    X = embedder.transform(texts) if getattr(embedder, "tfidf", None) else embedder.fit_transform(texts)

    # Semantic distance
    S = cosine_similarity(X)
    D_sem = np.clip(1.0 - S, 0.0, 2.0)

    # Structural distance via shortest paths
    G = build_kg_graph(kg)
    # Precompute APSP lengths (unweighted)
    # For disconnected pairs, set to large value
    n = len(node_ids)
    D_struct = np.full((n, n), fill_value=np.inf, dtype=float)
    sp = dict(nx.all_pairs_shortest_path_length(G))
    for i, u in enumerate(node_ids):
        D_struct[i, i] = 0.0
        for j, v in enumerate(node_ids):
            if v in sp.get(u, {}):
                D_struct[i, j] = float(sp[u][v])
    # Replace inf with max finite + 1
    finite = D_struct[np.isfinite(D_struct)]
    max_f = float(finite.max()) if finite.size else 1.0
    D_struct[~np.isfinite(D_struct)] = max_f + 1.0

    # Normalize each component to [0,1]
    def _norm(M):
        return (M - M.min()) / (M.max() - M.min() + 1e-12)
    D_struct = _norm(D_struct)
    D_sem = _norm(D_sem)

    D_K = gamma_struct * D_struct + gamma_sem * D_sem
    D_K = _norm(D_K)
    return D_K, D_struct, D_sem, X, node_ids

## Normalized Measure and Degree Centrality

In [10]:
def normalized_measure(n: int) -> np.ndarray:
    return np.full((n,), 1.0 / max(1, n), dtype=float)

def degree_centrality_measure(kg: Dict[str, Any], node_ids: List[str]) -> np.ndarray:
    G = build_kg_graph(kg)
    deg = np.array([G.degree(nid) for nid in node_ids], dtype=float)
    if deg.sum() <= 0:
        return normalized_measure(len(node_ids))
    return deg / deg.sum()

## Compute FGW Distance and Coupling

In [11]:
# -----------------------------
# FGW / Proxy Coupling
# -----------------------------

def compute_feature_cost(E_L: np.ndarray, E_K: np.ndarray) -> np.ndarray:
    # Squared Euclidean between embeddings; both are L2-normalized, so this ~ 2 - 2*cosine
    C = cdist(E_L, E_K, metric="sqeuclidean")
    # Normalize to [0,1]
    C = (C - C.min()) / (C.max() - C.min() + 1e-12)
    return C

def local_structure_fingerprint(D: np.ndarray, k: int = 5) -> np.ndarray:
    """
    For each row i in a distance matrix D (N x N), return a vector of the
    sorted distances to its k nearest neighbors (excluding self).
    """
    N = D.shape[0]
    fp = np.zeros((N, k), dtype=float)
    for i in range(N):
        row = D[i].copy()
        row[i] = np.inf
        idx = np.argsort(row)[:k]
        fp[i] = np.sort(row[idx])
    # Normalize per-column
    fp = (fp - fp.min(axis=0, keepdims=True)) / (fp.max(axis=0, keepdims=True) - fp.min(axis=0, keepdims=True) + 1e-12)
    return fp

def compute_proxy_cost(D_L: np.ndarray, D_K: np.ndarray,
                       E_L: np.ndarray, E_K: np.ndarray,
                       lam_feat: float = 0.6) -> np.ndarray:
    """
    Build a surrogate cost for OT that mixes:
      - feature cost between embeddings
      - local structure fingerprints between spaces
    """
    C_feat = compute_feature_cost(E_L, E_K)

    # ensure the k is no greater than the mininum number of elements in
    # either lecture notes or the kg
    min_k = min(D_L.shape[0]-1, D_K.shape[0]-1)

    fp_L = local_structure_fingerprint(D_L, k=min(min_k, max(1, D_L.shape[0]-1)))
    fp_K = local_structure_fingerprint(D_K, k=min(min_k, max(1, D_K.shape[0]-1)))
    #print("fp_L shape: {}".format(fp_L.shape))
    #print("fp_K shape: {}".format(fp_K.shape))

    # Structure penalty between local fingerprints
    C_struct = cdist(fp_L, fp_K, metric="sqeuclidean")
    C_struct = (C_struct - C_struct.min()) / (C_struct.max() - C_struct.min() + 1e-12)
    C = lam_feat * C_feat + (1.0 - lam_feat) * C_struct
    C = (C - C.min()) / (C.max() - C.min() + 1e-12)
    return C

def sinkhorn_ot(mu: np.ndarray, nu: np.ndarray, C: np.ndarray, eps: float = 0.05, n_iter: int = 300) -> np.ndarray:
    """
    Simple entropic OT via Sinkhorn-Knopp on kernel K = exp(-C/eps).
    Returns coupling matrix P (mu x nu).
    """
    K = np.exp(-C / max(1e-8, eps))
    u = np.ones_like(mu)
    v = np.ones_like(nu)
    for _ in range(n_iter):
        u = mu / (K @ v + 1e-12)
        v = nu / (K.T @ u + 1e-12)
    P = np.diag(u) @ K @ np.diag(v)
    return P

def fgw_distance_proxy(P: np.ndarray,
                       D_L: np.ndarray, D_K: np.ndarray,
                       E_L: np.ndarray, E_K: np.ndarray,
                       lam_feat: float = 0.6) -> Tuple[float, float, float]:
    """
    Compute an FGW-like loss using current coupling P:
      d = sum |D_L[i,k] - D_K[j,l]|^2 * P[i,j]*P[k,l] + lam_feat * sum ||E_L[i]-E_K[j]||^2 * P[i,j]
    Returns:
      total_loss, structure_term, feature_term
    """
    # Structure term
    DL = D_L
    DK = D_K
    # Efficient einsum formulation
    # A[i,k] = D_L[i,k]^2
    # B[j,l] = D_K[j,l]^2
    # structure term = sum_{i,k,j,l} (D_L[i,k] - D_K[j,l])^2 * P[i,j] * P[k,l]
    # = sum A[i,k]*P[i,:].sum_j P[k,:].sum_l + sum B[j,l]*P[:,j].sum_i P[:,l].sum_k - 2 * sum D_L[i,k] D_K[j,l] P[i,j] P[k,l]
    A = DL ** 2
    B = DK ** 2
    Pi_row = P.sum(axis=1)  # size N
    Pi_col = P.sum(axis=0)  # size M

    term1 = (A * np.outer(Pi_row, Pi_row)).sum()
    term2 = (B * np.outer(Pi_col, Pi_col)).sum()
    # Cross term
    term3 = 0.0
    # Compute M1 = D_L @ P @ D_K.T @ P.T ??
    # We need sum_{i,k,j,l} D_L[i,k] * D_K[j,l] * P[i,j] * P[k,l]
    # This equals trace(D_L^T (P D_K P^T))
    # Compute M = P @ D_K @ P.T  -> size N x N
    M = P @ DK @ P.T
    term3 = np.sum(DL * M)
    structure = term1 + term2 - 2.0 * term3

    # Feature term
    # ||E_L[i] - E_K[j]||^2 = a + b - 2 <...>, but we can compute directly
    C_feat = cdist(E_L, E_K, metric="sqeuclidean")
    feature = float((C_feat * P).sum())
    total = structure + lam_feat * feature
    return total, structure, feature

def compute_coupling_and_distance(D_L: np.ndarray, D_K: np.ndarray,
                                  E_L: np.ndarray, E_K: np.ndarray,
                                  mu: np.ndarray, nu: np.ndarray,
                                  lam_feat: float, sink_eps: float, sink_iter: int,
                                  use_pot: bool) -> Tuple[np.ndarray, float, float, float]:
    """
    Returns:
      P (coupling), total_fgw_like, structure_term, feature_term
    """
    if use_pot and _HAS_POT:
        print("Use POT.")
        # Exact FGW using POT
        # Feature matrices are E_L and E_K; costs are cosine or Euclidean
        # POT expects feature cost matrices; we give squared Euclidean
        C1 = D_L
        C2 = D_K
        M_feat = cdist(E_L, E_K, metric="sqeuclidean")
        # FGW coupling
        #P = fused_gromov_wasserstein(M_feat, C1, C2, mu, nu, alpha=lam_feat,
        # verbose=False)
        # FGW loss
        #total = float(fgw_loss(P, M_feat, C1, C2, lam_feat))

        P, log = fused_gromov_wasserstein(M_feat, C1, C2, mu, nu,
                                              loss_fun="square_loss",
                                               alpha=lam_feat, log=True)
        total = float(log.get("fgw_dist", np.nan))
        # Decompose approximately (for reporting)
        # We recompute proxy structure/feature with same P for interpretability
        total_proxy, structure, feature = fgw_distance_proxy(P, D_L, D_K, E_L, E_K, lam_feat)
        return P, total, structure, feature
    else:
        print("Use Proxy OT.")
        # Proxy: build surrogate cost and do entropic OT
        C = compute_proxy_cost(D_L, D_K, E_L, E_K, lam_feat)
        P = sinkhorn_ot(mu, nu, C, eps=sink_eps, n_iter=sink_iter)
        total, structure, feature = fgw_distance_proxy(P, D_L, D_K, E_L, E_K, lam_feat)
        return P, total, structure, feature

## LLM Assisted Concept and Relationship Generation

In [12]:
import os
from dotenv import load_dotenv
load_dotenv()

True

### LLM Helper OpenAI

In [13]:
# ---------------------------
# LLM helper (optional)
# ---------------------------
def llm_call(prompt: str, model: str = "gpt-4o-mini", max_tokens: int = 700, temperature: float = 0.2) -> str:
    """Call OpenAI chat completion. If no key or pkg, raises RuntimeError."""
    api_key = os.getenv("OPENAI_API_KEY")
    #api_key = userdata.get('OPENAI_API_KEY')
    if not api_key:
        raise RuntimeError("OPENAI_API_KEY not set")

    try:
        import openai  # type: ignore
        client = openai.OpenAI(api_key=api_key)
    except Exception as e:
        raise RuntimeError(f"OpenAI client not available: {e}")

    try:
        resp = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a precise assistant for ontology construction."},
                {"role": "user", "content": prompt},
            ]
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        raise RuntimeError(f"LLM call failed: {e}")


## LLM Helper Gemini

In [14]:
import google.generativeai as genai

api_key = os.getenv("GOOGLE_API_KEY")
#api_key = userdata.get('GOOGLE_API_KEY') # Assuming this is how you access secrets
if not api_key:
    raise RuntimeError("GOOGLE_API_KEY not set")

try:
    model: str = "gemini-2.5-flash-lite"
    genai.configure(api_key=api_key)
    # For chat-based models, you'd typically use genai.GenerativeModel
    # and start a chat session. For simple prompt-response, direct generate_content works.
    gemini_model = genai.GenerativeModel(model)
except Exception as e:
    raise RuntimeError(f"Gemini client not available: {e}")

In [15]:
def llm_call_gemini(prompt: str, model_instance=gemini_model,
    max_output_tokens=1000, temperature=0.2) -> str:
    """Call Google Gemini chat completion. If no key or pkg, raises RuntimeError."""
    try:
        # Gemini's generate_content takes a prompt directly
        resp = model_instance.generate_content(
            contents=[
                {"role": "user", "parts": [
                    "You are a precise assistant for ontology construction.",
                    prompt
                ]}
            ]
        )
        # Accessing the text content from the response
        return resp.text.strip()
    except Exception as e:
        raise RuntimeError(f"LLM call failed: {e}")

### LLM Concept Naming

In [16]:
# ---------------------------
# LLM concept naming
# ---------------------------
def make_concept_label_from_text_LLM(text: str) -> str:

    print("Use LLM to name new concept.")
    
    prompt = """
        Generate a single, meaningful concept that represents
        the main idea of the TEXT.
        The concept should be specific
        and represents the main phrase and terms in the TEXT.\n
        return ONLY the concept name. \n
        If the TEXT provided is empty and does not contain any content to
        derive a concept from, return ''.\n
        TEXT: {}.\n\n
        """


    #name = llm_call(prompt.format(text)).strip()

    name = llm_call_gemini(prompt.format(text)).strip()

    return name

def make_edges_between_text_LLM(kg: Dict[str, Any], new_node,
                               allowed_relations: List[str]):

    print("Use LLM to add new edges for a new concept.")

    """
    Uses LLM to find and add the most likely relation(s) between a new node
    and existing nodes in the KG, based on concept proximity and allowed relations.
    This version attempts to make a single LLM call for efficiency.
    """
    new_node_info = f"New Concept: \"{new_node['label']}\" (Definition: \"{new_node['definition']}\")"
    existing_nodes_info = "\nExisting Concepts:\n" + "\n".join([
        f"- \"{n['label']}\" (Definition: \"{n.get('definition', '')}\") [ID: {n['id']}]"
        for n in kg["nodes"] if n["id"] != new_node['id']
    ])

    prompt = f"""
    Given the new node and existing nodes below, identify the most likely relationship(s)
    from the NEW NODE to each EXISTING NODE from the following list of allowed relations:
    {', '.join(allowed_relations)}.
    For each relevant relationship, output the original EXISTING NODE's ID, the relation name, and a
    probability score between 0 and 1, and rationale about adding this edge,
    separated by colons, one relationship per line.
    Example output format:
    [EXISTING_NODE_ID]: [RELATION_NAME]: [PROBABILITY]: [RATIONALE]
    [EXISTING_NODE_ID]: [RELATION_NAME]: [PROBABILITY]: [RATIONALE]
    ...

    If no suitable relation is found for an existing concept with a high probability,
    do not include it in the output.

    {new_node_info}
    {existing_nodes_info}

    Output:
    """

    candidate_edges = []
    try:
        #llm_response = llm_call(prompt, max_tokens=5000).strip() # Increase max_tokens

        llm_response = llm_call_gemini(prompt, max_output_tokens=5000).strip() # Increase max_tokens
        # for potentially longer output Parse the LLM response
        for line in llm_response.splitlines():
            parts = line.split(':')
            if len(parts) == 4:
                target_id = parts[0].strip()
                llm_relation = parts[1].strip()
                
                try:
                    llm_confidence = float(parts[2].strip())
                except ValueError:
                    llm_confidence = 0.5 # Default confidence if parsing fails
                
                llm_rationale = parts[3].strip()
                
                # Validate target_id and relation
                if any(n["id"].lower() == target_id.lower() for n in kg["nodes"]) and llm_relation in allowed_relations:
                    
                    # get edge defintion
                    e_def = get_edge_definition(kg, new_node['id'], target_id, llm_relation)
                    
                    candidate_edges.append({
                        "source": new_node['id'],
                        "target": target_id,
                        "relation": llm_relation,
                        "definition": e_def,
                        "confidence": llm_confidence,
                        "rationale": llm_rationale
                    })
                else:
                     print(f"Warning: Invalid target_id or relation from LLM: {line}.")

    except RuntimeError as e:
        print(f"LLM call failed for relation prediction: {e}")
        # If LLM fails, add a default 'relatedTo' edge to all existing nodes with low confidence
        # This is a fallback to ensure some connectivity, adjust confidence as needed.
        # for existing_node in kg["nodes"]:
        #     if existing_node["id"] != new_node['id']:
        #         candidate_edges.append({
        #             "source": new_node['id'],
        #             "target": existing_node["id"],
        #             "relation": "relatedTo",
        #             "confidence": 0.2
        #         })


    return candidate_edges

## Task-Oriented All-Allowed Relations

In [17]:
_ALLOWED_RELATIONS_DEFAULT = [
    "isA","partOf","prerequisiteOf","dependsOn","relatedTo","synonymOf","antonymOf",
    "contrastsWith","defines","uses","usedBy","appliesTo","exampleOf","counterexampleOf",
    "illustratedBy","causes","resultsIn","prevents","assumes","implies","equivalentTo",
    "parameterOf","hasParameter","propertyOf","hasProperty","measuredBy","unitOf",
    "representedBy","notationFor","formulaFor","provedBy","theoremOf","algorithmFor",
    "stepOf","produces","consumes","advantageOf","limitationOf","commonErrorIn",
    "misconceptionOf","commonlyConfusedWith","assessedBy"
]

## TF-IDF Top Terms Generation

In [18]:
def tfidf_top_terms(texts: List[str], topk: int = 3) -> List[str]:
    """
    Return the top-k terms by mean TF-IDF across documents.
    Robust to empty/stopword-only inputs and None/NaN values.
    """
    if not texts:
        return []

    # Coerce to strings and strip
    cleaned = [(t if isinstance(t, str) else "").strip() for t in texts]
    if all(not t for t in cleaned):
        return []

    # First attempt: standard English stopwords
    vec = TfidfVectorizer(max_features=4096, stop_words="english")
    try:
        X = vec.fit_transform(cleaned)
    except ValueError:

        return []

    # Mean TF-IDF per term across docs
    scores = np.asarray(X.mean(axis=0)).ravel()

    # Top-k indices (safe if topk > n_features)
    k = min(topk, scores.size)
    if k == 0:
        return []

    # Use argpartition for efficiency, then sort those k by score desc
    top_idx = np.argpartition(-scores, range(k))[:k]
    top_idx = top_idx[np.argsort(-scores[top_idx])]

    feature_names = vec.get_feature_names_out()
    return [feature_names[i] for i in top_idx]

## Rate, KL_Divergence, Coupling_Entropy, and Objective

In [19]:
def rate_complexity(kg: Dict[str, Any]) -> float:
    n = len(kg.get("nodes", []))
    m = len(kg.get("edges", []))
    return float(n) + 0.5 * float(m)

def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
    # Symmetrized KL
    p = p + 1e-12
    q = q + 1e-12
    kl_pq = float(np.sum(rel_entr(p, q)))
    kl_qp = float(np.sum(rel_entr(q, p)))
    return 0.5 * (kl_pq + kl_qp)

def coupling_entropy(col: np.ndarray) -> float:
    p = col / (col.sum() + 1e-12)
    p = p + 1e-12
    return float(-np.sum(p * np.log(p)))

def compute_objective(kg: Dict[str, Any], beta: float, fgw_total: float) -> float:
    return rate_complexity(kg) + beta * fgw_total

## KG Refinement Operations

In [20]:
# -----------------------------
# KG Refinement Operations
# -----------------------------

"""
def make_concept_label_from_text(text: str) -> str:
    # Simple heuristic: TF-IDF top terms or fallback to first 5 words
    toks = re.findall(r"[A-Za-z_][A-Za-z0-9_]+", text)[:20]
    if len(toks) >= 3:
        return " ".join(toks[:3]).lower()
    return (text[:30] + "...").strip()
"""

def add_concept_from_element(kg: Dict[str, Any],
                             element: LectureElement,
                             embedder: Embedder,
                             allowed_relations: List[str],
                             connect_to: Optional[str] = None,
                             relation: str = "relatedTo") -> KGNode:
    label = make_concept_label_from_text_LLM(element.text)

    if label != "''" and label != '""' or (not label):
        node_id = re.sub(r"[^a-zA-Z0-9_]+", "_", label).strip("_")
        if not node_id:
            node_id = f"new_node_{len(kg['nodes'])+1}"
        # Ensure unique
        base = node_id
        c = 1
        existing_ids = {n["id"] for n in kg["nodes"]}
        while node_id in existing_ids:
            c += 1
            node_id = f"{base}_{c}"
        node = {
            "id": node_id,
            "label": label,
            "type": "Concept",
            "definition": element.text,
            "aliases": [],
            "attributes": {},
            "provenance": [{
                "section_path": element.section_path,
                "line_start": element.idx,
                "line_end": element.idx,
                "text_excerpt": element.text
            }],
            "confidence": 0.6,
            "rationale": "Added to reduce distortion; under-represented lecture content."
        }
        kg["nodes"].append(node)

        # Use LLM to make edges for the new node
        candidate_edges = make_edges_between_text_LLM(kg, node, allowed_relations)

        edges = []

        # Rank candidate edges by confidence and add the top one(s)
        if candidate_edges:
            candidate_edges.sort(key=lambda x: x["confidence"], reverse=True)
            top_confidence = candidate_edges[0]["confidence"]
            edges_to_add = [
                # Consider ties within 5%
                edge for edge in candidate_edges if edge["confidence"] >= top_confidence * 0.95
            ]

            for edge_info in edges_to_add:
                edge = {
                    "id": f"e_{edge_info['source']}_{edge_info['relation']}_{edge_info['target']}",
                    "source": edge_info['source'],
                    "target": edge_info['target'],
                    "relation": edge_info['relation'],
                    "definition": edge_info['definition'],
                    "provenance": [{
                        "section_path": element.section_path,
                        "line_start": element.idx,
                        "line_end": element.idx,
                        "text_excerpt": element.text
                    }],
                    "confidence": edge_info['confidence'],
                    "rationale": edge_info['rationale']
                }
                # Avoid duplicates
                edge_sig = (edge["source"].lower(), edge["target"].lower(), edge["relation"].lower())
                if not any((e["source"].lower(), e["target"].lower(), e["relation"].lower()) == edge_sig for e in kg["edges"]):
                    kg["edges"].append(edge)
                    edges.append(edge)
        else:
            if connect_to and relation in allowed_relations:
                # get edge definition
                e_def = get_edge_definition(kg, node_id, connect_to, relation)
                
                edge = {
                    "id": f"e_{node_id}_{relation}_{connect_to}",
                    "source": node_id,
                    "target": connect_to,
                    "relation": relation,
                    "definition": e_def,
                    "provenance": [{
                        "section_path": element.section_path,
                        "line_start": element.idx,
                        "line_end": element.idx,
                        "text_excerpt": element.text
                    }],
                    "confidence": 0.6,
                    "rationale": "Heuristic relation based on semantic proximity."
                }
                # Avoid duplicates
                edge_sig = (edge["source"].lower(), edge["target"].lower(), edge["relation"].lower())
                if not any((e["source"].lower(), e["target"].lower(), e["relation"].lower()) == edge_sig for e in kg["edges"]):
                    kg["edges"].append(edge)
                    edges.append(edge)

        return KGNode(**node), edges

    else:
        return None, None

def kmeans_split_concept(col: np.ndarray,
                         E_L: np.ndarray,
                         elements: List[LectureElement],
                         kg: Dict[str, Any],
                         node_idx: int) -> Tuple[Optional[KGNode], Optional[KGNode], List[int], List[int]]:
    """
    Split a concept into two based on the embedding clusters of its coupled
    lecture elements.
    Returns new nodes (as dicts) and index sets assigned to each.
    """
    weights = col
    sel = np.where(weights > weights.mean())[0]
    if len(sel) < 4:
        return None, None, [], []
    X = E_L[sel]
    try:
        km = KMeans(n_clusters=2, n_init=10, random_state=42)
        labs = km.fit_predict(X)
    except Exception:
        return None, None, [], []

    node = kg["nodes"][node_idx]
    base_id = node["id"]

    a_idx = [int(sel[i]) for i in range(len(sel)) if labs[i] == 0]
    b_idx = [int(sel[i]) for i in range(len(sel)) if labs[i] == 1]

    # Concatenate the text of a_idx and b_idx
    text_a = "\n".join([elements[i].text for i in a_idx])
    text_b = "\n".join([elements[i].text for i in b_idx])

    # Use LLM to make a label
    label_a = make_concept_label_from_text_LLM(text_a)
    if label_a == "''" or label_a == '""' or (not label_a):
        nA_txt = tfidf_top_terms([elements[i].text for i in a_idx], topk=3)
        label_a = (node["label"] + " " + " ".join(nA_txt)).strip()

    label_b = make_concept_label_from_text_LLM(text_b)
    if label_b == "''" or label_b == '""' or (not label_b):
        nB_txt = tfidf_top_terms([elements[i].text for i in b_idx], topk=3)
        label_b = (node["label"] + " " + " ".join(nB_txt)).strip()

    def _new(name_suffix, label, txt):

        new_id = re.sub(r"[^a-zA-Z0-9_]+", "_", label).strip("_")
        
        if not new_id:
            new_id = f"{base_id}_{name_suffix}"
            
        # Ensure unique
        base = new_id
        c = 1
        existing_ids = {n["id"] for n in kg["nodes"]}
        while new_id in existing_ids:
            c += 1
            new_id = f"{base}_{c}"
        
        new_node = dict(node)
        new_node["id"] = new_id
        new_node["label"] = label
        new_node["confidence"] = 0.55
        new_node["definition"] = txt
        new_node["provenance"][0]['text_excerpt'] = txt
        new_node["rationale"] = "Split concept with diverse coupling."
        return new_node

    A = _new("a", label_a, text_a)
    B = _new("b", label_b, text_b)
    return KGNode(**A), KGNode(**B), a_idx, b_idx

def merge_if_redundant(i: int, j: int,
                       P: np.ndarray,
                       E_K: np.ndarray,
                       sim_thresh: float,
                       kl_thresh: float) -> bool:
    """
    Return True if concepts i and j should be merged (high semantic sim & low KL).
    """
    v1 = E_K[i]; v2 = E_K[j]
    cos = float(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-12))
    if cos < sim_thresh:
        return False
    pi = P[:, i]; pj = P[:, j]
    kl = kl_divergence(pi, pj)
    return kl < kl_thresh


## KG Final Refinement by Domain Knowledge

In [21]:
def _extract_json_list(text: str) -> str:
    """
    Return a JSON string representing a list.
    Handles raw JSON or fenced code blocks like:
    ```json
    [ ... ]
    ```
    """
    # Try direct parse first
    try:
        obj = json.loads(text)
        if isinstance(obj, list):
            return json.dumps(obj)
    except Exception:
        pass

    # Look for fenced block ```json ... ``` or ``` ...
    fence_match = re.search(r"```(?:json)?\s*(\[.*?\])\s*```", text, flags=re.DOTALL | re.IGNORECASE)
    if fence_match:
        return fence_match.group(1)

    # Fallback: capture first bracketed list (greedy but safe-ish)
    bracket_match = re.search(r"(\[.*\])", text, flags=re.DOTALL)
    if bracket_match:
        return bracket_match.group(1)

    raise ValueError("No JSON array found in LLM output.")

def apply_domain_knowledge_LLM(kg: Dict[str, Any], allowed_relations: List[str]) -> List[KGEdge]:
    """
    Refine the knowledge graph by adding missing edges based on domain knowledge
    encoded in the KG, using an LLM guided by a structured prompt.

    Args:
        kg: Knowledge graph as a dict with nodes and edges.
        allowed_relations: List of allowed relation types.

    Returns:
        A list of KGEdge objects proposed by the LLM.
    """

    # Prompt template
    prompt_template = """
    You are an expert Knowledge Graph Refinement Agent.
    Your goal is to propose NEW edges that make it easier to generate high-quality
    multiple-choice questions (MCQs) that produce
    high-quality distractors and har-to-guess answers. 

    ### Inputs
    1. Knowledge Graph (KG): {kg_json}
    2. Allowed Relations: {allowed_relations}

    ### Task Focus
    - Prioritize the edges that help build strong distractors.
    - Typical contrastive relations include: contrastsWith, antonymOf, differentFrom, opposes, distinguishesFrom, versus/vs.
    - Only use relation types that are in the Allowed Relations list.

    ### Evidence Constraints
    - Base your proposals STRICTLY on domain knowledge already encoded in the KG
      (node labels/definitions, existing relations, hierarchies, examples, contrasts).
    - Do NOT invent facts beyond what the KG implies.

    ### Quality Bar (MCQ Utility)
    For each proposed edge:
    - It should enable question writers to craft MCQs where wrong options 
    are **plausible** yet **incorrect**.
    - Prefer pairs that are commonly confused or that share overlapping 
      properties but differ on key dimensions.
    - Avoid trivial or redundant contrasts.
    
    ### Output
    - A list of new edges in the following structure:
    e = {{
        "id": "<unique_id>",
        "source": "<src_node_id>",
        "target": "<dst_node_id>",
        "relation": "<relation_type>",
        "definition": "<edge_definition>",
        "provenance": [],
        "confidence": <number between 0 and 1>,
        "rationale": "<explanation of why this relation was added>"
    }}

    Do not repeat existing edges. Only include new, refined edges.
    Ensure src_node_id and dst_node_id are existing nodes in the knowledge graph.

    ---

    ### Chain-of-Thought Refinement Instructions

    1. **Deconstruct the KG**
       - Read all nodes and existing edges.
       - Note the domain knowledge explicitly encoded in the graph 
         (hierarchies, prerequisites, usage patterns, examples, contrasts).

    2. **Diagnose Gaps**
       - Identify pairs of nodes that appear semantically, hierarchically, or logically related but lack an explicit relation.
       - Use only domain knowledge present in the KG (structure, relations, node labels) to justify candidate edges.

    3. **Develop Candidate Edges**
       - For each missing link:
         - Choose the most appropriate relation type from the allowed set.
         - Ensure the edge does not already exist.

    4. **Reason Explicitly**
       - For each new edge, provide a rationale:
         - Why the two nodes are connected.
         - Why this relation type is correct.
         - Why the confidence score (0–1) is appropriate.

       Example reasoning:
       Nodes: "Linear Regression" → "Gradient Descent"
       Evidence: In the KG, Gradient Descent is used to optimize models; Linear Regression requires optimization.
       Relation: uses
       Confidence: 0.9
       Rationale: Domain knowledge in the KG indicates optimization is part of regression training.

    5. **Deliver Final List**
       - Output only the list of new edge objects in the specified structure.
       - Each edge must include id, source_id, target_id, relation, provenance, confidence, rationale.
       - Ensure the source_id and target_id are in the knowledge graph.
    """

    # Fill in the placeholders
    prompt = prompt_template.format(
        kg_json=json.dumps(kg, indent=2),
        allowed_relations=allowed_relations
    )

    # ---- Call your LLM here ----
    # Replace this with your actual LLM call (OpenAI, Anthropic, etc.)
    # For example, if using OpenAI:
    # response = openai.ChatCompletion.create(
    #     model="gpt-4o",
    #     messages=[{"role": "user", "content": prompt}]
    # )
    # raw_output = response.choices[0].message["content"]

    #raw_output = llm_call(prompt)  # <-- Replace with your LLM wrapper

    raw_output = llm_call_gemini(prompt, max_output_tokens=50000)

    # Try parsing JSON safely
    json_str = _extract_json_list(raw_output)
    try:
        edges: List[KGEdge] = json.loads(json_str)
    except json.JSONDecodeError:
        print("LLM output could not be parsed as JSON:\n" + json_str)
        return None

    return edges


## Add the Edges Refined by Domain Knowledge LLM

In [22]:
def add_refined_domain_edges(kg: Dict[str, Any], edges: List[KGEdge]) -> Dict[str, Any]:
    """
    - Ensure each edge's source/target node exists by id; if not, create a new node.
    - New nodes follow your requested schema and use the edge's rationale.
    - Append normalized edges to kg['edges'] (avoid duplicates).
    """

    added_node_counter = 0
    added_edge_counter = 0
    
    kg.setdefault("nodes", [])
    kg.setdefault("edges", [])

    def node_exists(node_id: str) -> bool:
        nid = node_id.lower()
        return any(str(n.get("id", "")).lower() == nid for n in kg["nodes"])

    def add_node_from_edge(node_id: str, edge: KGEdge) -> None:
        kg["nodes"].append({
            "id": node_id,
            "label": node_id,  # per requirement: use id for label
            "type": "Concept",
            "definition": edge.get("rationale", "") or "",
            "aliases": [],  # you didn't request aliases here
            "provenance": [{"text": edge.get("rationale", "") or ""}],
            "attributes": {},
            "confidence": float(edge.get("confidence", 0.5) or 0.5),
            "rationale": "Add unmatched node by domain knowledge from LLM"
        })

    def edge_exists(src_id: str, rel: str, dst_id: str) -> bool:
        s, r, t = src_id.lower(), rel, dst_id.lower()
        return any(
            str(e.get("source", "")).lower() == s and
            str(e.get("relation", "")) == r and
            str(e.get("target", "")).lower() == t
            for e in kg["edges"]
        )

    for e in edges:
        rel = str(e.get("relation", "")).strip()
        if not rel:
            continue

        src_id = str(e.get("source", "")).strip()
        dst_id = str(e.get("target", "")).strip()
        if not src_id or not dst_id:
            continue

        e_def = str(e.get("definition", "")).strip()
        if not e_def:
            continue

        if not node_exists(src_id):
            add_node_from_edge(src_id, e)
            added_node_counter += 1
        if not node_exists(dst_id):
            add_node_from_edge(dst_id, e)
            added_node_counter += 1

        # Normalize provenance to list[dict] with "text" if needed
        norm_prov = []
        norm_prov.append({"text_excerpt": e.get("rationale", "")})

        if not edge_exists(src_id, rel, dst_id):
            kg["edges"].append({
                "id": str(e.get("id", "")) or f"e_{abs(hash(src_id + '::' + rel + '::' + dst_id))}",
                "source": src_id,
                "target": dst_id,
                "relation": rel,
                "definition": e_def,
                "provenance": norm_prov,
                "confidence": float(e.get("confidence", 0.5) or 0.5),
                "rationale": str(e.get("rationale", "")).strip(),
            })

            added_edge_counter += 1

    return added_node_counter, added_edge_counter

# Pipeline

In [52]:
from pathlib import Path

lecture_json_path = "./data/lecture_notes_8.json" #type=str, required=True, help="Path to lecture notes JSON")
kg_json_path = "./data/kg8.json" #type=str, required=True, help="Path to initial KG JSON")
out_dir = "./data/output" #type=str, required=True, help="Output directory")
#beta = 20.0 #type=float, default=20.0)
beta = 100.0
alpha = "0.2,0.3,0.5" #help="alpha weights for lecture distance (chron,logic,semantic)"
gamma = "0.4,0.6" #help="gamma weights for KG distance (struct,semantic)")
lambda_feat = 0.6 #help="FGW feature balance (alpha in POT)")
max_iterations = 10
convergence_threshold = 0.01
#theta_add = 0.06
theta_add = 0.02
theta_split = 0.35
theta_merge = 0.12
theta_relate = 0.25
sinkhorn_eps = 0.05
sinkhorn_iter = 300
disable_domain_rules = True
no_pot =  False #help="Disable POT even if installed (use proxy)")
seed = 42


try:
    a_chron, a_logic, a_sem = [float(x) for x in alpha.split(",")]
    g_struct, g_sem = [float(x) for x in gamma.split(",")]
except Exception:
    raise SystemExit("Please provide --alpha like '0.2,0.3,0.5' and --gamma like '0.4,0.6'")

hp = HyperParams(
    alpha_chron=a_chron,
    alpha_logic=a_logic,
    alpha_sem=a_sem,
    gamma_struct=g_struct,
    gamma_sem=g_sem,
    lambda_feat=lambda_feat,
    beta=beta,
    theta_add=theta_add,
    theta_split=theta_split,
    theta_merge=theta_merge,
    theta_relate=theta_relate,
    max_iterations=max_iterations,
    convergence_threshold=convergence_threshold,
    sinkhorn_eps=sinkhorn_eps,
    sinkhorn_iter=sinkhorn_iter
)

use_pot = (not no_pot) and _HAS_POT
if use_pot:
    print("Using POT for exact FGW: %s", use_pot)
else:
    print("Use POT proxy.")

ensure_dir(out_dir)

Using POT for exact FGW: %s True


In [25]:
elements = parse_lecture_elements(lecture_json_path)
elements

[LectureElement(idx=2, id_path='1.1', section_path=['__ROOT__', 'Data Science Programming', 'Week 9: Lecture 1: Time Series Data Analysis'], text='apply techiques to time series data', typ='text'),
 LectureElement(idx=3, id_path='2', section_path=['__ROOT__', 'Time Series'], text='time series data is an important form of structured data in many different fields, such', typ='text'),
 LectureElement(idx=4, id_path='2', section_path=['__ROOT__', 'Time Series'], text='as finance, economics, ecology, neuroscience, and physics. anything that is observed', typ='text'),
 LectureElement(idx=5, id_path='2', section_path=['__ROOT__', 'Time Series'], text='or measured at many points in time forms a time series. many time series are fixed', typ='text'),
 LectureElement(idx=6, id_path='2', section_path=['__ROOT__', 'Time Series'], text='frequency, which is to say that data points occur at regular intervals according to some', typ='text'),
 LectureElement(idx=7, id_path='2', section_path=['__ROOT__',

In [26]:
# Build embeddings (fit once on lecture + KG text to align spaces for TF-IDF fallback)
embedder = Embedder()
lecture_texts = [e.text for e in elements]
E_L = embedder.fit_transform(lecture_texts)
D_L, D_chron, D_logic, D_semL = build_lecture_distance(
    elements, E_L, hp.alpha_chron, hp.alpha_logic, hp.alpha_sem
)

Use SentenceTransformer.


## Refinement Starts Here!!

In [27]:
kg_history = {} # store the individual knowledge graph at each step

In [28]:
# Load inputs
#kg = read_json(kg_json_path)
kg = read_kg(kg_json_path)

if "allowed_relations" not in kg.get("meta", {}):
    if "meta" not in kg:
        kg["meta"] = {}
    kg["meta"]["allowed_relations"] = _ALLOWED_RELATIONS_DEFAULT.copy()

allowed_relations = _ALLOWED_RELATIONS_DEFAULT.copy()

In [29]:
len(kg["nodes"]), len(kg['edges'])

(5, 4)

In [30]:
from copy import deepcopy
kg_history[0] = deepcopy(kg)

In [31]:
len(kg_history[0]['nodes']), len(kg_history[0]['edges'])

(5, 4)

In [32]:
D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)

# Measures
mu = normalized_measure(len(elements))
nu = degree_centrality_measure(kg, node_ids)

Use SentenceTransformer.


In [33]:
# Initial coupling and distance
P, fgw_total, struct_term, feat_term = compute_coupling_and_distance(
    D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
)

Use POT.


In [34]:
best_KG = json.loads(json.dumps(kg))
best_P = P.copy()
best_fgw = fgw_total
best_L = compute_objective(kg, hp.beta, best_fgw)

history = [{
    "iter": 0,
    "rate": rate_complexity(kg),
    "fgw_total": best_fgw,
    "fgw_structure": struct_term,
    "fgw_feature": feat_term,
    "objective": best_L,
    "ops": {"add":0,"split":0,"merge":0,"edge_add":0,"edge_remove":0}
}]

ops_counters = collections.Counter()

In [35]:
history

[{'iter': 0,
  'rate': 7.0,
  'fgw_total': 0.6135204083965918,
  'fgw_structure': 0.13721822574533182,
  'fgw_feature': 1.3279736823734816,
  'objective': 68.35204083965918,
  'ops': {'add': 0, 'split': 0, 'merge': 0, 'edge_add': 0, 'edge_remove': 0}}]

In [36]:
# Iterative refinement
from tqdm import tqdm

import time

for it in tqdm(range(1, hp.max_iterations + 1)):
    improved = False
    local_ops = collections.Counter()

    # --- Operation A: Add new concepts for under-represented content
    # mass over lecture elements
    row_mass = P.sum(axis=1)  # size N (lecture)
    add_candidates = np.where(row_mass < hp.theta_add)[0].tolist()
    random.shuffle(add_candidates)
    added_in_this_iter = 0
    added_edge_in_this_iter = 0
    for i in add_candidates[:5]:  # cap additions per iter
        # connect to nearest existing concept by semantic similarity
        v = E_L[i]
        sims = E_K @ v
        j = int(np.argmax(sims))
        connect_to = node_ids[j]
        added_node, added_edges = add_concept_from_element(
            kg, elements[i], embedder, allowed_relations,
            connect_to=connect_to,
            relation="relatedTo"
        )

        if added_node:
            added_in_this_iter += 1
        if added_edges:
            added_edge_in_this_iter += len(added_edges)

        time.sleep(0.5)

    local_ops["add"] += added_in_this_iter
    local_ops["edge_add"] += added_edge_in_this_iter

    # Rebuild KG side after additions (embeddings, distances, measures)
    D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)
    nu = degree_centrality_measure(kg, node_ids)

    # --- Operation B: Split concepts with high coupling entropy
    # Recompute coupling to assess splits
    P, fgw_total_tmp, struct_term, feat_term = compute_coupling_and_distance(
        D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
    )
    cols_entropy = [coupling_entropy(P[:, j]) for j in range(P.shape[1])]
    split_targets = [j for j, H in enumerate(cols_entropy) if H > hp.theta_split]
    random.shuffle(split_targets)
    splits_done = 0
    for j in split_targets[:2]:  # cap splits per iter
        A, B, a_idx, b_idx = kmeans_split_concept(P[:, j], E_L, elements, kg, j)
        if A is None:
            continue
        # Commit split: remove old node j and add A,B
        old_id = kg["nodes"][j]["id"]
        # Add A, B
        kg["nodes"].append(dataclasses.asdict(A))
        kg["nodes"].append(dataclasses.asdict(B))
        # Rewire edges connected to old_id: duplicate to A and B
        new_edges = []
        for e in kg["edges"]:
            if e["source"] == old_id:
                for nid in [A.id, B.id]:
                    new_e = dict(e); new_e["source"] = nid; new_edges.append(new_e)
            elif e["target"] == old_id:
                for nid in [A.id, B.id]:
                    new_e = dict(e); new_e["target"] = nid; new_edges.append(new_e)
            else:
                new_edges.append(e)
        kg["edges"] = new_edges
        # Remove old node
        del kg["nodes"][j]
        splits_done += 1
    local_ops["split"] += splits_done

    # --- Operation C: Merge redundant concept pairs
    # Rebuild KG matrices after splits
    D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)
    nu = degree_centrality_measure(kg, node_ids)
    P, fgw_total_tmp, struct_term, feat_term = compute_coupling_and_distance(
        D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
    )
    to_merge = []
    m = len(node_ids)
    for i in range(m):
        for j in range(i+1, m):
            if merge_if_redundant(i, j, P, E_K, sim_thresh=0.92, kl_thresh=hp.theta_merge):
                to_merge.append((i, j))
    # Commit merges (greedy, no transitive handling beyond sequential)
    merges_done = 0
    merged_ids = set()
    for (i, j) in to_merge[:2]:  # cap merges per iter
        if i in merged_ids or j in merged_ids:
            continue
        id_i = kg["nodes"][i]["id"]; id_j = kg["nodes"][j]["id"]
        # Merge j into i: keep i's label/def; absorb aliases
        ni = kg["nodes"][i]; nj = kg["nodes"][j]
        ali = set(ni.get("aliases", []) + [nj.get("label", "")] + nj.get("aliases", []))
        ni["aliases"] = sorted([a for a in ali if a])
        # Rewire edges from j to i
        for e in kg["edges"]:
            if e["source"] == id_j: e["source"] = id_i
            if e["target"] == id_j: e["target"] = id_i
        # Remove node j
        del kg["nodes"][j]
        merges_done += 1
        merged_ids.add(i); merged_ids.add(j)
    local_ops["merge"] += merges_done

    # --- Operation D: Relationship updates (add/rewire/remove)
    # Add edges based on coupling + D_L proximity
    D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)
    nu = degree_centrality_measure(kg, node_ids)
    P, fgw_total_tmp, struct_term, feat_term = compute_coupling_and_distance(
        D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
    )

    # Add relations where coupled lecture elements are near
    edge_adds = 0
    N, M = P.shape
    # For each concept j, find top lecture elements and propose relations with 
    # other concepts k
    for j in range(M):
        top_i = np.argsort(-P[:, j])[:5]
        for k in range(M):
            if j == k:
                continue
            # If top elements for j are close to top elements for k in lecture space, 
            # relate
            top_k = np.argsort(-P[:, k])[:5]
            # Compute average cross-distance
            pairs = list(itertools.product(top_i, top_k))
            if not pairs:
                continue
            avg_d = float(np.mean([D_L[i1, i2] for i1, i2 in pairs]))
            if avg_d < hp.theta_relate:
                a = kg["nodes"][j]["id"]; b = kg["nodes"][k]["id"]
                sig = (a, b, "relatedTo")
                if not any((e["source"], e["target"], e["relation"]) == sig for e in kg["edges"]):
                    # get edge definition
                    e_def = get_edge_definition(kg, a, b, "relatedTo")

                    e_rationale = get_edge_rationale(kg, a, b, "relatedTo")
                    
                    kg["edges"].append({
                        "id": f"rel_{a}_{b}",
                        "source": a,
                        "target": b,
                        "relation": "relatedTo",
                        "definition": e_def,
                        "provenance": [],
                        "confidence": 0.6,
                        "rationale": e_rationale
                    })
                    edge_adds += 1
    local_ops["edge_add"] += edge_adds

    # Remove weak edges not supported by coupling
    edge_removes = 0
    keep_edges = []
    for e in kg["edges"]:
        try:
            a = node_ids.index(e["source"])
            b = node_ids.index(e["target"])
        except ValueError:
            # Node removed; drop edge
            edge_removes += 1
            continue
        # Support measure ~ product of marginals around a,b
        supp = float(P[:, a].sum() * P[:, b].sum())
        if supp < 1e-4:
            edge_removes += 1
            continue
        keep_edges.append(e)
    kg["edges"] = keep_edges
    local_ops["edge_remove"] += edge_removes

    # Domain rules (optional)
    #if not disable_domain_rules:
    #    added = apply_domain_rules(kg, allowed_relations)
    #    local_ops["edge_add"] += len(added)

    # Evaluate objective
    D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)
    nu = degree_centrality_measure(kg, node_ids)
    P, fgw_total, struct_term, feat_term = compute_coupling_and_distance(
        D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
    )
    r = rate_complexity(kg)
    L = r + hp.beta * fgw_total

    history.append({
        "iter": it,
        "rate": r,
        "fgw_total": fgw_total,
        "fgw_structure": struct_term,
        "fgw_feature": feat_term,
        "objective": L,
        "ops": dict(local_ops)
    })

    kg_history[it] = deepcopy(kg)
    print("The KG sizes at iteration {}: nodes:{}, edges:{}".format(it, 
                                                                    len(kg_history[it]['nodes']),
                                                                    len(kg_history[it]['edges'])))

    if L + 1e-9 < best_L:
        print("Reassign best values.")
        best_L = L
        best_KG = json.loads(json.dumps(kg))
        best_P = P.copy()
        best_fgw = fgw_total
        improved = True
        for k, v in local_ops.items():
            ops_counters[k] += v
    else:
        print("best values remain.")
        for k, v in local_ops.items():
            ops_counters[k] += v

    if not improved:
        # If no improvement, consider stopping early
        if abs(history[-1]["objective"] - history[-2]["objective"]) < hp.convergence_threshold:
            break


  0%|                                                                                                                            | 0/10 [00:00<?, ?it/s]

Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 10%|███████████▌                                                                                                        | 1/10 [00:14<02:11, 14.65s/it]

Use POT.
The KG sizes at iteration 1: nodes:12, edges:16
Reassign best values.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 20%|███████████████████████▏                                                                                            | 2/10 [00:30<02:01, 15.14s/it]

Use POT.
The KG sizes at iteration 2: nodes:19, edges:33
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.


 30%|██████████████████████████████████▊                                                                                 | 3/10 [00:53<02:11, 18.80s/it]

The KG sizes at iteration 3: nodes:26, edges:48
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 40%|██████████████████████████████████████████████▍                                                                     | 4/10 [01:20<02:12, 22.16s/it]

Use POT.
The KG sizes at iteration 4: nodes:33, edges:65
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 50%|██████████████████████████████████████████████████████████                                                          | 5/10 [01:40<01:46, 21.27s/it]

Use POT.
The KG sizes at iteration 5: nodes:39, edges:83
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 60%|█████████████████████████████████████████████████████████████████████▌                                              | 6/10 [02:05<01:30, 22.64s/it]

Use POT.
The KG sizes at iteration 6: nodes:46, edges:113
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 70%|█████████████████████████████████████████████████████████████████████████████████▏                                  | 7/10 [02:40<01:20, 26.78s/it]

Use POT.
The KG sizes at iteration 7: nodes:52, edges:139
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 80%|████████████████████████████████████████████████████████████████████████████████████████████▊                       | 8/10 [03:07<00:53, 26.88s/it]

Use POT.
The KG sizes at iteration 8: nodes:58, edges:176
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████▍           | 9/10 [03:35<00:26, 26.97s/it]

Use POT.
The KG sizes at iteration 9: nodes:64, edges:217
best values remain.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use LLM to name new concept.
Use LLM to add new edges for a new concept.
Use SentenceTransformer.
Use POT.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use LLM to name new concept.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.
Use POT.
Use SentenceTransformer.


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [04:00<00:00, 24.01s/it]

Use POT.
The KG sizes at iteration 10: nodes:69, edges:245
best values remain.





In [37]:
ops_counters

Counter({'edge_add': 114,
         'add': 50,
         'edge_remove': 27,
         'split': 19,
         'merge': 5})

In [38]:
len(kg['nodes']), len(kg['edges'])

(69, 245)

In [39]:
#+++++++++++++++++
# Domain knowledge graph refinement
#
#+++++++++++++++++
added_edges_domain_refinement = apply_domain_knowledge_LLM(kg, allowed_relations)

added_node_counter, added_edge_counter = add_refined_domain_edges(kg, added_edges_domain_refinement)


In [40]:
added_node_counter, added_edge_counter

(0, 35)

In [41]:
#+++++++++++++++++
# Domain knowledge graph refinement
# Update history
#+++++++++++++++++

it = it + 1

local_ops = collections.Counter()
local_ops['add'] += added_node_counter
local_ops['edge_add'] += added_edge_counter

# Evaluate objective
D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, embedder, hp.gamma_struct, hp.gamma_sem)
nu = degree_centrality_measure(kg, node_ids)

P, fgw_total, struct_term, feat_term = compute_coupling_and_distance(
    D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
)

r = rate_complexity(kg)
L = r + hp.beta * fgw_total

history.append({
    "iter": it,
    "rate": r,
    "fgw_total": fgw_total,
    "fgw_structure": struct_term,
    "fgw_feature": feat_term,
    "objective": L,
    "ops": dict(local_ops)
})

for k, v in local_ops.items():
    ops_counters[k] += v

Use SentenceTransformer.
Use POT.


In [42]:
ops_counters

Counter({'edge_add': 149,
         'add': 50,
         'edge_remove': 27,
         'split': 19,
         'merge': 5})

In [43]:
# Build outcome
outcome = RefinementOutcome(
    iterations=len(history)-1,
    final_objective=float(best_L),
    rate=float(rate_complexity(best_KG)),
    distortion=float(best_fgw),   # report main FGW term as 'distortion'
    fgw_distance=float(best_fgw),
    history=history,
    operations={
        "concepts_added": int(ops_counters.get("add", 0)),
        "merged": int(ops_counters.get("merge", 0)),
        "concepts_split": int(ops_counters.get("split", 0)),
        "relationships_added": int(ops_counters.get("edge_add", 0)),
        "relationships_removed": int(ops_counters.get("edge_remove", 0)),
    }
)

In [44]:
outcome.operations

{'concepts_added': 50,
 'merged': 5,
 'concepts_split': 19,
 'relationships_added': 149,
 'relationships_removed': 27}

In [45]:
# Coupling analysis (top-k)
# For interpretability, recompute matrices for best_KG
D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(best_KG, embedder, hp.gamma_struct, hp.gamma_sem)
nu = degree_centrality_measure(best_KG, node_ids)
P, best_fgw, struct_term, feat_term = compute_coupling_and_distance(
    D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
)
coupling_report = []

topk_align = 12

N, M = P.shape
for i in range(N):
    topj = np.argsort(-P[i, :])[:min(topk_align, M)]
    row = []
    for j in topj:
        row.append({
            "lecture_idx": int(i),
            "lecture_text": elements[i].text[:240],
            "concept_id": node_ids[j],
            "concept_label": next((n["label"] for n in best_KG["nodes"] if n["id"] == node_ids[j]), node_ids[j]),
            "weight": float(P[i, j])
        })
    coupling_report.append(row)

Use SentenceTransformer.
Use POT.


In [53]:
# Save artifacts
file_name = "objective_rate_distortion_curve.png"
rd_curve_path = os.path.join(out_dir, os.path.basename(lecture_json_path).replace(".json", f"_{file_name}"))
rd_curve_path

'./data/output/lecture_notes_8_objective_rate_distortion_curve.png'

In [54]:

try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    # One chart, no styles/colors
    xs = [h["iter"] for h in history]
    rates = [h["rate"] for h in history]
    dists = [h["fgw_total"] for h in history]
    objs = [h["objective"] for h in history]
    plt.figure()
    plt.plot(xs, objs, label="Objective L")
    plt.plot(xs, rates, label="Rate r")
    plt.plot(xs, dists, label="Distortion d (FGW)")
    plt.xlabel("Iteration")
    plt.ylabel("Value")
    plt.legend()
    plt.tight_layout()
    plt.savefig(rd_curve_path, dpi=160)
    plt.close()
except Exception as e:
    logging.warning("Could not generate RD curve plot: %s", e)
    rd_curve_path = None


In [55]:
# Package refined KG meta
if "meta" not in best_KG:
    best_KG["meta"] = {}
best_KG["meta"].setdefault("refinement", {})
best_KG["meta"]["refinement"].update({
    "iterations": outcome.iterations,
    "final_objective": outcome.final_objective,
    "rate": outcome.rate,
    "distortion": outcome.distortion,
    "fgw_distance": outcome.fgw_distance,
    "used_pot_fgw": bool(use_pot and _HAS_POT),
    "alpha": {
        "chronological": hp.alpha_chron,
        "logical": hp.alpha_logic,
        "semantic": hp.alpha_sem
    },
    "gamma": {
        "structural": hp.gamma_struct,
        "semantic": hp.gamma_sem
    },
    "lambda_feature_balance": hp.lambda_feat,
    "beta_rate_distortion": hp.beta,
    "thresholds": {
        "theta_add": hp.theta_add,
        "theta_split": hp.theta_split,
        "theta_merge": hp.theta_merge,
        "theta_relate": hp.theta_relate
    }
})

# Plot
- Computes rate R = nodes + 0.5 * edges.
- Computes objective L = R + beta * D for a user-chosen beta (default 20).
- Finds the geometric knee (max perpendicular distance to the line between the
  first and last points in (D, R) space).
- Writes a CSV of the iteration history and saves a PNG plot.

In [56]:
@dataclass
class RDHistory:
    iters: np.ndarray       # iteration indices
    D: np.ndarray           # distortion per iteration
    R: np.ndarray

    #nodes: np.ndarray       # node count per iteration
    #edges: np.ndarray       # edge count per iteration


def objective_L(R: np.ndarray, D: np.ndarray, beta: float) -> np.ndarray:
    return R + beta * D


def knee_index_max_distance(D: np.ndarray, R: np.ndarray) -> int:
    """
    Knee = point of maximum perpendicular distance to the chord from the
    first to last point in (D, R) space.
    """
    A = np.array([D[0], R[0]], dtype=float)
    B = np.array([D[-1], R[-1]], dtype=float)
    AB = B - A
    ab_norm = np.linalg.norm(AB)
    if ab_norm == 0:
        return 0

    def dist_to_line(p: np.ndarray) -> float:
        # In 2D, || (B - A) x (P - A) || / ||B - A||
        return np.abs(np.cross(AB, p - A)) / ab_norm

    distances = np.array([dist_to_line(np.array([D[i], R[i]], dtype=float)) for i in range(len(D))])
    return int(np.argmax(distances))


def beta_threshold_vs_seed(R: np.ndarray, D: np.ndarray) -> np.ndarray:
    """
    For each iteration i, solve R_i + beta*D_i = R_0 + beta*D_0 for beta.
    beta_i = (R0 - Ri) / (Di - D0) = (Ri - R0) / (D0 - Di)
    """
    R0, D0 = R[0], D[0]
    betas = []
    for i in range(len(D)):
        if i == 0:
            betas.append(np.nan)
            continue
        denom = (D0 - D[i])
        betas.append((R[i] - R0) / denom if denom != 0 else np.inf)
    return np.array(betas, dtype=float)


def build_dataframe(hist: RDHistory, beta: float) -> pd.DataFrame:
    R = hist.R
    L = objective_L(R, hist.D, beta)
    beta_tie = beta_threshold_vs_seed(R, hist.D)
    df = pd.DataFrame({
        "iter": hist.iters,
        "distortion_D": hist.D,
        "rate_R": R,
        #"nodes": hist.nodes,
        #"edges": hist.edges,
        f"objective_L(beta={beta:g})": L,
        "beta_threshold_vs_seed": beta_tie
    })
    return df


def plot_rd_curve(
    hist: RDHistory,
    beta: float,
    savepath: str,
    show_iso_L: bool = False,
    iso_L_count: int = 4
) -> Tuple[int, int]:
    """
    Plot R vs D, annotate iterations, knee, and min-L point (for given beta).
    Optionally overlay iso-L lines.
    """
    D, R = hist.D, hist.R
    L = objective_L(R, D, beta)
    knee_idx = knee_index_max_distance(D, R)
    lopt_idx = int(np.argmin(L))

    plt.figure(figsize=(6.0, 5.0))
    # R–D curve
    plt.plot(D, R, marker='o')  # default styling

    # annotate points t0..tn
    for i in range(len(D)):
        plt.annotate(f"t{i}", (D[i], R[i]), xytext=(5, 5),
                     textcoords='offset points', fontsize=8)

    # highlight knee and L-optimum
    plt.scatter([D[knee_idx]], [R[knee_idx]], marker='s', s=90, label=f"Knee ~ t{knee_idx}")
    plt.scatter([D[lopt_idx]], [R[lopt_idx]], marker='*', s=160, label=f"Min L (β={beta:g}) ~ t{lopt_idx}")

    # optional: iso-L lines through a few L values
    if show_iso_L:
        # pick a span of L values centered around median
        Lmin, Lmax = float(np.min(L)), float(np.max(L))
        Lvals = np.linspace(Lmin, Lmax, num=iso_L_count)
        x_span = np.linspace(float(np.min(D)) - 0.02, float(np.max(D)) + 0.02, 100)
        for Lv in Lvals:
            # R = L - beta * D
            y_line = Lv - beta * x_span
            plt.plot(x_span, y_line, linestyle='--', linewidth=0.8)

    plt.xlabel("Distortion D")
    plt.ylabel("Rate R")
    plt.title("Rate–Distortion Curve (D on x, R on y)")
    plt.grid(True)
    plt.legend(loc="best")
    plt.tight_layout()
    plt.savefig(savepath, dpi=150)
    plt.close()
    return knee_idx, lopt_idx


In [57]:
plot_file_name = "rate_distortion_plot_curve.png" #help="Base filename (without extension) for outputs")
iso = True # help="Overlay iso-L lines on the plot")

In [58]:
iters = []
D = []
R = []
for h in history:
    iters.append(h['iter'])
    R.append(h['rate'])
    D.append(h['fgw_total'])

In [59]:
hist = RDHistory(np.array(iters), np.array(D), np.array(R))

In [60]:
df = build_dataframe(hist, beta)

In [61]:
df

Unnamed: 0,iter,distortion_D,rate_R,objective_L(beta=100),beta_threshold_vs_seed
0,0,0.61352,7.0,68.352041,
1,1,0.474801,20.0,67.480114,93.714448
2,2,0.432905,35.5,78.790532,157.794124
3,3,0.410908,50.0,91.090846,212.228348
4,4,0.396742,65.5,105.174217,269.86103
5,5,0.372485,80.5,117.748481,304.934217
6,6,0.377586,102.5,140.258605,404.773604
7,7,0.365558,121.5,158.055805,461.763638
8,8,0.358528,146.0,181.852827,545.114835
9,9,0.347876,172.5,207.287637,623.014166


### Save Rate_Distortion History

In [62]:
history_name = "rate_distortion_history.csv"
history_path = os.path.join(out_dir, os.path.basename(lecture_json_path).replace(".json", f"_{history_name}"))
history_path

'./data/output/lecture_notes_8_rate_distortion_history.csv'

In [63]:
df.to_csv(history_path)

### Save KG History and KG_Knee to Files

In [64]:
png_path = os.path.join(out_dir, os.path.basename(lecture_json_path).replace(".json", f"_{plot_file_name}"))

In [65]:
png_path

'./data/output/lecture_notes_8_rate_distortion_plot_curve.png'

In [66]:
knee_idx, lopt_idx = plot_rd_curve(hist, beta, png_path, show_iso_L=iso, iso_L_count=5)

In [67]:
knee_idx, lopt_idx

(5, 1)

In [68]:
# Console summary
R, D = hist.R, hist.D
L = objective_L(R, D, beta)
print("\n=== Rate–Distortion Summary ===")
print(f"beta (β): {beta:g}")
print(f"Knee index (max distance): t{knee_idx}  [D={D[knee_idx]:.3f}, R={R[knee_idx]:.3f}]")
print(f"Min L index:               t{lopt_idx}  [D={D[lopt_idx]:.3f}, R={R[lopt_idx]:.3f}, L={L[lopt_idx]:.3f}]")

# Optional quick thresholds report
beta_tie = beta_threshold_vs_seed(R, D)
print("\nβ thresholds to tie seed (t0) in L (higher → favors higher-fidelity iterations):")
for i, b in enumerate(beta_tie):
    if i == 0:
        print(f"  t{i}: —")
    else:
        if math.isfinite(b):
            print(f"  t{i}: β ≈ {b:.2f}")
        else:
            print(f"  t{i}: β = ∞ (no finite tie)")


=== Rate–Distortion Summary ===
beta (β): 100
Knee index (max distance): t5  [D=0.372, R=80.500]
Min L index:               t1  [D=0.475, R=20.000, L=67.480]

β thresholds to tie seed (t0) in L (higher → favors higher-fidelity iterations):
  t0: —
  t1: β ≈ 93.71
  t2: β ≈ 157.79
  t3: β ≈ 212.23
  t4: β ≈ 269.86
  t5: β ≈ 304.93
  t6: β ≈ 404.77
  t7: β ≈ 461.76
  t8: β ≈ 545.11
  t9: β ≈ 623.01
  t10: β ≈ 732.72
  t11: β ≈ 806.09


In [69]:
# The KG at knee point
kg_knee = kg_history[knee_idx]

In [70]:
len(kg_knee['nodes']), len(kg_knee['edges'])

(39, 83)

In [71]:
history_file_name = "kg_history.json"
kg_history_path = os.path.join(out_dir, os.path.basename(lecture_json_path).replace(".json", f"_{history_file_name}"))

In [72]:
kg_history_path

'./data/output/lecture_notes_8_kg_history.json'

In [73]:
# Save to disk
with open(kg_history_path, "w") as f:
    json.dump(kg_history, f, indent=2)

In [74]:
knee_file_name = "kg_knee.json"
kg_knee_path = os.path.join(out_dir, os.path.basename(lecture_json_path).replace(".json", f"_{knee_file_name}"))

In [75]:
kg_knee_path

'./data/output/lecture_notes_8_kg_knee.json'

In [76]:
# Save to disk
with open(kg_knee_path, "w") as f:
    json.dump(kg_knee, f, indent=2)

## Checking Coverage of KG vs. Lecture 

In [77]:
# Given lecture elements and KG, evaluate the KG's coverage
# within tolerated feature distortion
def kg_coverage_with_feature_distortion(elements, kg, embedder, hp, us_pot, thresh):
    # build lecture metric-measure space 
    lecture_texts = [e.text for e in elements]
    E_L = embedder.fit_transform(lecture_texts)
    D_L, D_chron, D_logic, D_semL = build_lecture_distance(
        elements, E_L, hp.alpha_chron, hp.alpha_logic, hp.alpha_sem
    )

    # build KG metri-measure space
    D_K, D_struct, D_semK, E_K, node_ids = build_kg_distance(kg, 
                                                             embedder, 
                                                             hp.gamma_struct, hp.gamma_sem)

    # build lecture elements and kg nodes distributions
    mu = normalized_measure(len(elements))
    nu = degree_centrality_measure(kg, node_ids)

    # compute coupling
    P, fgw_total_tmp, struct_term, feat_term = compute_coupling_and_distance(
        D_L, D_K, E_L, E_K, mu, nu, hp.lambda_feat, hp.sinkhorn_eps, hp.sinkhorn_iter, use_pot
    )

    # compute feature distortion
    M_feat = cdist(E_L, E_K, metric="sqeuclidean")

    # compute feature distortion tolerance
    tau = 0
    feat_ravel = np.asarray(M_feat, dtype=float).ravel()
    feat_ravel = feat_ravel[np.isfinite(feat_ravel)]
    if feat_ravel.size != 0:
        tau = float(np.quantile(feat_ravel, thresh))

    # compute coverage
    good = (M_feat <= tau).astype(float)

    per_pair = P * good
    return per_pair.sum()  # already in [0,1]

In [78]:
lecture_json_path = "./data/lecture_notes_8.json"
kg_json_path = "./data/kg8.json"
kg_knee_path = "./data/output/lecture_notes_8_kg_knee.json"

In [79]:
elements = parse_lecture_elements(lecture_json_path)

In [80]:
kg = read_kg(kg_json_path)

In [81]:
kg_cov = kg_coverage_with_feature_distortion(elements, kg, embedder, hp, True, 0.3)
kg_cov

Use SentenceTransformer.
Use SentenceTransformer.
Use POT.


0.5306603773584907

In [82]:
kg_knee_file = read_json(kg_knee_path)

In [83]:
kg_knee_cov = kg_coverage_with_feature_distortion(elements, kg_knee_file, embedder, hp, True, 0.3)
kg_knee_cov

Use SentenceTransformer.
Use SentenceTransformer.
Use POT.


0.896112752898386

In [84]:
lecture_list = []
kg_cov_list  = []
kg_knee_cov_list = []
tau_list = []

In [85]:
lecture_list.append("lecture_notes_8")
kg_cov_list.append(kg_cov)
kg_knee_cov_list.append(kg_knee_cov)
tau_list.append("30-th quantile of feature distortion")

In [86]:
coverage_pd = pd.DataFrame({"lecture":lecture_list, "initial_kg_coverage":kg_cov_list,
              "knee_kg_coverage":kg_knee_cov_list, "feature_distortion_tolerance":tau_list})
coverage_pd

Unnamed: 0,lecture,initial_kg_coverage,knee_kg_coverage,feature_distortion_tolerance
0,lecture_notes_8,0.53066,0.896113,30-th quantile of feature distortion


In [87]:
coverage_pd.to_csv("./data/output/kg_coverages.csv", index=None)