## 1. Obtain reference data for three viewing angles of the test sample

### （1） Other companies with similar basic information

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle, json, numpy as np, pandas as pd, torch
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

csv_path  = Path('../test_data_sampled.csv')
test_pkl  = Path('../test_data_2022_basic.pkl')
train_pkl  = Path('../train_data_2022_basic.pkl')
out_path  = Path('../test_to_similar.json')

df = pd.read_csv(csv_path, dtype={'CompanyID': str})
df = df[(df['time'] >= 177) & (df['time'] <= 192)]
comp2time = dict(zip(df['CompanyID'], df['time']))

query_ids = list(comp2time.keys())
print(f"Total query companies (time 50-190): {len(query_ids)}")

with test_pkl.open('rb') as f:
    test_data = pickle.load(f)
with train_pkl.open('rb') as f:
    hist_data = pickle.load(f)

all_data = {}
all_data.update(hist_data)
all_data.update(test_data)         

candidate_keys, candidate_texts, candidate_times = [], [], []
missing_basic = 0
for cid, t in comp2time.items():
    info = all_data.get(cid)
    if info and 'basic_info' in info and info['basic_info']:
        candidate_keys.append(cid)
        candidate_texts.append(info['basic_info'])
        candidate_times.append(t)
    else:
        missing_basic += 1

print(f"Candidates w/ basic_info : {len(candidate_keys)}")
if missing_basic:
    print(f"‼  {missing_basic} companies skipped (basic_info missing)")

candidate_times = np.array(candidate_times, dtype=np.int16)

local_model_dir = "../models--sentence-transformers--all-MiniLM-L6-v2"
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

model = SentenceTransformer(
    local_model_dir,
    device=device,
    local_files_only=True  
)

batch_size = 512
cand_embs  = []
for i in tqdm(range(0, len(candidate_texts), batch_size), desc='Encoding candidates'):
    batch = candidate_texts[i:i+batch_size]
    emb   = model.encode(batch,
                         batch_size=len(batch),
                         convert_to_numpy=True,
                         normalize_embeddings=True,
                         show_progress_bar=False)
    cand_embs.append(emb)
cand_embs = np.vstack(cand_embs).astype('float32')  

print(f"Candidate embeddings shape: {cand_embs.shape}")

result = {}
for qid in tqdm(query_ids, desc='Searching'):
    q_info = all_data.get(qid)
    if (not q_info) or ('basic_info' not in q_info):
        continue

    q_time = comp2time[qid]
    q_emb  = model.encode(q_info['basic_info'],
                          convert_to_numpy=True,
                          normalize_embeddings=True)

    mask = candidate_times <= q_time

    if qid in candidate_keys:
        self_idx = candidate_keys.index(qid)
        mask[self_idx] = False

    if not mask.any():
        continue

    sims = cand_embs[mask] @ q_emb          
    masked_keys = np.array(candidate_keys)[mask]

    top_k = 4 if sims.shape[0] >= 4 else sims.shape[0]
    top_idx = np.argpartition(-sims, top_k-1)[:top_k]
    top_idx = top_idx[np.argsort(-sims[top_idx])]     

    result[qid] = masked_keys[top_idx].tolist()

with out_path.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)

print(f"\nDone! Mapping saved → {out_path}")


### （2） Basic information of the lead investor

In [17]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle, json, networkx as nx, pandas as pd
from pathlib import Path

strat_csv     = Path('../test_data_sampled.csv')
time_csv      = Path('../company_time_id.csv')
graph_pkl     = Path('../graph_2022_invest.pkl')
out_json_path = Path('../test_company_to_investor.json')

df_q = pd.read_csv(strat_csv, dtype={'CompanyID': str})
df_q = df_q[(df_q['time'] >= 177) & (df_q['time'] <= 192)]
query_ids      = set(df_q['CompanyID'])
query_time_map = dict(zip(df_q['CompanyID'], df_q['time']))
print(f"Query companies: {len(query_ids)}")

df_time = pd.read_csv(time_csv, dtype={'CompanyID': str}).drop_duplicates(subset=['CompanyID'])
time_map = dict(zip(df_time['CompanyID'], df_time['time']))
time_map.update(query_time_map)          

with graph_pkl.open('rb') as f:
    G: nx.MultiGraph = pickle.load(f)

# real-id → node-idx
id_to_node = {attr['id']: n for n, attr in G.nodes(data=True)}

def is_person_id(rid: str) -> bool:
    return rid.endswith('P')

result = {}   # company_id → {'invest_person': ..., 'time': ...}

for cid in query_ids:
    c_time = time_map.get(cid)
    node   = id_to_node.get(cid)
    if (c_time is None) or (node is None):
        result[cid] = {'invest_person': None, 'time': None}
        continue

    best_investor, best_date = None, -1

    for u, v, k, attr in G.edges(node, keys=True, data=True):
        edge_date = attr.get('edge_date')
        if edge_date is None:
            continue
        try:
            edge_date = int(edge_date)
        except ValueError:
            continue

        if edge_date > c_time:
            continue                         

        other = v if u == node else u
        other_id = G.nodes[other]['id']
        if not is_person_id(other_id):
            continue                         

        if edge_date > best_date:
            best_date     = edge_date
            best_investor = other_id

    result[cid] = {'invest_person': best_investor, 'time': int(c_time)}

print(f"Finished. Got investors for {len(result)} companies.")

with out_json_path.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)

print(f"Saved → {out_json_path}")


Query companies: 2510
Finished. Got investors for 2510 companies.
Saved → /data/VC_LLM_Agent/multi_agent/test_data_sampled/4_agent/test_company_to_investor.json


### （3） Graph reasoning path analysis

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
from pathlib import Path
import pandas as pd
import math

input_csv = Path('../test_data_sampled.csv')         
out_json  = Path('../test_data_sampled_path.json')    

df_in = pd.read_csv(input_csv, dtype={'CompanyID': str})

if 'time' in df_in.columns:
    df_q = df_in[(df_in['time'] >= 50) & (df_in['time'] <= 190)].copy()
