## 1. Select some points for training, these points are very rich in connections

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os, math, json, pickle, random
from collections import deque
from typing import Dict, Any, List
import networkx as nx
import numpy as np
from tqdm import tqdm

GRAPH_PKL  = "../graph_2022.pkl"
OUT_DIR    = "../path_selector/graph_retrieval_data"

TRAIN_SIZE = 1400
VALID_SIZE = 400
TEST_SIZE  = 200

RANDOM_SEED              = 42
FAST_SAMPLE_ENABLED      = True  
FAST_SAMPLE_OVERSAMPLE   = 1.35   
FAST_MIN_DEG1            = 1     
FAST_ACCEPT_RATIO        = 0.7   
FAST_PERCENTILE          = 0.50   
FAST_BATCH_SIZE          = 500    
MAX_FAST_ITER_MULTIPLIER = 8      

MAX_HOPS    = 4
LAMBDA_TIME = 0.01
T_REF       = 192

W_DEG1    = 0.4
W_EDGES   = 0.3
W_NODES   = 0.2
W_ENTROPY = 0.1
W_DECAY   = 0.05

F_W_DEG1       = 0.5
F_W_COMP_PROP  = 0.2
F_W_TWOHOP     = 0.3

ENABLE_DIVERSITY_POST   = False
INV_OVERLAP_THRESHOLD   = 0.7

MAX_SUBGRAPH_NODES = 5000   

def is_company(g, node_key) -> bool:
    real_id = g.nodes[node_key].get('id')
    if real_id is None:
        raise ValueError(f"Node {node_key} missing 'id' attribute")
    return not str(real_id).endswith('P')

def k_hop_subgraph_nodes(g: nx.Graph, root, max_hops: int):
    visited = {root}
    frontier = deque([(root, 0)])
    while frontier:
        u, d = frontier.popleft()
        if d == max_hops:
            continue
        for v in g.neighbors(u):
            if v not in visited:
                visited.add(v)
                if len(visited) >= MAX_SUBGRAPH_NODES:
                    return visited
                frontier.append((v, d+1))
    return visited

def type_entropy(comp_cnt: int, inv_cnt: int) -> float:
    total = comp_cnt + inv_cnt
    if total == 0: return 0.0
    out = 0.0
    for c in (comp_cnt, inv_cnt):
        if c > 0:
            p = c / total
            out -= p * math.log(p + 1e-12)
    return out

def compute_metrics(g: nx.Graph, node_id, is_company_flag) -> Dict[str, Any]:
    sub_nodes = k_hop_subgraph_nodes(g, node_id, MAX_HOPS)
    edge_cnt = 0
    for v in sub_nodes:
        for u in g.neighbors(v):
            if u in sub_nodes and u > v:
                edge_cnt += 1
    comp_cnt = sum(is_company_flag[n] for n in sub_nodes)
    inv_cnt  = len(sub_nodes) - comp_cnt
    ent = type_entropy(comp_cnt, inv_cnt)
    deg1 = len(neighbors[node_id])

    decay_sum = 0.0
    edge_dates = []
    for v in sub_nodes:
        for u in g.neighbors(v):
            if u in sub_nodes and u > v:
                data = g.get_edge_data(v, u)
                ed = data.get("edge_date") if isinstance(data, dict) else None
                if ed is not None:
                    edge_dates.append(ed)
                    decay_sum += math.exp(-LAMBDA_TIME * max(T_REF - ed, 0))
    avg_edge_date = float(np.mean(edge_dates)) if edge_dates else -1.0

    richness = (
        math.log(1 + deg1)          * W_DEG1 +
        math.log(1 + edge_cnt)      * W_EDGES +
        math.log(1 + len(sub_nodes))* W_NODES +
        ent * W_ENTROPY +
        decay_sum * W_DECAY
    )

    return dict(
        node_key=node_id,
        deg1=deg1,
        hop_nodes=len(sub_nodes),
        hop_edges=edge_cnt,
        investor_ratio=inv_cnt / max(1, len(sub_nodes)),
        type_entropy=ent,
        avg_edge_date=avg_edge_date,
        time_decay_sum=decay_sum,
        richness_score=richness
    )

def diversity_postprocess(g: nx.Graph, metrics: List[Dict[str, Any]], need_total: int):
    selected = []
    seen_investors = set()
    for m in metrics:
        inv_neigh = {nbr for nbr in g.neighbors(m["id"]) if str(g.nodes[nbr].get('id','')).endswith('P')}
        if len(inv_neigh) == 0:
            continue
        overlap_ratio = len(seen_investors & inv_neigh) / len(inv_neigh)
        if overlap_ratio < INV_OVERLAP_THRESHOLD:
            selected.append(m)
            seen_investors |= inv_neigh
        if len(selected) >= need_total:
            break
    return selected

neighbors = {}
is_company_flag = {}
two_hop_size = {}

def precompute_fast_features(g, company_nodes):
    global neighbors, is_company_flag, two_hop_size
    for n in g.nodes():
        neighbors[n] = list(g.neighbors(n))
        rid = str(g.nodes[n].get('id',''))
        is_company_flag[n] = (not rid.endswith('P'))
    for n in company_nodes:
        lvl1 = neighbors[n]
        seen = set(lvl1)
        for v in lvl1:
            for u in neighbors[v]:
                if u != n:
                    seen.add(u)
        two_hop_size[n] = len(seen)

def fast_score(n):
    deg1 = len(neighbors[n])
    if deg1 == 0:
        return 0.0
    comp1 = sum(is_company_flag[u] for u in neighbors[n])
    comp_prop = comp1 / deg1
    th = two_hop_size.get(n, deg1)  # fallback
    return (deg1 * F_W_DEG1 +
            th   * F_W_TWOHOP +
            comp_prop * F_W_COMP_PROP)

def stageA_random_collect(g, company_nodes, need_total):
    target = int(need_total * FAST_SAMPLE_OVERSAMPLE)
    print(f"[INFO] StageA target (oversampled) = {target}")
    shuffled = company_nodes[:]   
    random.shuffle(shuffled)

    collected = []
    fast_scores_batch = []
    threshold = 0.0

    max_iters = min(len(shuffled), need_total * MAX_FAST_ITER_MULTIPLIER)
    idx = 0
    while len(collected) < target and idx < max_iters:
        n = shuffled[idx % len(shuffled)]
        idx += 1
        if len(neighbors[n]) < FAST_MIN_DEG1:
            continue
        s = fast_score(n)
        fast_scores_batch.append(s)

        if len(fast_scores_batch) >= FAST_BATCH_SIZE:
            threshold = np.percentile(fast_scores_batch, FAST_PERCENTILE * 100)
            fast_scores_batch.clear()

        accept = (threshold == 0.0) or (s >= threshold * FAST_ACCEPT_RATIO)
        if accept:
            collected.append(n)

    print(f"[INFO] StageA collected={len(collected)} after {idx} iterations (threshold≈{threshold:.3f})")
    if len(collected) < need_total:
        print("[WARN] StageA collected fewer than need_total; will proceed.")
    return collected[:target]