else:
    df_q = df_in.copy()

query_ids = set(df_q['CompanyID'])
print(f"Target companies (time in [50,190] if available): {len(query_ids)}")

path1_series = df_in.get('Path1', pd.Series([None] * len(df_in)))
path2_series = df_in.get('Path2', pd.Series([None] * len(df_in)))
path_map = dict(zip(df_in['CompanyID'], zip(path1_series, path2_series)))

PATH_SEP = '|'

def _is_nan(x) -> bool:
    return x is None or (isinstance(x, float) and math.isnan(x))

def parse_path(path_str):
    if _is_nan(path_str):
        return None
    s = str(path_str).strip()
    if not s:
        return None
    return [tok.strip() for tok in s.split(PATH_SEP) if tok.strip()]

def check_alternating(path_ids):
    if not path_ids or len(path_ids) < 2:
        return True
    def is_person_id(rid: str) -> bool:
        return rid.endswith('P')
    last_is_person = is_person_id(path_ids[0])
    for rid in path_ids[1:]:
        cur_is_person = is_person_id(rid)
        if cur_is_person == last_is_person:
            return False
        last_is_person = cur_is_person
    return True

result = {} 
missing_in_paths = 0
bad_alt_count = 0
both_ok = 0

for cid in query_ids:
    p1_str, p2_str = path_map.get(cid, (None, None))
    p1 = parse_path(p1_str)
    p2 = parse_path(p2_str)

    if (p1 is None) and (p2 is None):
        missing_in_paths += 1

    if p1 and not check_alternating(p1):
        bad_alt_count += 1
    if p2 and not check_alternating(p2):
        bad_alt_count += 1
    if p1 and p2:
        both_ok += 1

    result[cid] = {'path1': p1, 'path2': p2}

print(f"Finished. Companies exported: {len(result)}")
print(f" - Not found (both paths missing): {missing_in_paths}")
print(f" - Alternation rule violations (non-fatal): {bad_alt_count}")
print(f" - Both paths present: {both_ok}")

out_json.parent.mkdir(parents=True, exist_ok=True)
with out_json.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Wrote JSON to: {out_json.resolve()}")

## 2. Constructing three angle prompts

### （1） Here are some tips for building information based on companies with similar backgrounds.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import pickle
from pathlib import Path

mapping_json = Path('../test_to_similar.json')
test_pkl     = Path('../test_data_2022_basic.pkl')
hist_pkl     = Path('../train_data_2022_basic.pkl')
out_pkl      = Path('../text_similar_company_prompt.pkl')

with mapping_json.open('r', encoding='utf-8') as f:
    mapping = json.load(f)
test_ids = list(mapping.keys())

with test_pkl.open('rb') as f:
    test_data = pickle.load(f)
with hist_pkl.open('rb') as f:
    hist_data = pickle.load(f)

all_data = {}
all_data.update(hist_data)
all_data.update(test_data)

intro = (
    "You are a seasoned venture-capital investor. "
    "A target company has just secured its Series-A financing. "
    "Your task is to predict whether it will obtain a second round of financing, "
    "IPO, or be acquired within the next 12 months. "
    "Below are reference companies (Q) and their outcomes (A). "
    "In the labels, True means the company succeeded within one year; False means it did not.\n"
)

closing = (
    "### Instructions"
    "Based on the reference Q-A pairs above and the target company information,"
    "1. First output your **single-word judgment** on whether the company will obtain a second round, IPO, or be acquired within 12 months."
    "   Use exactly one of the following formats:"
    "      Prediction: True"
    "      Prediction: False"
    ""
    "2. Immediately after that, provide your explanation covering both **positive factors** and **negative factors** that led you to this judgment."
    "   You may write as many sentences as you feel necessary."
    ""
    "### Output Format (exactly)"
    "Prediction: True/False"
    "Analysis:"
    "<your detailed analysis here>"
)

def bool_str(label: int) -> str:
    return "True" if label == 1 else "False"

prompts = {}   

for cid in test_ids:
    ref_ids = mapping.get(cid, [])[:4]
    qa_blocks = []

    for rid in ref_ids:
        ref_obj = all_data.get(rid)
        if not ref_obj:
            continue
        q = ref_obj.get('basic_info', '').strip()
        a = bool_str(ref_obj.get('label', 0))
        qa_blocks.append(f"Q: {q}\nA: {a}\n")

    target_obj = all_data.get(cid, {})
    target_q   = target_obj.get('basic_info', '').strip()
    target_section = (
        "Below is the target company's basic information:\n"
        f"Q: {target_q}\n"
    )

    prompt = (
        f"{intro}\n"
        + "\n".join(qa_blocks)
        + "\n"
        + target_section
        + "\n"
        + closing
    )

    prompts[cid] = {"input_prompt": prompt}
    
with out_pkl.open('wb') as f:
    pickle.dump(prompts, f)

print(f"Generated prompts for {len(prompts)} companies → {out_pkl}")


### （2） Here we build a prompt based on the background analysis of the lead investor

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import json, pickle, random, sys
from collections import defaultdict

company_json = Path('../test_company_to_investor.json')
train_pkl    = Path('../train_data_2022_basic.pkl')
test_pkl     = Path('../test_data_2022_basic.pkl')
graph_pkl    = Path('../graph_2022.pkl')
out_pkl      = Path('../company_prompts_single_investor.pkl')

company_map = json.loads(company_json.read_text())

with train_pkl.open("rb") as f:
    train_data = pickle.load(f)
with test_pkl.open("rb") as f:
    test_data = pickle.load(f)

basic_info_db = {**train_data, **test_data}

import networkx as nx

with graph_pkl.open("rb") as f:
    G: nx.MultiDiGraph = pickle.load(f) 

id2index = {str(d["id"]): n for n, d in G.nodes(data=True)}

def label_str(lbl: int) -> str:
    return "True" if lbl == 1 else ("False" if lbl == 0 else "Unknown")


def fetch_basic(cid: str):
    rec = basic_info_db.get(cid, {})
    return rec.get("basic_info", "").strip(), label_str(rec.get("label"))


def collect_person_history(
    person_nd: int, cutoff_time: int, target_cid: str, top_k: int = 5
):
    invest, position = [], []

    for _, nbr, data in G.edges(person_nd, data=True):
        edge_time = data.get("edge_date")
        if edge_time is None or edge_time >= cutoff_time:
            continue  

        cid_nbr = str(G.nodes[nbr]["id"])
        if cid_nbr == target_cid:
            continue  

        edge_type = str(data.get("edge_type"))  
        if edge_type == "0":
            invest.append(nbr)
        else:
            position.append(nbr)

    random.shuffle(invest)
    random.shuffle(position)
    return invest[:top_k], position[:top_k]


def build_prompt(person_id: str, target_cid: str, target_time: int, target_info: str):
    p_nd = id2index.get(person_id)
    if p_nd is None:
        raise RuntimeError("Investor node not found in graph")

    invest_nodes, pos_nodes = collect_person_history(
        p_nd, cutoff_time=target_time, target_cid=target_cid
    )

    invest_lines = []
    for i, n in enumerate(invest_nodes, 1):
        cid = str(G.nodes[n]["id"])
        info, lbl = fetch_basic(cid)
        if info:
            invest_lines.append(f"  • Deal {i}: {info} (Label: {lbl})")

    position_lines = []
    for i, n in enumerate(pos_nodes, 1):
        cid = str(G.nodes[n]["id"])
        info, lbl = fetch_basic(cid)
        if info:
            position_lines.append(f"  • Role {i}: {info} (Label: {lbl})")

    if not invest_lines and not position_lines:
        raise RuntimeError("All history filtered – empty prompt")

    return "\n".join(
        [
            "You are an independent venture-capital analyst.",
            f"Below is the historical track record of **Investor {person_id}**, "
            f"who is the **lead investor** in the target company's just-closed Series-A round.",
            f"Only deals *before month {target_time}* are shown to avoid forward-looking bias.",
            "",
            "=== Investment History (chronological) ===",
            *(invest_lines or ["  (none)"]),
            "",
            "=== Board / Executive Positions ===",
            *(position_lines or ["  (none)"]),
            "",
            "=" * 60,
            "",
            "### Target Company",
            f"{target_info}",
            "",
            "### Task",
            "Using **only** the information above, assess whether the lead investor’s past experience ",
            "suggests that the target company is likely to achieve a successful second fund-raise, IPO, ",
            "or acquisition **within the next 12 months**.",
            "",
            "Please discuss:",
            "• Positive signals from the investor’s track record",
            "• Potential red flags or weaknesses",
            "• Any open questions or uncertainties",
            "",
            "### Output format (exactly)",
            "Analysis:",
            "<your analysis here>",
            "Prediction: True/False",
        ]
    )

company_prompts = {}
for cid, meta in company_map.items():
    investor_id = meta.get("invest_person")
    tgt_time    = int(meta.get("time", -1))
    tgt_info, _ = fetch_basic(cid)

    company_prompts[cid] = {"input_prompt": ""}

    if tgt_time < 0 or (not investor_id) or (not tgt_info):
        continue

    try:
        prompt = build_prompt(investor_id, cid, tgt_time, tgt_info)
        company_prompts[cid]["input_prompt"] = prompt         
    except Exception as e:
        print(f"[WARN] skip {cid} ({investor_id}): {e}", file=sys.stderr)

print(f"\nTotal companies: {len(company_prompts)}  "
      f"Non-empty prompts: {sum(bool(v['input_prompt']) for v in company_prompts.values())}")


### （3） Here we build a prompt based on the graph reasoning path

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import json
import pickle
from typing import List, Optional, Tuple, Dict

import networkx as nx

paths_json = Path("../test_data_sampled_path.json")         
train_pkl  = Path("../train_data_2022_basic.pkl")
test_pkl   = Path("../test_data_2022_basic.pkl")
graph_pkl  = Path("../graph_2022.pkl")
out_pkl    = Path("../test_data_sampled_prompts_merged.pkl")

company_paths: Dict[str, Dict[str, List[str]]] = json.loads(
    paths_json.read_text(encoding="utf-8")
)

with train_pkl.open("rb") as f:
    train_data = pickle.load(f)
with test_pkl.open("rb") as f:
    test_data = pickle.load(f)

basic_info_db = {**train_data, **test_data}

with graph_pkl.open("rb") as f:
    G: nx.MultiDiGraph = pickle.load(f)

id2node = {str(d["id"]): n for n, d in G.nodes(data=True)}

def is_person_id(rid: str) -> bool:
    return isinstance(rid, str) and rid.endswith("P")

def label_str(lbl: Optional[int]) -> str:
    return "True" if lbl == 1 else ("False" if lbl == 0 else "Unknown")

def company_display_name(cid: str) -> str:
    rec = basic_info_db.get(cid) or {}
    return (rec.get("name") or "").strip() or cid

def company_profile_and_label(cid: str) -> Tuple[str, str]:
    rec = basic_info_db.get(cid) or {}
    profile = (rec.get("basic_info") or "").strip()
    return profile, label_str(rec.get("label"))

def person_profile(pid: str) -> str:
    nd = id2node.get(pid)
    if nd is None:
        return f"Investor {pid}"
    attrs = G.nodes[nd]
    name = (attrs.get("name") or "").strip()
    binfo = (attrs.get("basic_info") or "").strip()
    if name and binfo:
        return f"{name} — {binfo}"
    if name:
        return name
    if binfo:
        return binfo
    return f"Investor {pid}"

def format_path(ids: Optional[List[str]]) -> str:
    if not ids:
        return "(none)"
    return " -> ".join(ids)