def main():
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)

    os.makedirs(OUT_DIR, exist_ok=True)
    print(f"[INFO] Loading graph: {GRAPH_PKL}")
    with open(GRAPH_PKL, "rb") as f:
        g = pickle.load(f)
    print(f"[INFO] Loaded graph |V|={g.number_of_nodes()} |E|={g.number_of_edges()}")

    company_nodes = [n for n in g.nodes() if is_company(g, n)]
    print(f"[INFO] Company nodes total: {len(company_nodes)}")
    need_total = TRAIN_SIZE + VALID_SIZE + TEST_SIZE

    print("[INFO] Precomputing fast features...")
    precompute_fast_features(g, company_nodes)
    print("[INFO] Fast features done.")

    if FAST_SAMPLE_ENABLED:
        stageA = stageA_random_collect(g, company_nodes, need_total)
    else:
        stageA = company_nodes[:]
        random.shuffle(stageA)
        stageA = stageA[:need_total]

    print(f"[INFO] StageB precise richness on {len(stageA)} nodes (this is main cost).")
    metrics = []
    for n in tqdm(stageA, desc="StageB compute richness"):
        metrics.append(compute_metrics(g, n, is_company_flag))

    metrics.sort(key=lambda x: x["richness_score"], reverse=True)

    if ENABLE_DIVERSITY_POST:
        filtered = diversity_postprocess(g, metrics, need_total)
        if len(filtered) >= need_total:
            metrics = filtered + metrics[len(filtered):]
        else:
            print("[WARN] diversity_postprocess insufficient; fallback to original ranking.")

    top_metrics = metrics[:need_total]

    def real_id(n):
        rid = g.nodes[n].get('id')
        if rid is None:
            raise ValueError(f"Node {n} missing 'id' attribute when exporting.")
        return str(rid)

    train_keys = [m["node_key"] for m in top_metrics[:TRAIN_SIZE]]
    valid_keys = [m["node_key"] for m in top_metrics[TRAIN_SIZE:TRAIN_SIZE+VALID_SIZE]]
    test_keys  = [m["node_key"] for m in top_metrics[TRAIN_SIZE+VALID_SIZE:TRAIN_SIZE+VALID_SIZE+TEST_SIZE]]

    train_ids = [real_id(k) for k in train_keys]
    valid_ids = [real_id(k) for k in valid_keys]
    test_ids  = [real_id(k) for k in test_keys]

    with open(os.path.join(OUT_DIR, "train_companies.txt"), "w") as f:
        f.write("\n".join(train_ids))
    with open(os.path.join(OUT_DIR, "valid_companies.txt"), "w") as f:
        f.write("\n".join(valid_ids))
    if TEST_SIZE > 0:
        with open(os.path.join(OUT_DIR, "test_companies.txt"), "w") as f:
            f.write("\n".join(test_ids))

    with open(os.path.join(OUT_DIR, "richness_stats.jsonl"), "w") as f:
        for m in metrics:
            f.write(json.dumps(m, ensure_ascii=False) + "\n")

    print(f"[DONE] Train={len(train_ids)} Valid={len(valid_ids)} Test={len(test_ids)}")
    print("[HINT] Top-3 richness examples:")
    for m in top_metrics[:3]:
        print(m)

if __name__ == "__main__":
    main()


## 2. Construct basic text information for the selected points

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
export_path_groups_3hop.py
--------------------------
Starting from target companies in the train/valid/test splits, perform "random-branch 3-hop" sampling.
For each parent path, generate a set (≤3) of candidate extensions. The output is training data for
subsequent "in-group ranking & selection" (an abstract representation of three binary pairs per group).

Each group structure:
{
  "target_real_id": <root company>,
  "hop": <1|2|3>,                        # depth of this extension (length of the parent path)
  "parent_path": ["A", "B", ...],        # sequence of real IDs along the parent path
  "parent_label": "A|B|...",             # compressed parent path label (joined by '|')
  "candidates": ["E","F","G"],           # ≤3 new candidate real IDs
  "group_list": ["A|B", "E", "F", "G"]   # list-form representation consistent with your example
}

Output files:
  - path_groups_3hop.jsonl : one JSON group per line
  - missing_ids.txt        : root IDs from the splits not present in the graph (if any)
  - stats.json             : brief statistics (number of groups, per-hop distribution, etc.)