def collect_entities_both_paths(target_cid: str, p1: Optional[List[str]], p2: Optional[List[str]]):
    seq = (p1 or []) + (p2 or [])
    comp_ids, person_ids = [], []
    comp_seen, pers_seen = set(), set()
    for rid in seq:
        if is_person_id(rid):
            if rid not in pers_seen:
                pers_seen.add(rid)
                person_ids.append(rid)
        else:
            if rid != target_cid and rid not in comp_seen:
                comp_seen.add(rid)
                comp_ids.append(rid)
    return comp_ids, person_ids

def build_merged_prompt(target_cid: str, path1: Optional[List[str]], path2: Optional[List[str]]) -> Optional[str]:
    if not (path1 or path2):
        return None

    target_name = company_display_name(target_cid)
    target_profile, _ = company_profile_and_label(target_cid)

    path_a = format_path(path1)
    path_b = format_path(path2)

    comp_ids, person_ids = collect_entities_both_paths(target_cid, path1, path2)

    company_profile_lines = []
    label_pairs = []
    for cid in comp_ids:
        info, lbl = company_profile_and_label(cid)
        name = company_display_name(cid)
        if info:
            company_profile_lines.append(f"  • {name} ({cid}): {info} (Outcome Label: {lbl})")
        else:
            company_profile_lines.append(f"  • {name} ({cid}) (Outcome Label: {lbl})")
        label_pairs.append(f"{cid}:{lbl}")
    labels_summary = ", ".join(label_pairs) if label_pairs else "(none)"

    investor_lines = [f"  • {person_profile(pid)} ({pid})" for pid in person_ids]

    tgt_block = (target_profile or f"(No profile text available for {target_name})").strip()

    if path_a == "(none)" and path_b == "(none)":
        return None

    lines = [
        "Role: You are a senior venture-capital analyst who excels at step-by-step reasoning over investment paths to judge whether a seed/angel-stage start-up is likely to secure Series-A funding within the next year.",
        "",
        "You are given the following information blocks:",
        f"(1) High-value investment paths retrieved for {target_name} ({target_cid}):",
        f"    (A) {path_a}",
        f"    (B) {path_b}",
        "(2) Company profiles appearing in the paths (each with outcome labels; True = raised Series A within 12 months after seed/angel, False = did not):",
        *(company_profile_lines or ["  (none)"]),
        f"    Success/Failure: {labels_summary}",
        "(3) Investor profiles appearing in the paths:",
        *(investor_lines or ["  (none)"]),
        "(4) Target company profile:",
        f"    {tgt_block}",
        "",
        "Task:",
        f"• Analyse the evidence and predict whether {target_name} will raise a Series-A round within 12 months.",
        "",
        "Output exactly in the format:",
        "Prediction: True/False",
        "Analysis: <your step-by-step reasoning>",
        "",
        "If evidence is insufficient, reason cautiously but still decide.",
    ]
    return "\n".join(lines)

out: Dict[str, Dict[str, object]] = {}
skipped = 0

for cid, obj in company_paths.items():
    p1 = obj.get("path1") or None
    p2 = obj.get("path2") or None

    prompt = build_merged_prompt(cid, p1, p2)
    if not prompt:
        skipped += 1
        continue

    out[cid] = {
        "paths": {"path1": p1, "path2": p2},
        "input_prompt": prompt,
    }

print(f"Built merged prompts for {len(out)} companies; skipped {skipped} without usable content.")

out_pkl.parent.mkdir(parents=True, exist_ok=True)
with out_pkl.open("wb") as f:
    pickle.dump(out, f)

print("Saved →", out_pkl.resolve())

## 3. Here is the parallel call LLM analysis prompt

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os, pickle, json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from openai import OpenAI

IN_PKLS = [
    Path('../text_similar_company_prompt.pkl'),
    Path('../company_prompts_single_investor.pkl'),
    Path('../test_data_sampled_prompts_merged.pkl'),
]

def out_path(in_path: Path) -> Path:
    return in_path.with_name(in_path.stem + "_with_pred.pkl")

BASE_URL = os.getenv("BASE_URL", "")
MODEL    = ""

api_keys = [
    os.getenv("OPENAI_API_KEY"),
    os.getenv("OPENAI_API_KEY_2"),
    os.getenv("OPENAI_API_KEY_3"),
    os.getenv("OPENAI_API_KEY_4"),
    os.getenv("OPENAI_API_KEY_5"),
    os.getenv("OPENAI_API_KEY_6"),
    os.getenv("OPENAI_API_KEY_7"),
    os.getenv("OPENAI_API_KEY_8"),
]
api_keys = [k for k in api_keys if k]

clients = [OpenAI(api_key=k, base_url=BASE_URL) for k in api_keys for _ in range(1)]
MAX_WORKERS = 4

def call_llm(prompt: str, client: OpenAI) -> str:
    resp = client.chat.completions.create(
        model       = MODEL,
        messages    = [{"role": "user", "content": prompt}],
        temperature = 0.0,
        
    )
    return resp.choices[0].message.content.strip()

SAVE_EVERY_N = 50       

def process_pkl(in_pkl: Path):
    out_pkl = out_path(in_pkl)
    with in_pkl.open('rb') as f:
        data = pickle.load(f)

    if out_pkl.exists():
        with out_pkl.open('rb') as f:
            saved = pickle.load(f)
        for k, v in saved.items():
            if isinstance(v, dict) and v.get('prediction'):
                data.setdefault(k, {}).update({'prediction': v['prediction']})

    tasks, skipped_empty = [], 0
    for i, (cid, rec) in enumerate(data.items()):
        if rec.get('prediction'):    
            continue
        prompt = (rec.get('input_prompt') or "").strip()
        if not prompt:                         
            data[cid]['prediction'] = ""     
            skipped_empty += 1
            continue
        client = clients[i % MAX_WORKERS]
        tasks.append((cid, prompt, client))

    print(f"\n[{in_pkl.name}] empty prompt skipped: {skipped_empty} | to infer: {len(tasks)}")

    if not tasks:
        with out_pkl.open('wb') as f:
            pickle.dump(data, f)
        return

    pending = tasks[:]
    completed = 0
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe:
        futures = {exe.submit(call_llm, p, c): cid for cid, p, c in pending}
        pbar = tqdm(total=len(futures), desc=f"LLM predicting ({in_pkl.name})")
        for fut in as_completed(futures):
            cid = futures[fut]
            try:
                data[cid]['prediction'] = fut.result()
            except Exception as e:
                print(f"[Error] {cid}: {e}")
                data[cid]['prediction'] = ""
            completed += 1
            pbar.update(1)

            if completed % SAVE_EVERY_N == 0:
                with out_pkl.open('wb') as f:
                    pickle.dump(data, f)
        pbar.close()

    with out_pkl.open('wb') as f:
        pickle.dump(data, f)
    print(f"[Done] {in_pkl.name} → {out_pkl}")

for pkl_path in IN_PKLS:
    if not pkl_path.exists():
        print(f"[WARN] {pkl_path} ")
        continue
    process_pkl(pkl_path)


## 4. Below is the code to vectorize the output of LLM

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pickle, numpy as np, re
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

root = Path('../3_agent_prompt')
IN_PKL  = root / 'combined_predictions.pkl'
EMB_NPZ = root / 'processed_vector/MiniLM_text_embed.npz'
LBL_NPY = root / 'processed_vector/labels.npy'

VIEW_KEYS  = ['c', 'cc', 'pp']
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'

raw = pickle.loads(IN_PKL.read_bytes())
company_ids = sorted(raw.keys())
N = len(company_ids)
print(f"Total samples: {N}")

empty_mask = np.zeros((N, 4), dtype=bool) 
texts  = []
labels = np.full((N, 4), fill_value=-1, dtype=np.int8) 

for i, cid in enumerate(company_ids):
    rec = raw[cid]
    for j, vk in enumerate(VIEW_KEYS):
        view = rec.get(vk, {}) or {}                      
        txt  = (view.get('text') or '').strip()
        lbl  = view.get('label')

        texts.append(txt)
        if lbl in (0, 1):             
            labels[i, j] = lbl
        if txt == '':
            empty_mask[i, j] = True    

model = SentenceTransformer(MODEL_NAME)
emb = model.encode(
        texts,
        batch_size=64,
        convert_to_numpy=True,
        normalize_embeddings=True,
        show_progress_bar=True
)                                 

vectors_text = emb.reshape(N, 4, -1).astype('float32')

vectors_text[empty_mask] = 0.0

print("vectors_text shape :", vectors_text.shape)
print("labels shape       :", labels.shape)
print(f"Empty texts count  : {empty_mask.sum()}")

EMB_NPZ.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
    EMB_NPZ,
    company_ids=np.array(company_ids, dtype='<U20'),
    view_keys=np.array(VIEW_KEYS),
    vectors_text=vectors_text
)
np.save(LBL_NPY, labels)

print(f"Saved embeddings → {EMB_NPZ}")
print(f"Saved labels     → {LBL_NPY}")


## 5. Here is the code that handles the vectorization of the test company features

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pandas as pd, numpy as np, json, torch
from pathlib import Path

ID2JSON   = Path('../ID2index.json')
COMP_CSV  = Path('../test_data_sampled.csv')
COMP_RAW  = Path('../Company.csv')
DEAL_RAW  = Path('../Deal.csv')

INDUSTRY_DIM = 39            

ID2index   = json.loads(ID2JSON.read_text())
target_ids = pd.read_csv(COMP_CSV, dtype=str)['CompanyID'].tolist()
target_ids = [cid for cid in target_ids if cid in ID2index]
print(f"Target companies in mapping: {len(target_ids)}")

company_df = pd.read_csv(COMP_RAW)
deal_df    = pd.read_csv(DEAL_RAW)
company_df = company_df[company_df['CompanyID'].isin(target_ids)]
deal_df    = deal_df[(deal_df['CompanyID'].isin(target_ids)) & (deal_df['DealNo'] == 1)]

company_df['Index'] = company_df['CompanyID'].map(ID2index)
deal_df['Index']    = deal_df['CompanyID'].map(ID2index)
company_df = company_df.set_index('Index', drop=False)
deal_df    = deal_df.set_index('Index', drop=False)

all_industries = sorted(company_df['PrimaryIndustryGroup'].dropna().unique())
industry_list  = all_industries[:INDUSTRY_DIM]          
Industry2idx   = {name: i for i, name in enumerate(industry_list)}

dealtype_list  = sorted(deal_df['DealType'].dropna().unique())
DealType2idx   = {name: i for i, name in enumerate(dealtype_list)}

print(f"Industry categories (capped) : {len(industry_list)}/{INDUSTRY_DIM}")
print(f"DealType categories          : {len(DealType2idx)}")

dim = 2 + INDUSTRY_DIM + len(DealType2idx)          
mat = torch.zeros(len(ID2index), dim, dtype=torch.float32)

for idx, row in company_df.iterrows():
    mat[idx, 0] = row['YearFounded'] if not pd.isna(row['YearFounded']) else 0
    ind = row['PrimaryIndustryGroup']
    if ind in Industry2idx:                          
        mat[idx, 2 + Industry2idx[ind]] = 1

base_dt = 2 + INDUSTRY_DIM
for idx, row in deal_df.iterrows():
    mat[idx, 1] = row['DealSize'] if not pd.isna(row['DealSize']) else 0
    dt = row['DealType']
    if dt in DealType2idx:
        mat[idx, base_dt + DealType2idx[dt]] = 1

row_idx = [ID2index[cid] for cid in target_ids]
attr_matrix = mat[row_idx].numpy().astype('float32')
np.save('test_sample_company_attr_target.npy', attr_matrix)
np.save('test_sample_company_id_order.npy', np.array(target_ids))
print("Saved attr matrix :", attr_matrix.shape)    


## 6. Here is the code to get the weight of the corresponding agent through the output of the three agents and the company feature vector

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import csv
import json
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

root = Path("..")
EMB_NPZ = root / "processed_vector" / "MiniLM_text_embed.npz"
LBL_NPY = root / "processed_vector" / "labels.npy"
TEST_CSV = root / "test_data_sampled.csv"