Paths & parameters are fixed; run directly: python export_path_groups_3hop.py
"""

import os
import json
import pickle
import random
from typing import List, Dict, Any, Set
from tqdm import tqdm
import networkx as nx

GRAPH_PKL   = "../graph_2022.pkl"
SPLITS_DIR  = "../graph_retrieval_data"
OUT_DIR     = "../graph_retrieval_data/context"

TRAIN_FILE  = "train_companies.txt"
VALID_FILE  = "valid_companies.txt"
TEST_FILE   = "test_companies.txt"      

MAX_HOPS    = 3         
SAMPLE_K    = 3          
RANDOM_SEED = 42
PATH_JOIN   = "|"        

SHOW_TOP_EXAMPLES = 5    
ALLOW_DUP_IN_SIBLINGS = False  

def load_graph(path: str):
    with open(path, "rb") as f:
        g = pickle.load(f)
    if not isinstance(g, (nx.Graph, nx.DiGraph, nx.MultiGraph, nx.MultiDiGraph)):
        raise TypeError(f"Loaded object type {type(g)} is not a NetworkX graph.")
    return g

def load_id_list(path: str) -> List[str]:
    if not os.path.exists(path):
        return []
    with open(path, "r") as f:
        return [ln.strip() for ln in f if ln.strip()]

def build_realid_index(g: nx.Graph) -> Dict[str, List[Any]]:
    idx = {}
    for nk in g.nodes():
        rid = g.nodes[nk].get("id")
        if rid is None:
            continue
        rid = str(rid)
        idx.setdefault(rid, []).append(nk)
    dups = sum(1 for v in idx.values() if len(v) > 1)
    if dups > 0:
        print(f"[WARN] {dups} real ids map to multiple internal node keys (use the first).")
    return idx

def sample_children(g: nx.Graph, node_key, visited_keys: Set[Any], k: int) -> List[Any]:
    neighs = [nbr for nbr in g.neighbors(node_key) if nbr not in visited_keys]
    if not neighs:
        return []
    random.shuffle(neighs)
    picked = neighs[:k]
    if (not ALLOW_DUP_IN_SIBLINGS) and len(set(picked)) != len(picked):
        picked = list(dict.fromkeys(picked))
    return picked

def path_label(path_nodes: List[str]) -> str:
    return PATH_JOIN.join(path_nodes)

def extract_real_id(g: nx.Graph, node_key) -> str:
    rid = g.nodes[node_key].get("id")
    return str(rid) if rid is not None else None

def generate_groups_for_root(g: nx.Graph, root_real_id: str, realid_index: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
    groups = []
    root_key = realid_index[root_real_id][0]
    visited_keys: Set[Any] = {root_key}

    parents = [{
        "keys": [root_key],
        "real_ids": [root_real_id],
        "last_key": root_key
    }]

    for hop in range(1, MAX_HOPS + 1):
        next_parents = []
        for parent in parents:
            parent_last_key = parent["last_key"]
            children_keys = sample_children(g, parent_last_key, visited_keys, SAMPLE_K)
            if not children_keys:
                continue

            parent_path_real = parent["real_ids"]  # list[str]
            parent_lbl = path_label(parent_path_real)
            child_real_ids = []
            for ck in children_keys:
                rid = extract_real_id(g, ck)
                if rid is None:
                    continue
                child_real_ids.append(rid)

            if not child_real_ids:
                continue

            group_list = [parent_lbl] + child_real_ids
            groups.append({
                "target_real_id": root_real_id,
                "hop": hop,
                "parent_path": parent_path_real,
                "parent_label": parent_lbl,
                "candidates": child_real_ids,
                "group_list": group_list
            })

            for ck, crid in zip(children_keys, child_real_ids):
                visited_keys.add(ck)
                next_parents.append({
                    "keys": parent["keys"] + [ck],
                    "real_ids": parent_path_real + [crid],
                    "last_key": ck
                })

        parents = next_parents
        if not parents:
            break

    return groups

def main():
    random.seed(RANDOM_SEED)

    os.makedirs(OUT_DIR, exist_ok=True)
    print(f"[INFO] Loading graph: {GRAPH_PKL}")
    g = load_graph(GRAPH_PKL)
    print(f"[INFO] Graph loaded |V|={g.number_of_nodes()} |E|={g.number_of_edges()}")

    train_ids = load_id_list(os.path.join(SPLITS_DIR, TRAIN_FILE))
    valid_ids = load_id_list(os.path.join(SPLITS_DIR, VALID_FILE))
    test_ids  = load_id_list(os.path.join(SPLITS_DIR, TEST_FILE))
    target_ids = list(dict.fromkeys(train_ids + valid_ids + test_ids))

    print(f"[INFO] Split counts: train={len(train_ids)} valid={len(valid_ids)} test={len(test_ids)} unique_total={len(target_ids)}")

    realid_index = build_realid_index(g)

    missing = [rid for rid in target_ids if rid not in realid_index]
    if missing:
        miss_path = os.path.join(OUT_DIR, "missing_ids.txt")
        with open(miss_path, "w") as f:
            f.write("\n".join(missing))
        print(f"[WARN] {len(missing)} root ids not found. Wrote {miss_path}")
    target_ids = [rid for rid in target_ids if rid in realid_index]

    out_groups_path = os.path.join(OUT_DIR, "path_groups_3hop.jsonl")
    hop_counts = {1: 0, 2: 0, 3: 0}
    total_groups = 0

    print(f"[INFO] Generating 3-hop path groups (SAMPLE_K={SAMPLE_K}) ...")
    with open(out_groups_path, "w") as fout:
        for rid in tqdm(target_ids, desc="Targets"):
            root_groups = generate_groups_for_root(g, rid, realid_index)
            for gp in root_groups:
                hop_counts[gp["hop"]] += 1
                total_groups += 1
                fout.write(json.dumps(gp, ensure_ascii=False) + "\n")

    stats = {
        "total_targets": len(target_ids),
        "total_groups": total_groups,
        "groups_per_hop": hop_counts,
        "max_hops": MAX_HOPS,
        "sample_k": SAMPLE_K,
        "random_seed": RANDOM_SEED
    }
    with open(os.path.join(OUT_DIR, "stats.json"), "w") as f:
        json.dump(stats, f, ensure_ascii=False, indent=2)

    print(f"[INFO] Wrote groups to: {out_groups_path}")
    print(f"[INFO] Stats: {stats}")

    if SHOW_TOP_EXAMPLES > 0:
        print("[HINT] Example groups:")
        with open(out_groups_path, "r") as f:
            for i, line in enumerate(f):
                if i >= SHOW_TOP_EXAMPLES:
                    break
                print(line.strip())

    print("[DONE] export_path_groups_3hop complete.")

if __name__ == "__main__":
    main()


## 3. The following is stratified sampling

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import math
import random
from collections import defaultdict

CONTEXT_DIR   = "../graph_retrieval_data/context"
INPUT_FILE    = "path_groups_3hop.jsonl"           
OUTPUT_FILE   = "path_groups_3hop_sampled.jsonl"   
STATS_FILE    = "path_groups_3hop_sampled_stats.json"

RANDOM_SEED   = 42

SAMPLE_RATIO = {
    1: 1.0,    
    2: 0.5,    
    3: 1/3      
}

USE_FLOOR = True

def main():
    random.seed(RANDOM_SEED)
    in_path  = os.path.join(CONTEXT_DIR, INPUT_FILE)
    out_path = os.path.join(CONTEXT_DIR, OUTPUT_FILE)
    stats_path = os.path.join(CONTEXT_DIR, STATS_FILE)

    if not os.path.exists(in_path):
        raise FileNotFoundError(f"Input file not found: {in_path}")

    kept_per_hop = defaultdict(list)
    origin_counts = defaultdict(int)

    print(f"[INFO] Loading groups from {in_path}")
    with open(in_path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            hop = obj.get("hop")
            cands = obj.get("candidates", [])
            origin_counts[hop] += 1
            if len(cands) == 3:
                kept_per_hop[hop].append(obj)

    filtered_counts = {h: len(lst) for h, lst in kept_per_hop.items()}

    sampled = []
    sampled_counts = {}
    for hop, groups in kept_per_hop.items():
        ratio = SAMPLE_RATIO.get(hop, 1.0)
        if ratio >= 1.0:
            chosen = groups
        else:
            n_total = len(groups)
            want = ratio * n_total
            if USE_FLOOR:
                k = math.floor(want)
            else:
                k = round(want)
            if n_total > 0 and k == 0:
                k = 1  
            if k < n_total:
                chosen = random.sample(groups, k)
            else:
                chosen = groups
        sampled.extend(chosen)
        sampled_counts[hop] = len(chosen)

    sampled.sort(key=lambda x: (x["target_real_id"], x["hop"], x["parent_label"]))

    with open(out_path, "w") as fout:
        for obj in sampled:
            fout.write(json.dumps(obj, ensure_ascii=False) + "\n")

    stats = {
        "input_file": in_path,
        "output_file": out_path,
        "random_seed": RANDOM_SEED,
        "sample_ratio": SAMPLE_RATIO,
        "origin_counts_per_hop": dict(origin_counts),
        "filtered_counts_candidates_eq_3": filtered_counts,
        "sampled_counts_per_hop": sampled_counts,
        "total_sampled": len(sampled)
    }
    with open(stats_path, "w") as fjson:
        json.dump(stats, fjson, ensure_ascii=False, indent=2)

    print("[INFO] Origin counts (all groups):", dict(origin_counts))
    print("[INFO] After length==3 filter:", filtered_counts)
    print("[INFO] Sampled counts:", sampled_counts)
    print(f"[INFO] Total sampled groups written: {len(sampled)}")
    print(f"[INFO] Stats written to: {stats_path}")

    print("[HINT] Sampled examples:")
    for obj in sampled[:5]:
        print(json.dumps(obj, ensure_ascii=False))

    print("[DONE] Sampling complete.")

if __name__ == "__main__":
    main()


## 3. Generate basic information about all points involved, four for each group

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
build_prompts_from_groups.py
----------------------------
Given grouped path expansions (parent_label + 3 candidates),
construct 4 English prompts per group:
  * base prompt (parent path only)
  * 3 extended prompts (parent path + each candidate)

Input file  (JSONL):
  path_groups_3hop_sampled.jsonl
  Each line has: {
      "target_real_id": "...",
      "hop": 1|2|3,
      "parent_path": [...],      # list of real ids (strings) (optional but we reconstruct from parent_label)
      "parent_label": "A|B|...",
      "candidates": ["C","D","E"],
      ...
  }

External info sources:
  COMPANY_INFO_PKL: dict[id] -> {"basic_info": str, "label": 0/1 or None}
  INVESTOR_INFO_PKL: dict[id] -> str OR dict with key "basic_info"

Conventions:
  * First id of any path is always the target company (do NOT reveal its label).
  * Paths alternate types company / investor / company / investor / ...
  * Candidate list nodes are all the opposite type of the last node in parent path.
  * For company nodes (except the target) we verbalize label:
      1 -> "previously succeeded"
      0 -> "previously failed"
      missing -> "outcome unknown"

Output JSONL:
  prompts_3hop_groups.jsonl
Each line:
  {
    "target_id": "...",
    "hop": ...,
    "parent_label": "...",
    "parent_path_ids": [...],
    "base_prompt": "...",
    "extended": [
        {"candidate_id": "...", "extended_path_ids": [...], "prompt": "..."},
        ...
    ]
  }

Adjust PATHS section below as needed.
"""