ATTR_NPY = root / "test_sample_company_attr_target.npy"
ATTR_ID_ORDER_OPT = root / "test_sample_company_id_order.npy"   

MODEL_CKPT = Path("./best_gate.pt")   
OUT_CSV = root / "test_view_weights.csv"

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 256

npz = np.load(EMB_NPZ, allow_pickle=True)
company_ids_all = npz["company_ids"]            
V_text_all = npz["vectors_text"].astype(np.float32)  

bool_raw = np.load(LBL_NPY).astype(np.float32)  
bool_all = bool_raw[..., None]                 

N, n_views, D_text = V_text_all.shape
assert n_views == 3, f"Expect 3 views, got {n_views}"

id2idx = {str(cid): i for i, cid in enumerate(company_ids_all)}

df_test = pd.read_csv(TEST_CSV, dtype={"CompanyID": str})
test_ids = df_test["CompanyID"].astype(str).tolist()

kept = [cid for cid in test_ids if cid in id2idx]
missing = [cid for cid in test_ids if cid not in id2idx]
if missing:
    print(f"[WARN] {len(missing)} test CompanyID(s) not in embeddings; will skip.")

idxs = [id2idx[cid] for cid in kept]

V_text = V_text_all[idxs]    
B_bool = bool_all[idxs]       
cids_sub = np.array(kept)

attr_mat = np.load(ATTR_NPY)  
K = attr_mat.shape[1]

if ATTR_ID_ORDER_OPT.exists():
    attr_ids = np.load(ATTR_ID_ORDER_OPT)
    id2attr = {str(cid): i for i, cid in enumerate(attr_ids)}
    attr_rows = []
    not_found_attr = []
    for cid in cids_sub:
        j = id2attr.get(cid)
        if j is None:
            not_found_attr.append(cid)
            attr_rows.append(np.zeros((K,), dtype=np.float32))
        else:
            attr_rows.append(attr_mat[j])
    if not_found_attr:
        print(f"[WARN] {len(not_found_attr)} companies missing in ATTR id order; filled zeros.")
    C_feat = np.stack(attr_rows, axis=0).astype(np.float32)
else:
    if attr_mat.shape[0] != len(cids_sub):
        print(f"[WARN] ATTR rows ({attr_mat.shape[0]}) != #test companies ({len(cids_sub)}). "
              f"Will truncate/pad with zeros as needed.")
    if attr_mat.shape[0] >= len(cids_sub):
        C_feat = attr_mat[:len(cids_sub)].astype(np.float32)
    else:
        pad = np.zeros((len(cids_sub) - attr_mat.shape[0], K), dtype=np.float32)
        C_feat = np.vstack([attr_mat, pad]).astype(np.float32)

class MoEGate(nn.Module):
    def __init__(self, text_dim=384, comp_dim=1, n_views=3):
        super().__init__()
        self.n_views = n_views
        self.log_wb = nn.Parameter(torch.tensor(1.1)) 

        view_dim = text_dim + 1  

        self.gate = nn.Sequential(
            nn.Linear(n_views * view_dim + comp_dim, 128),
            nn.ReLU(),
            nn.Linear(128, n_views),
        )
        self.clf = nn.Sequential(
            nn.Linear(view_dim, 64), nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, V_text, B_bool, C):
        wb = torch.exp(self.log_wb)
        V = torch.cat([V_text, wb * B_bool], dim=-1)   
        B = V.size(0)
        flat = torch.cat([V.view(B, -1), C], dim=-1)    
        alpha = torch.softmax(self.gate(flat), dim=-1) 
        v_agg = (alpha.unsqueeze(-1) * V).sum(1)         
        logit = self.clf(v_agg).squeeze(-1)             
        return logit, alpha

model = MoEGate(text_dim=D_text, comp_dim=K, n_views=3).to(DEVICE)
state = torch.load(MODEL_CKPT, map_location=DEVICE)
model.load_state_dict(state)
model.eval()

class InferDS(Dataset):
    def __init__(self, V, B, C, ids):
        self.V = torch.from_numpy(V).float()
        self.B = torch.from_numpy(B).float()
        self.C = torch.from_numpy(C).float()
        self.ids = list(ids)
    def __len__(self): return len(self.ids)
    def __getitem__(self, i):
        return self.V[i], self.B[i], self.C[i], self.ids[i]

loader = DataLoader(InferDS(V_text, B_bool, C_feat, cids_sub),
                    batch_size=BATCH_SIZE, shuffle=False)

rows = [] 
with torch.no_grad():
    for Vb, Bb, Cb, id_batch in loader:
        Vb, Bb, Cb = Vb.to(DEVICE), Bb.to(DEVICE), Cb.to(DEVICE)
        logits, alpha = model(Vb, Bb, Cb) 
        probs = torch.sigmoid(logits).cpu().numpy()
        w = alpha.cpu().numpy()
        for cid, p, wi in zip(id_batch, probs, w):
            rows.append({
                "CompanyID": cid,
                "w1": float(wi[0]),
                "w2": float(wi[1]),
                "w3": float(wi[2]),
                "prob": float(p),
            })

df_out = pd.DataFrame(rows, columns=["CompanyID", "w1", "w2", "w3", "prob"])
OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
df_out.to_csv(OUT_CSV, index=False, encoding="utf-8")
print(f"Saved gating weights → {OUT_CSV.resolve()}")

## 7. Here is the process of constructing the prompt of Manager Agent based on the output of the three agents and the weights for each sample obtained above

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import pickle, json, re
from typing import Dict, Any
import numpy as np
import pandas as pd

ROOT = Path("..")

IN_PKLS = [
     "../text_similar_company_prompt.pkl",   
     "../company_prompts_single_investor.pkl",  
     "../test_data_sampled_prompts_merged.pkl", 
]

WEIGHT_CSV = "../test_view_weights.csv"

TRAIN_BASIC_PKL = "../train_data_2022_basic.pkl"
TEST_BASIC_PKL  = "../test_data_2022_basic.pkl"