import os
import json
import pickle
from typing import Dict, Any, List


CONTEXT_DIR          = "../graph_retrieval_data/context"
GROUPS_FILE          = "path_groups_3hop_sampled.jsonl"   
OUTPUT_FILE          = "prompts_3hop_groups.jsonl"

COMPANY_INFO_PKL     = "../train_data_2022_basic.pkl"
INVESTOR_INFO_PKL    = "../person_biography_dict.pkl"

PATH_SEPARATOR       = "|"     
MAX_SHOW_EXAMPLES    = 3

SYSTEM_PREAMBLE = (
    "You are a seasoned venture-capital investor. "
    "A target company has just secured its Series-A financing. "
    "Your task is to predict whether it will obtain a second round of financing, IPO, "
    "or be acquired within the next 12 months.\n"
)

GUIDANCE_SUFFIX = (
    "\nAbove is the target company's basic profile and a structured investment linkage context. "
    "Based on ONLY this information, output your single-word judgment on whether the company will obtain a "
    "second round, IPO, or be acquired within 12 months.\n"
    "Use exactly one of the following formats:\n"
    "  Prediction: True\n"
    "  Prediction: False"
)

def is_investor(node_id: str) -> bool:
    return node_id.endswith('P')

def load_pickle(path: str):
    with open(path, "rb") as f:
        return pickle.load(f)

def get_company_basic(company_info: Dict[str, Any], cid: str) -> str:
    data = company_info.get(cid)
    if data is None:
        return f"Basic information unavailable for company {cid}."
    if isinstance(data, dict):
        return data.get("basic_info") or f"Basic information unavailable for company {cid}."
    if isinstance(data, str):
        return data
    return f"Basic information unavailable for company {cid}."

def get_company_label_text(company_info: Dict[str, Any], cid: str) -> str:
    data = company_info.get(cid)
    if not data or not isinstance(data, dict):
        return "outcome unknown"
    label = data.get("label")
    if label == 1:
        return "previously succeeded"
    if label == 0:
        return "previously failed"
    return "outcome unknown"

def get_investor_info(investor_info: Dict[str, Any], iid: str) -> str:
    data = investor_info.get(iid)
    if data is None:
        return f"Background information unavailable for investor {iid}."
    if isinstance(data, dict):
        return data.get("basic_info") or str(data)
    if isinstance(data, str):
        return data
    return str(data)