OUT_PKL = "../manager_prompts.pkl"

def norm_pred(x: Any) -> str:
    if x is None: return "Unknown"
    s = str(x).strip().lower()
    if s in {"1","true","yes","y","positive"}: return "True"
    if s in {"0","false","no","n","negative"}: return "False"
    m = re.search(r"\b(prediction)\s*:\s*(true|false)\b", s, re.I)
    if m: return m.group(2).capitalize()
    return "Unknown"

PRED_RE = re.compile(r"(?im)^\s*prediction\s*:\s*(true|false)\s*$")
ANAL_RE = re.compile(r"(?is)\banalysis\s*:\s*(.+)\Z")

def parse_pred_analysis_from_text(txt: str) -> Dict[str,str]:
    """从一段 LLM 文本里提取 Prediction/Analysis。"""
    pred = "Unknown"; ana = ""
    if not txt:
        return {"prediction": pred, "analysis": ana}
    m1 = PRED_RE.search(txt)
    if m1: pred = m1.group(1).capitalize()
    m2 = ANAL_RE.search(txt)
    if m2: ana = m2.group(1).strip()
    return {"prediction": pred, "analysis": ana}

def lower_keys(d: Dict[str, Any]) -> Dict[str, Any]:
    return { (k.lower() if isinstance(k,str) else k): v for k,v in d.items() }

def load_agent_pkl(pkl_path: Path) -> Dict[str, Dict[str, str]]:
    data = pickle.loads(pkl_path.read_bytes())
    out: Dict[str, Dict[str,str]] = {}
    if not isinstance(data, dict):
        raise ValueError(f"{pkl_path} should be a dict keyed by CompanyID")

    for cid, v in data.items():
        cid_str = str(cid)
        if not isinstance(v, dict):
            if isinstance(v, str):
                pa = parse_pred_analysis_from_text(v)
                out[cid_str] = pa
            else:
                out[cid_str] = {"prediction": "Unknown", "analysis": ""}
            continue

        lv = lower_keys(v)
        pred = lv.get("prediction")
        ana  = lv.get("analysis")

        if pred is not None or ana is not None:
            out[cid_str] = {
                "prediction": norm_pred(pred),
                "analysis": ("" if ana is None else str(ana)).strip()
            }
            continue

        for key in ["output"]:
            if key in lv and isinstance(lv[key], str):
                pa = parse_pred_analysis_from_text(lv[key])
                out[cid_str] = pa
                break
        else:
            out[cid_str] = {"prediction": "Unknown", "analysis": ""}

    return out

def load_basic_db() -> Dict[str, Dict[str, Any]]:
    with TRAIN_BASIC_PKL.open("rb") as f:
        train_db = pickle.load(f)
    with TEST_BASIC_PKL.open("rb") as f:
        test_db = pickle.load(f)
    return {**train_db, **test_db}

def company_profile(db: Dict[str, Dict[str, Any]], cid: str) -> str:
    rec = db.get(cid, {})
    return (rec.get("basic_info") or "").strip()

def company_name(db: Dict[str, Dict[str, Any]], cid: str) -> str:
    rec = db.get(cid, {})
    return (rec.get("name") or "").strip() or cid

wdf = pd.read_csv(WEIGHT_CSV, dtype={"CompanyID": str})
cols = {c.lower(): c for c in wdf.columns}
col_cid = cols.get("companyid","CompanyID")
w1c = cols.get("w1","w1"); w2c = cols.get("w2","w2"); w3c = cols.get("w3","w3")

def normalize_row(row):
    w = np.array([row[w1c], row[w2c], row[w3c]], dtype=float)
    s = w.sum()
    return (w/s if s>0 else np.array([1/3,1/3,1/3], float)).astype(float)

weights_map = { str(r[col_cid]): normalize_row(r) for _, r in wdf.iterrows() }

sim_map  = load_agent_pkl(IN_PKLS[0])   
inv_map  = load_agent_pkl(IN_PKLS[1])  
path_map = load_agent_pkl(IN_PKLS[2])  

basic_db = load_basic_db()

def build_manager_prompt(cid: str, w_path: float, w_sim: float, w_inv: float) -> str:
    path_pred = path_map.get(cid, {}).get("prediction","Unknown")
    path_ana  = path_map.get(cid, {}).get("analysis","")

    sim_pred  = sim_map.get(cid, {}).get("prediction","Unknown")
    sim_ana   = sim_map.get(cid, {}).get("analysis","")

    inv_pred  = inv_map.get(cid, {}).get("prediction","Unknown")
    inv_ana   = inv_map.get(cid, {}).get("analysis","")

    tgt_name  = company_name(basic_db, cid)
    tgt_prof  = company_profile(basic_db, cid) or f"(No profile text available for {tgt_name})"

    weights_text = f"[Path={w_path:.3f}, Similar-company={w_sim:.3f}, Lead-investor={w_inv:.3f}]"

    lines = [
        "Role:",
        "You are a senior venture-capital analyst who excels at synthesizing other experts' viewpoints to decide whether a seed/angel-stage start-up will secure Series-A funding within the next year.",
        "",
        "You are given:",
        "",
        "(1) Path-analyst verdict",
        f"• Prediction: {path_pred}",
        f"• Analysis  : {path_ana}",
        "",
        "(2) Similar-company analyst verdict",
        f"• Prediction: {sim_pred}",
        f"• Analysis  : {sim_ana}",
        "",
        "(3) Lead-investor analyst verdict",
        f"• Prediction: {inv_pred}",
        f"• Analysis  : {inv_ana}",
        "",
        "(4) Aggregate-weight advice",
        "The historical importance of the three perspectives is",
        f"{weights_text}",
        "",
        "(5) Target company profile",
        f"{tgt_prof}",
        "",
        "Task:",
        "• Produce a single, final prediction on whether the target will raise a Series-A round within 12 months.",
        "• Output **exactly** in the format:",
        "",
        "Prediction: True/False",
        "Analysis: <your step-by-step reasoning>",
        "",
        "• If evidence is insufficient, reason cautiously but still decide.",
    ]
    return "\n".join(lines)