def describe_path_segment(path_ids: List[str],
                          company_info: Dict[str, Any],
                          investor_info: Dict[str, Any]) -> str:
    if not path_ids:
        return ""
    target_id = path_ids[0]
    target_basic = get_company_basic(company_info, target_id)

    lines = []
    lines.append(f"Target company {target_id}: {target_basic}")

    for i in range(1, len(path_ids)):
        prev_id = path_ids[i-1]
        cur_id  = path_ids[i]
        prev_is_inv = is_investor(prev_id)
        cur_is_inv  = is_investor(cur_id)

        if prev_is_inv and cur_is_inv:
            lines.append(f"NOTE: Two consecutive investors {prev_id}->{cur_id} (unexpected).")
        elif (not prev_is_inv) and (not cur_is_inv):
            lines.append(f"NOTE: Two consecutive companies {prev_id}->{cur_id} (unexpected).")
        else:
            if cur_is_inv:
                inv_info = get_investor_info(investor_info, cur_id)
                if prev_id == target_id:
                    lines.append(
                        f"Investor {cur_id} is one of the investors in target {target_id}. "
                        f"Profile: {inv_info}"
                    )
                else:
                    lines.append(
                        f"Investor {cur_id} has also backed company {prev_id} in the chain. "
                        f"Profile: {inv_info}"
                    )
            else:
                comp_basic = get_company_basic(company_info, cur_id)
                comp_label_text = get_company_label_text(company_info, cur_id)
                lines.append(
                    f"Company {cur_id} is another portfolio company linked via investor {prev_id}. "
                    f"{comp_basic} Historical outcome: {comp_label_text}."
                )
    return "\n".join(lines)

def assemble_prompt(path_ids: List[str],
                    company_info: Dict[str, Any],
                    investor_info: Dict[str, Any]) -> str:
    narrative = describe_path_segment(path_ids, company_info, investor_info)
    body_intro = (
        "Below is investment-chain context you may consult to form your prediction.\n"
    )
    prompt = SYSTEM_PREAMBLE + body_intro + narrative + GUIDANCE_SUFFIX
    return prompt

def process():
    company_info = load_pickle(COMPANY_INFO_PKL)
    investor_info = load_pickle(INVESTOR_INFO_PKL)

    groups_path = os.path.join(CONTEXT_DIR, GROUPS_FILE)
    out_path    = os.path.join(CONTEXT_DIR, OUTPUT_FILE)

    if not os.path.exists(groups_path):
        raise FileNotFoundError(f"Groups file not found: {groups_path}")

    total_groups = 0
    written = 0

    with open(groups_path, "r") as fin, open(out_path, "w") as fout:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            group_obj = json.loads(line)
            total_groups += 1

            parent_label = group_obj["parent_label"]
            candidates   = group_obj.get("candidates", [])
            hop          = group_obj.get("hop")
            target_id    = group_obj.get("target_real_id")

            parent_path_ids = parent_label.split(PATH_SEPARATOR) if parent_label else []
            if not parent_path_ids:
                continue
            if parent_path_ids[0] != target_id:
                pass

            base_prompt = assemble_prompt(parent_path_ids, company_info, investor_info)

            extended_prompts = []
            for cand_id in candidates:
                extended_path = parent_path_ids + [cand_id]
                ext_prompt = assemble_prompt(extended_path, company_info, investor_info)
                extended_prompts.append({
                    "candidate_id": cand_id,
                    "extended_path_ids": extended_path,
                    "prompt": ext_prompt
                })

            out_record = {
                "target_id": target_id,
                "hop": hop,
                "parent_label": parent_label,
                "parent_path_ids": parent_path_ids,
                "base_prompt": base_prompt,
                "extended": extended_prompts
            }
            fout.write(json.dumps(out_record, ensure_ascii=False) + "\n")
            written += 1

    print(f"[DONE] Processed groups: {total_groups}, wrote prompt records: {written}")
    print(f"[OUTPUT] {out_path}")

    if MAX_SHOW_EXAMPLES > 0:
        print("[HINT] Example prompt records:")
        with open(out_path, "r") as f:
            for i, l in enumerate(f):
                if i >= MAX_SHOW_EXAMPLES:
                    break
                print(l.strip()[:500] + ("..." if len(l.strip()) > 500 else ""))

if __name__ == "__main__":
    process()


## 4. Obtain offline reasoning structures for different information combinations

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import math
from typing import Dict, List

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

CONTEXT_DIR = "../graph_retrieval_data/context"
PROMPTS_FILE = "prompts_3hop_groups.jsonl"

MODEL_PATH = "../LLama/Meta-Llama-3.1-8B-Instruct"

DEVICE_MAP = "auto"
DTYPE = torch.bfloat16   
MAX_SHOW_TOKENS = 2      
ADD_DEBUG_GENERATE = False 

print("[INFO] Loading model...")
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    use_fast=False,
    local_files_only=True
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=DTYPE,
    device_map=DEVICE_MAP,
    local_files_only=True
)
model.eval()
print("[INFO] Model loaded.")

TRUE_TOKENS = tokenizer.encode("True", add_special_tokens=False)
FALSE_TOKENS = tokenizer.encode("False", add_special_tokens=False)
print(f"[DEBUG] TRUE_TOKENS={TRUE_TOKENS}, FALSE_TOKENS={FALSE_TOKENS}")

def ensure_prediction_suffix(prompt: str) -> str:
    suffix = "Prediction: "
    if prompt.rstrip().endswith("Prediction:"):
        return prompt.rstrip() + " "
    if prompt.endswith(suffix):
        return prompt
    if not prompt.endswith("\n"):
        prompt += "\n"
    return prompt + suffix

@torch.no_grad()
def two_token_conditional_prob(prompt: str) -> Dict[str, float]:
    prompt = ensure_prediction_suffix(prompt)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_ids = inputs["input_ids"]

    out = model(**inputs)
    logits_last = out.logits[:, -1, :]  # (1, vocab)
    probs_first = F.softmax(logits_last, dim=-1)

    if len(TRUE_TOKENS) == 1 and len(FALSE_TOKENS) == 1:
        p_true_raw = probs_first[0, TRUE_TOKENS[0]]
        p_false_raw = probs_first[0, FALSE_TOKENS[0]]
    else:
        def seq_prob(prefix_ids: torch.Tensor, seq: List[int]) -> torch.Tensor:
            cur = prefix_ids.clone()
            logp = 0.0
            for tid in seq:
                out_step = model(input_ids=cur)
                logits_step = out_step.logits[:, -1, :]
                probs_step = F.softmax(logits_step, dim=-1)
                p_tok = probs_step[0, tid]
                logp += torch.log(p_tok + 1e-12)
                next_tok = torch.tensor([[tid]], device=cur.device)
                cur = torch.cat([cur, next_tok], dim=1)
            return torch.exp(logp)

        p_true_raw = seq_prob(input_ids, TRUE_TOKENS)
        p_false_raw = seq_prob(input_ids, FALSE_TOKENS)

    denom = p_true_raw + p_false_raw + 1e-12
    p_true = (p_true_raw / denom).item()
    p_false = (p_false_raw / denom).item()

    log_odds = math.log((p_true + 1e-12) / (p_false + 1e-12))
    margin = p_true - p_false

    return {
        "prompt_final": prompt,
        "p_true": p_true,
        "p_false": p_false,
        "log_odds": log_odds,
        "margin": margin,
        "pred_label": "True" if p_true >= 0.5 else "False"
    }