out = {}
skipped = 0

for cid, w_vec in weights_map.items():
    w_path, w_sim, w_inv = float(w_vec[0]), float(w_vec[1]), float(w_vec[2])

    try:
        prompt = build_manager_prompt(cid, w_path, w_sim, w_inv)
        out[cid] = {
            "input_prompt": prompt,
            "weights": [w_path, w_sim, w_inv],
            "meta": {
                "path_pred": path_map.get(cid, {}).get("prediction", "Unknown"),
                "path_analysis": path_map.get(cid, {}).get("analysis", ""),
                "sim_pred":  sim_map.get(cid, {}).get("prediction", "Unknown"),
                "sim_analysis": sim_map.get(cid, {}).get("analysis", ""),
                "inv_pred":  inv_map.get(cid, {}).get("prediction", "Unknown"),
                "inv_analysis": inv_map.get(cid, {}).get("analysis", ""),
                "target_profile": company_profile(basic_db, cid),
            }
        }
    except Exception as e:
        skipped += 1
        print(f"[WARN] skip {cid}: {e}")

print(f"Built manager prompts: {len(out)}  (skipped {skipped})")

OUT_PKL.parent.mkdir(parents=True, exist_ok=True)
with OUT_PKL.open("wb") as f:
    pickle.dump(out, f)

print("Saved →", OUT_PKL.resolve())

## 8. Here is the process of parallel reasoning to achieve the final Manager Agent

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import re
import pickle
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from openai import OpenAI

IN_PKL  = Path("../manager_prompts.pkl")
OUT_PKL = IN_PKL.with_name(IN_PKL.stem + "_with_pred.pkl")

BASE_URL = os.getenv("BASE_URL", "")   
MODEL    = os.getenv("LLM_MODEL", "")  

api_keys = [
    os.getenv("OPENAI_API_KEY"),
    os.getenv("OPENAI_API_KEY_2"),
    os.getenv("OPENAI_API_KEY_3"),
    os.getenv("OPENAI_API_KEY_4"),
    os.getenv("OPENAI_API_KEY_5"),
    os.getenv("OPENAI_API_KEY_6"),
    os.getenv("OPENAI_API_KEY_7"),
    os.getenv("OPENAI_API_KEY_8"),
]
api_keys = [k for k in api_keys if k]
if not api_keys:
    raise RuntimeError("No API key found in env (OPENAI_API_KEY, ...).")

clients = [OpenAI(api_key=k, base_url=BASE_URL) for k in api_keys]
MAX_WORKERS = min(8, len(clients)) 

PRED_RE = re.compile(r"(?im)^\s*prediction\s*:\s*(true|false)\s*$")
ANAL_RE = re.compile(r"(?is)\banalysis\s*:\s*(.+)\Z")

def parse_prediction_analysis(text: str):
    """Extract `Prediction:` (True/False) and `Analysis:` from model output."""
    if not text:
        return "Unknown", ""
    pred = "Unknown"
    m1 = PRED_RE.search(text)
    if m1:
        pred = m1.group(1).capitalize()
    m2 = ANAL_RE.search(text)
    analysis = m2.group(1).strip() if m2 else ""
    return pred, analysis

def call_llm(prompt: str, client: OpenAI) -> str:
    resp = client.chat.completions.create(
        model       = MODEL,
        messages    = [{"role": "user", "content": prompt}],
        temperature = 0.0,
    )
    return resp.choices[0].message.content.strip()

SAVE_EVERY_N = 50

def main():
    if not IN_PKL.exists():
        raise FileNotFoundError(f"Input PKL not found: {IN_PKL}")

    with IN_PKL.open("rb") as f:
        data = pickle.load(f) 

    if OUT_PKL.exists():
        with OUT_PKL.open("rb") as f:
            saved = pickle.load(f)
        for cid, rec in saved.items():
            if isinstance(rec, dict) and rec.get("prediction"):
                data.setdefault(cid, {}).update({
                    "prediction": rec.get("prediction"),
                    "manager_prediction": rec.get("manager_prediction"),
                    "manager_analysis": rec.get("manager_analysis"),
                })

    tasks = []
    skipped_empty = 0
    already_done  = 0
    for i, (cid, rec) in enumerate(data.items()):
        if rec.get("prediction"):        
            already_done += 1
            continue
        prompt = (rec.get("input_prompt") or "").strip()
        if not prompt:
            data[cid]["prediction"] = ""
            data[cid]["manager_prediction"] = "Unknown"
            data[cid]["manager_analysis"] = ""
            skipped_empty += 1
            continue
        client = clients[i % max(1, len(clients))]
        tasks.append((cid, prompt, client))

    print(f"\n[{IN_PKL.name}] already_done: {already_done} | empty skipped: {skipped_empty} | to infer: {len(tasks)}")

    if tasks:
        completed = 0
        with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe:
            futures = {exe.submit(call_llm, p, c): cid for cid, p, c in tasks}
            pbar = tqdm(total=len(futures), desc=f"LLM predicting ({IN_PKL.name})")
            for fut in as_completed(futures):
                cid = futures[fut]
                try:
                    raw = fut.result()
                    pred, ana = parse_prediction_analysis(raw)
                    data[cid]["prediction"] = raw                  
                    data[cid]["manager_prediction"] = pred        
                    data[cid]["manager_analysis"] = ana           
                except Exception as e:
                    print(f"[Error] {cid}: {e}")
                    data[cid]["prediction"] = ""
                    data[cid]["manager_prediction"] = "Unknown"
                    data[cid]["manager_analysis"] = ""
                completed += 1
                pbar.update(1)
                if completed % SAVE_EVERY_N == 0:
                    with OUT_PKL.open("wb") as f:
                        pickle.dump(data, f)
            pbar.close()

    with OUT_PKL.open("wb") as f:
        pickle.dump(data, f)
    print(f"[Done] {IN_PKL.name} → {OUT_PKL}")

if __name__ == "__main__":
    main()