@torch.no_grad()
def debug_generate(prompt: str):
    prompt = ensure_prediction_suffix(prompt)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    gen = model.generate(
        **inputs,
        max_new_tokens=MAX_SHOW_TOKENS,
        temperature=0.0,
        do_sample=False
    )
    full = gen.sequences[0]
    new_tokens = full[inputs["input_ids"].shape[1]:]
    txt = tokenizer.decode(new_tokens, skip_special_tokens=True)
    return txt.strip()

def load_first_group(path: str):
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            return obj
    return None

def main():
    groups_path = os.path.join(CONTEXT_DIR, PROMPTS_FILE)
    group = load_first_group(groups_path)
    if group is None:
        print("[ERROR] No group found.")
        return

    print("[INFO] Loaded one group:")
    print(json.dumps({
        "target_id": group["target_id"],
        "hop": group["hop"],
        "parent_label": group["parent_label"],
        "num_candidates": len(group.get("extended", []))
    }, ensure_ascii=False, indent=2))

    base_prompt = group["base_prompt"]
    base_res = two_token_conditional_prob(base_prompt)

    result_record = {
        "target_id": group["target_id"],
        "hop": group["hop"],
        "parent_label": group["parent_label"],
        "base": {
            "p_true": base_res["p_true"],
            "p_false": base_res["p_false"],
            "log_odds": base_res["log_odds"],
            "margin": base_res["margin"],
            "pred_label": base_res["pred_label"]
        },
        "extended": []
    }

    if ADD_DEBUG_GENERATE:
        dbg = debug_generate(base_prompt)
        print(f"[DEBUG] Base generated: {dbg}")

    for ext in group["extended"]:
        cand_id = ext["candidate_id"]
        prompt_ext = ext["prompt"]
        ext_res = two_token_conditional_prob(prompt_ext)
        if ADD_DEBUG_GENERATE:
            dbg2 = debug_generate(prompt_ext)
            print(f"[DEBUG] Ext {cand_id} generated: {dbg2}")

        result_record["extended"].append({
            "candidate_id": cand_id,
            "p_true": ext_res["p_true"],
            "p_false": ext_res["p_false"],
            "log_odds": ext_res["log_odds"],
            "margin": ext_res["margin"],
            "pred_label": ext_res["pred_label"]
        })

    print("\n[RESULT]")
    print(json.dumps(result_record, ensure_ascii=False, indent=2))

if __name__ == "__main__":
    main()


## 5. The following is the batch processing code

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
batch_infer_prompts.py
----------------------
Batch inference for all prompt groups:
  - Reads prompts_3hop_groups.jsonl
  - Computes conditional probabilities p(True), p(False) for base + each extended prompt
  - (Optional) Computes Δ gain for each extended w.r.t. base if target label is available
  - Writes predictions_3hop_groups.jsonl

  CE(y,p) = - [ y log p + (1-y) log(1-p) ]
  Δ = (CE_base - CE_ext) + λ * ( |p_ext - 0.5| - |p_base - 0.5| )

Requires:
  - Llama causal LM
  - (Optional) company_info.pkl: {company_id: {"label": 0/1, ...}}

Resume:
  If output exists and RESUME=True, already processed keys won't be recomputed.
"""

import os
import json
import math
import pickle
import hashlib
from typing import List, Dict, Any, Tuple
from collections import OrderedDict

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

CONTEXT_DIR   = "../graph_retrieval_data/context"
INPUT_FILE    = "prompts_3hop_groups.jsonl"
OUTPUT_FILE   = "predictions_3hop_groups.jsonl"

COMPANY_INFO_PKL = "../train_data_2022_basic.pkl"
USE_COMPANY_LABEL = True          
LAMBDA_CONF = 0.2                 

MODEL_PATH   = "/data/VC_LLM_Agent/LLama/Meta-Llama-3.1-8B-Instruct"
DTYPE        = torch.bfloat16    
DEVICE_MAP   = "auto"

BATCH_SIZE   = 1                 
RESUME       = True              
PRED_SUFFIX  = "Prediction: "    
CACHE_MAX    = 200000             

COMPUTE_DELTA = True             
EPS = 1e-12

SHOW_PROGRESS = True
PRINT_EVERY   = 2000             

MAX_CTX     = 2048      
KEEP_HEAD   = 1024     

def ensure_prediction_suffix(prompt: str) -> str:
    if prompt.endswith(PRED_SUFFIX):
        return prompt
    if prompt.rstrip().endswith(PRED_SUFFIX.strip()):
        return prompt.rstrip() + " "
    if not prompt.endswith("\n"):
        prompt += "\n"
    return prompt + PRED_SUFFIX

def load_company_info(path: str) -> Dict[str, Any]:
    if not USE_COMPANY_LABEL:
        return {}
    if not os.path.exists(path):
        print(f"[WARN] company info file not found: {path}; Δ will be skipped.")
        return {}
    with open(path, "rb") as f:
        data = pickle.load(f)
    return data if isinstance(data, dict) else {}
    

def get_label(company_info: Dict[str, Any], cid: str):
    rec = company_info.get(cid)
    if not rec or not isinstance(rec, dict):
        return None
    return rec.get("label")

def ce_loss(y: int, p: float, eps=1e-12):
    p = min(max(p, eps), 1 - eps)
    return -(y * math.log(p) + (1 - y) * math.log(1 - p))

def compute_delta(y: int, p_base: float, p_ext: float, lambda_conf: float) -> float:
    ce_b = ce_loss(y, p_base)
    ce_e = ce_loss(y, p_ext)
    return (ce_b - ce_e) + lambda_conf * (abs(p_ext - 0.5) - abs(p_base - 0.5))

def model_setup():
    print("[INFO] Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH,
        use_fast=False,
        local_files_only=True
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=DTYPE,
        device_map=DEVICE_MAP,
        local_files_only=True
    )
    model.eval()
    model.config.use_cache = False

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
        print(f"[INFO] Set pad_token to eos_token (id={tokenizer.pad_token_id}).")


    true_tokens  = tokenizer.encode("True", add_special_tokens=False)
    false_tokens = tokenizer.encode("False", add_special_tokens=False)
    print(f"[INFO] TRUE tokens={true_tokens}, FALSE tokens={false_tokens}")
    return tokenizer, model, true_tokens, false_tokens


def maybe_truncate(prompt: str, tok: AutoTokenizer) -> str:
    ids = tok.encode(prompt, add_special_tokens=False)
    if len(ids) <= MAX_CTX:
        return prompt
    head = ids[:KEEP_HEAD]
    tail = ids[-(MAX_CTX - KEEP_HEAD):]
    new_ids = head + tail
    return tok.decode(new_ids, skip_special_tokens=True)

def seq_raw_prob(model, tokenizer, prefix_ids: torch.Tensor, target_ids: List[int]) -> torch.Tensor:
    cur = prefix_ids
    logp = 0.0
    for tid in target_ids:
        out = model(input_ids=cur, use_cache=False)
        logits = out.logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        p_tok = probs[0, tid]
        logp = logp + torch.log(p_tok + 1e-12)
        next_tok = torch.tensor([[tid]], device=cur.device)
        cur = torch.cat([cur, next_tok], dim=1)
    return torch.exp(logp)

def batch_boolean_probs(model,
                        tokenizer,
                        prompts: List[str],
                        true_tokens: List[int],
                        false_tokens: List[int]) -> List[Dict[str, float]]:
    multi = (len(true_tokens) > 1) or (len(false_tokens) > 1)
    results = []
    if multi:
        for pr in prompts:
            pr_final = ensure_prediction_suffix(pr)
            enc = tokenizer(pr_final, return_tensors="pt").to(model.device)
            p_true_raw = seq_raw_prob(model, tokenizer, enc["input_ids"], true_tokens)
            p_false_raw = seq_raw_prob(model, tokenizer, enc["input_ids"], false_tokens)
            denom = p_true_raw + p_false_raw + 1e-12
            p_true = (p_true_raw / denom).item()
            p_false = (p_false_raw / denom).item()
            results.append({"p_true": p_true, "p_false": p_false})
        return results

    fixed_prompts = [
    maybe_truncate(ensure_prediction_suffix(p), tokenizer)   
    for p in prompts
    ]
    enc = tokenizer(fixed_prompts, return_tensors="pt", padding=True).to(model.device)
    with torch.inference_mode():
        out = model(**enc, use_cache=False)
    logits = out.logits  # (B, L, V)
    attn = enc["attention_mask"]
    last_idx = attn.sum(dim=1) - 1  # (B,)

    true_id = true_tokens[0]
    false_id = false_tokens[0]

    for i in range(len(fixed_prompts)):
        z = logits[i, last_idx[i], :]
        probs = F.softmax(z, dim=-1)
        p_true_raw = probs[true_id]
        p_false_raw = probs[false_id]
        denom = p_true_raw + p_false_raw + 1e-12
        p_true = (p_true_raw / denom).item()
        p_false = (p_false_raw / denom).item()
        first_tok_id = int(torch.argmax(z))
        first_tok    = tokenizer.decode([first_tok_id], skip_special_tokens=True).strip()
        results.append({"p_true": p_true, 
                        "p_false": p_false,
                        "first_token": first_tok})
    return results

def prompt_hash(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def load_existing_predictions(path: str) -> Dict[str, Any]:
    if not os.path.exists(path):
        return {}
    existing = {}
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            key = f"{obj['target_id']}|{obj['hop']}|{obj['parent_label']}"
            existing[key] = obj
    print(f"[INFO] Resume enabled: loaded {len(existing)} existing group records.")
    return existing

def main():
    tokenizer, model, true_tokens, false_tokens = model_setup()
    company_info = load_company_info(COMPANY_INFO_PKL) if USE_COMPANY_LABEL else {}

    input_path = os.path.join(CONTEXT_DIR, INPUT_FILE)
    output_path = os.path.join(CONTEXT_DIR, OUTPUT_FILE)
    if not os.path.exists(input_path):
        raise FileNotFoundError(input_path)

    existing = load_existing_predictions(output_path) if RESUME else {}

    fout = open(output_path, "a" if RESUME else "w")

    prob_cache: OrderedDict[str, Dict[str, float]] = OrderedDict()

    def cache_get_or_none(ph):
        return prob_cache.get(ph)

    def cache_set(ph, val):
        prob_cache[ph] = val
        if len(prob_cache) > CACHE_MAX:
            prob_cache.popitem(last=False)

    total_groups = 0
    processed_groups = 0
    missing_label = 0

    batch_prompts: List[str] = []
    batch_meta: List[Tuple[str, str, str, str]] = []

    def flush_batch():
        nonlocal batch_prompts, batch_meta
        if not batch_prompts:
            return
        probs_list = batch_boolean_probs(
            model, tokenizer,
            batch_prompts, true_tokens, false_tokens
        )
        for meta, probs in zip(batch_meta, probs_list):
            ph, kind, gkey, cid = meta
            cache_set(ph, probs)
        batch_prompts.clear()
        batch_meta.clear()

    with open(input_path, "r") as fin:
        for line in tqdm(fin, desc="Groups", disable=not SHOW_PROGRESS):
            line = line.strip()
            if not line:
                continue
            g = json.loads(line)
            total_groups += 1

            gkey = f"{g['target_id']}|{g['hop']}|{g['parent_label']}"
            if RESUME and gkey in existing:
                continue  

            base_prompt = g["base_prompt"]
            exts = g.get("extended", [])

            base_prompt_final = maybe_truncate(
                ensure_prediction_suffix(base_prompt), tokenizer)
            ph_base = prompt_hash(base_prompt_final)
            p_base_entry = cache_get_or_none(ph_base)
            if p_base_entry is None:
                batch_prompts.append(base_prompt_final)
                batch_meta.append((ph_base, "base", gkey, "BASE"))

            for ext in exts:
                ep = maybe_truncate(
                    ensure_prediction_suffix(ext["prompt"]), tokenizer)
                ph_ext = prompt_hash(ep)
                p_ext_entry = cache_get_or_none(ph_ext)
                if p_ext_entry is None:
                    batch_prompts.append(ep)
                    batch_meta.append((ph_ext, "ext", gkey, ext["candidate_id"]))

            if len(batch_prompts) >= BATCH_SIZE:
                flush_batch()

            existing[gkey] = {
                "_raw_group": g,
                "_need_finalize": True
            }

    flush_batch()

    for gkey, record in tqdm(existing.items(), desc="Finalize", disable=not SHOW_PROGRESS):
        if not isinstance(record, dict):
            continue
        if not record.get("_need_finalize"):
            continue
        g = record["_raw_group"]
        target_id = g["target_id"]
        hop = g["hop"]
        parent_label = g["parent_label"]

        base_prompt = maybe_truncate(
            ensure_prediction_suffix(g["base_prompt"]), tokenizer)
        ph_base = prompt_hash(base_prompt)
        base_probs = cache_get_or_none(ph_base)
        if base_probs is None:
            raise RuntimeError("Cache miss for base prompt.")

        p_base_true = base_probs["p_true"]
        p_base_false = base_probs["p_false"]
        pred_base = "True" if p_base_true >= 0.5 else "False"

        label = None
        if COMPUTE_DELTA and USE_COMPANY_LABEL:
            label = get_label(company_info, target_id)
            if label not in (0, 1):
                label = None
                missing_label += 1

        extended_out = []
        for ext in g.get("extended", []):
            ep = maybe_truncate(
                ensure_prediction_suffix(ext["prompt"]), tokenizer)
            ph_ext = prompt_hash(ep)
            ext_probs = cache_get_or_none(ph_ext)
            if ext_probs is None:
                raise RuntimeError("Cache miss for extended prompt.")
            p_ext_true = ext_probs["p_true"]
            p_ext_false = ext_probs["p_false"]
            pred_ext = "True" if p_ext_true >= 0.5 else "False"

            delta_val = None
            if COMPUTE_DELTA and label is not None:
                delta_val = compute_delta(label, p_base_true, p_ext_true, LAMBDA_CONF)

            extended_out.append({
                "candidate_id": ext["candidate_id"],
                "p_true": p_ext_true,
                "p_false": p_ext_false,
                "pred_label": pred_ext,
                "delta": delta_val
            })

        out_obj = {
            "target_id": target_id,
            "hop": hop,
            "parent_label": parent_label,
            "p_true_base": p_base_true,
            "p_false_base": p_base_false,
            "pred_label_base": pred_base,
            "label": label,  
            "lambda_conf": LAMBDA_CONF if COMPUTE_DELTA else None,
            "extended": extended_out
        }
        fout.write(json.dumps(out_obj, ensure_ascii=False) + "\n")
        processed_groups += 1
        if processed_groups % PRINT_EVERY == 0:
            print(f"[INFO] Written {processed_groups} groups...")

    fout.close()

    print(f"\n[DONE] Total groups in file      : {total_groups}")
    print(f"[DONE] Newly processed groups    : {processed_groups}")
    if COMPUTE_DELTA and USE_COMPANY_LABEL:
        print(f"[INFO] Groups missing label (Δ skipped): {missing_label}")
    print(f"[OUTPUT] {output_path}")

if __name__ == "__main__":
    main()


## 6. Obtain the text embedding of the corresponding information segment to facilitate the subsequent training of the selector based on semantic judgment of information gain

In [None]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os, json, pickle, hashlib
from collections import OrderedDict
from typing import List, Dict

import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm


PROMPT_JSONL = "../graph_retrieval_data/context/prompts_3hop_groups.jsonl"
SAVE_DIR     = "../graph_retrieval_data/context"
OUT_PKL      = os.path.join(SAVE_DIR, "prompt_emb.pkl")

BATCH_SIZE   = 1         
DTYPE        = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE       = "cuda:1" if torch.cuda.is_available() else "cpu"

MAX_TOKENS   = 4096
HALF_TOKENS  = MAX_TOKENS // 2

print("[INFO] loading jinaai/jina-embeddings-v2-base-en …")
tok = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v2-base-en", trust_remote_code=True)
model = AutoModel.from_pretrained(
    "jinaai/jina-embeddings-v2-base-en",
    torch_dtype=torch.float16,
    trust_remote_code=True
).to(DEVICE, dtype=DTYPE)
model.eval()
print("[INFO] model loaded.")

if os.path.exists(OUT_PKL):
    with open(OUT_PKL, "rb") as f:
        emb_cache: Dict[str, np.ndarray] = pickle.load(f)
    print(f"[INFO] cache loaded: {len(emb_cache)} embeddings")
else:
    emb_cache = OrderedDict()

def md5(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()

def trim_text_to_max_tokens(text: str) -> str:
    tokens = tok.encode(text, add_special_tokens=False)
    if len(tokens) > MAX_TOKENS:
        tokens = tokens[:HALF_TOKENS] + tokens[-HALF_TOKENS:]
    return tok.decode(tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False)

def encode_batch(texts: List[str]) -> List[np.ndarray]:
    texts = [trim_text_to_max_tokens(t) for t in texts]
    with torch.no_grad():
        enc = tok(
            texts,
            padding=True,
            truncation=True,
            max_length=MAX_TOKENS,
            return_tensors="pt"
        ).to(DEVICE)
        out = model(**enc)
        cls = out.last_hidden_state[:, 0, :]
        cls = torch.nn.functional.normalize(cls, p=2, dim=1)
        return [v.cpu().float().numpy() for v in cls]

batch_txt, batch_hash = [], []
new_cnt = 0

def flush():
    global batch_txt, batch_hash, new_cnt
    if not batch_txt: return
    vecs = encode_batch(batch_txt)
    for h, v in zip(batch_hash, vecs):
        emb_cache[h] = v
    new_cnt += len(batch_txt)
    batch_txt, batch_hash = [], []
    torch.cuda.empty_cache()

with open(PROMPT_JSONL) as f:
    for line in tqdm(f, desc="Scanning prompts"):
        obj = json.loads(line)
        h = md5(obj["base_prompt"])
        if h not in emb_cache:
            batch_txt.append(obj["base_prompt"]); batch_hash.append(h)
        for ext in obj["extended"]:
            h2 = md5(ext["prompt"])
            if h2 not in emb_cache:
                batch_txt.append(ext["prompt"]); batch_hash.append(h2)
        if len(batch_txt) >= BATCH_SIZE:
            flush()

flush() 

print(f"[INFO] new embeddings added: {new_cnt}; total: {len(emb_cache)}")

tmp_out = OUT_PKL + ".tmp"
with open(tmp_out, "wb") as f:
    pickle.dump(emb_cache, f, protocol=pickle.HIGHEST_PROTOCOL)
os.replace(tmp_out, OUT_PKL)
print(f"[DONE] embeddings saved to {OUT_PKL}")