##  1. Extract training/validation/test data for the weight generator

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import math, random, pickle
from pathlib import Path
import pandas as pd
import networkx as nx

csv_path   = Path('../new_company.csv')
graph_pkl  = Path('../graph_2022_invest.pkl')

pkl_a      = Path('../test_data_2022_basic.pkl')
pkl_b      = Path('../train_data_2022_basic.pkl')

out_dir    = csv_path.parent / 'stratified_sampling'
out_dir.mkdir(exist_ok=True, parents=True)

df = pd.read_csv(csv_path, dtype={'time': int, 'CompanyID': str})

print('[Info] Loading graph...')
G = pickle.loads(graph_pkl.read_bytes())
label_map = {d.get('id', n): d.get('label') for n, d in G.nodes(data=True)}

df['label'] = df['CompanyID'].map(label_map)
df = df.dropna(subset=['label']).reset_index(drop=True)
df['label'] = df['label'].astype(int)

def stratified_sample(group: pd.DataFrame, frac: float, rng: random.Random):
    pos_mask = group['label'] == 1
    n_total  = len(group)
    n_take   = max(1, math.ceil(n_total * frac))

    n_pos = pos_mask.sum()
    if n_pos in {0, n_total}:
        return group.sample(n=min(n_take, n_total),
                            random_state=rng.randint(0, 2**32-1))

    ratio_pos = n_pos / n_total
    n_pos_take = max(1, round(n_take * ratio_pos))
    n_neg_take = n_take - n_pos_take

    pos_df = group[pos_mask].sample(n=min(n_pos_take, n_pos),
                                    random_state=rng.randint(0, 2**32-1))
    neg_df = group[~pos_mask].sample(n=min(n_neg_take, len(group)-n_pos),
                                     random_state=rng.randint(0, 2**32-1))
    return pd.concat([pos_df, neg_df])

rng = random.Random(42)
sampled_parts = [stratified_sample(g, 0.2, rng) for _, g in df.groupby('time')]
sampled_df = (pd.concat(sampled_parts)
                .sort_values(['time', 'CompanyID'])
                .reset_index(drop=True))

sampled_df = sampled_df[sampled_df['label'] != -1]

with pkl_a.open('rb') as f:
    keys_a = set(pickle.load(f).keys())
with pkl_b.open('rb') as f:
    keys_b = set(pickle.load(f).keys())

valid_ids = keys_a | keys_b        
sampled_df = sampled_df[sampled_df['CompanyID'].isin(valid_ids)].reset_index(drop=True)

df.to_csv(out_dir / 'new_company_with_label.csv', index=False)
sampled_df.to_csv(out_dir / 'sampled_10pct_temporal_stratified.csv', index=False)


## 2. For each sample, we construct three perspectives for all information prompts. The three perspectives are: (1) Companies with similar backgrounds; (2) Lead investor background analysis; (3) Graph reasoning path analysis.

### (1) Companies with similar backgrounds

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle, json, numpy as np, pandas as pd, torch
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

csv_path  = Path('../sampled_10pct_temporal_stratified.csv')
test_pkl  = Path('../test_data_2022_basic.pkl')
train_pkl  = Path('../train_data_2022_basic.pkl')
out_path  = Path('../test_to_similar.json')

df = pd.read_csv(csv_path, dtype={'CompanyID': str})
df = df[(df['time'] >= 50) & (df['time'] <= 190)]
comp2time = dict(zip(df['CompanyID'], df['time']))

query_ids = list(comp2time.keys())
print(f"Total query companies (time 50-190): {len(query_ids)}")

with test_pkl.open('rb') as f:
    test_data = pickle.load(f)
with train_pkl.open('rb') as f:
    hist_data = pickle.load(f)

all_data = {}
all_data.update(hist_data)
all_data.update(test_data)         

candidate_keys, candidate_texts, candidate_times = [], [], []
missing_basic = 0
for cid, t in comp2time.items():
    info = all_data.get(cid)
    if info and 'basic_info' in info and info['basic_info']:
        candidate_keys.append(cid)
        candidate_texts.append(info['basic_info'])
        candidate_times.append(t)
    else:
        missing_basic += 1

print(f"Candidates w/ basic_info : {len(candidate_keys)}")
if missing_basic:
    print(f"‼  {missing_basic} companies skipped (basic_info missing)")

candidate_times = np.array(candidate_times, dtype=np.int16)

model_name = 'sentence-transformers/all-MiniLM-L6-v2'
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
model  = SentenceTransformer(model_name, device=device)

batch_size = 512
cand_embs  = []
for i in tqdm(range(0, len(candidate_texts), batch_size), desc='Encoding candidates'):
    batch = candidate_texts[i:i+batch_size]
    emb   = model.encode(batch,
                         batch_size=len(batch),
                         convert_to_numpy=True,
                         normalize_embeddings=True,
                         show_progress_bar=False)
    cand_embs.append(emb)
cand_embs = np.vstack(cand_embs).astype('float32')   # [N, 384]

result = {}
for qid in tqdm(query_ids, desc='Searching'):
    q_info = all_data.get(qid)
    if (not q_info) or ('basic_info' not in q_info):
        continue

    q_time = comp2time[qid]
    q_emb  = model.encode(q_info['basic_info'],
                          convert_to_numpy=True,
                          normalize_embeddings=True)

    mask = candidate_times <= q_time

    if qid in candidate_keys:
        self_idx = candidate_keys.index(qid)
        mask[self_idx] = False

    if not mask.any():
        continue

    sims = cand_embs[mask] @ q_emb          
    masked_keys = np.array(candidate_keys)[mask]

    top_k = 4 if sims.shape[0] >= 4 else sims.shape[0]
    top_idx = np.argpartition(-sims, top_k-1)[:top_k]
    top_idx = top_idx[np.argsort(-sims[top_idx])]     

    result[qid] = masked_keys[top_idx].tolist()

with out_path.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)


### (2) Lead investor background analysis

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pickle, json, networkx as nx, pandas as pd
from pathlib import Path

strat_csv     = Path('../sampled_10pct_temporal_stratified.csv')
time_csv      = Path('../company_time_id.csv')
graph_pkl     = Path('../graph_2022.pkl')
out_json_path = Path('../test_company_to_investor.json')

df_q = pd.read_csv(strat_csv, dtype={'CompanyID': str})
df_q = df_q[(df_q['time'] >= 50) & (df_q['time'] <= 190)]
query_ids      = set(df_q['CompanyID'])
query_time_map = dict(zip(df_q['CompanyID'], df_q['time']))
print(f"Query companies: {len(query_ids)}")

df_time = pd.read_csv(time_csv, dtype={'CompanyID': str}).drop_duplicates(subset=['CompanyID'])
time_map = dict(zip(df_time['CompanyID'], df_time['time']))
time_map.update(query_time_map)          

with graph_pkl.open('rb') as f:
    G: nx.MultiGraph = pickle.load(f)

# real-id → node-idx
id_to_node = {attr['id']: n for n, attr in G.nodes(data=True)}

def is_person_id(rid: str) -> bool:
    return rid.endswith('P')

result = {}   # company_id → {'invest_person': ..., 'time': ...}

for cid in query_ids:
    c_time = time_map.get(cid)
    node   = id_to_node.get(cid)
    if (c_time is None) or (node is None):
        result[cid] = {'invest_person': None, 'time': None}
        continue

    best_investor, best_date = None, -1

    for u, v, k, attr in G.edges(node, keys=True, data=True):
        edge_date = attr.get('edge_date')
        if edge_date is None:
            continue
        try:
            edge_date = int(edge_date)
        except ValueError:
            continue

        if edge_date > c_time:
            continue                          

        other = v if u == node else u
        other_id = G.nodes[other]['id']
        if not is_person_id(other_id):
            continue                         

        if edge_date > best_date:
            best_date     = edge_date
            best_investor = other_id

    result[cid] = {'invest_person': best_investor, 'time': int(c_time)}

print(f"Finished. Got investors for {len(result)} companies.")

with out_json_path.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)


### (3) Graph reasoning path analysis

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
from pathlib import Path
import pandas as pd
import math

strat_csv   = Path('../sampled_10pct_temporal_stratified.csv')  
paths_csv   = Path('../company_two_paths.csv')                   
out_json    = Path('../company_two_paths.json')                  

df_q = pd.read_csv(strat_csv, dtype={'CompanyID': str})
df_q = df_q[(df_q['time'] >= 50) & (df_q['time'] <= 190)]
query_ids = set(df_q['CompanyID'])
print(f"Target companies (time in [50,190]): {len(query_ids)}")

df_paths = pd.read_csv(paths_csv, dtype={'CompanyID': str})

path_map = dict(
    zip(
        df_paths['CompanyID'],
        zip(df_paths.get('Path1', pd.Series([None]*len(df_paths))),
            df_paths.get('Path2', pd.Series([None]*len(df_paths))))
    )
)

PATH_SEP = '|'

def _is_nan(x) -> bool:
    return x is None or (isinstance(x, float) and math.isnan(x))

def parse_path(path_str):
    if _is_nan(path_str):
        return None
    s = str(path_str).strip()
    if not s:
        return None
    return [tok.strip() for tok in s.split(PATH_SEP) if tok.strip()]

def check_alternating(path_ids):
    if not path_ids or len(path_ids) < 2:
        return True
    def is_person_id(rid: str) -> bool:
        return rid.endswith('P')
    last_is_person = is_person_id(path_ids[0])
    for rid in path_ids[1:]:
        cur_is_person = is_person_id(rid)
        if cur_is_person == last_is_person:
            return False
        last_is_person = cur_is_person
    return True


result = {}  

missing_in_paths = 0
bad_alt_count = 0
both_ok = 0

for cid in query_ids:
    p1_str, p2_str = path_map.get(cid, (None, None))
    p1 = parse_path(p1_str)
    p2 = parse_path(p2_str)

    if (p1 is None) and (p2 is None):
        missing_in_paths += 1

    if p1 and not check_alternating(p1):
        bad_alt_count += 1
    if p2 and not check_alternating(p2):
        bad_alt_count += 1
    if p1 and p2:
        both_ok += 1

    result[cid] = {'path1': p1, 'path2': p2}

print(f"Finished. Companies exported: {len(result)}")
print(f" - Not found (both paths missing): {missing_in_paths}")
print(f" - Alternation rule violations (non-fatal): {bad_alt_count}")
print(f" - Both paths present: {both_ok}")

with out_json.open('w', encoding='utf-8') as f:
    json.dump(result, f, ensure_ascii=False, indent=2)
print(f"Wrote JSON to: {out_json.resolve()}")

## 3. Construct prompts from four angles and generate corresponding prompt prediction results

### Here we build a prompt for similar companies found by text similarity

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import pickle
from pathlib import Path

mapping_json = Path('../qualified_test_to_similar.json')
test_pkl     = Path('../test_data_2022_basic.pkl')
hist_pkl     = Path('../train_data_2022_basic.pkl')
out_pkl      = Path('../text_similar_company_prompt.pkl')

with mapping_json.open('r', encoding='utf-8') as f:
    mapping = json.load(f)
test_ids = list(mapping.keys())

with test_pkl.open('rb') as f:
    test_data = pickle.load(f)
with hist_pkl.open('rb') as f:
    hist_data = pickle.load(f)

all_data = {}
all_data.update(hist_data)
all_data.update(test_data)

intro = (
    "You are a seasoned venture-capital investor. "
    "A target company has just secured its Series-A financing. "
    "Your task is to predict whether it will obtain a second round of financing, "
    "IPO, or be acquired within the next 12 months. "
    "Below are reference companies (Q) and their outcomes (A). "
    "In the labels, True means the company succeeded within one year; False means it did not.\n"
)

closing = (
    "### Instructions"
    "Based on the reference Q-A pairs above and the target company information,"
    "1. First output your **single-word judgment** on whether the company will obtain a second round, IPO, or be acquired within 12 months."
    "   Use exactly one of the following formats:"
    "      Prediction: True"
    "      Prediction: False"
    ""
    "2. Immediately after that, provide your explanation covering both **positive factors** and **negative factors** that led you to this judgment."
    "   You may write as many sentences as you feel necessary."
    ""
    "### Output Format (exactly)"
    "Prediction: True/False"
    "Analysis:"
    "<your detailed analysis here>"
)

def bool_str(label: int) -> str:
    return "True" if label == 1 else "False"

prompts = {} 

for cid in test_ids:
    ref_ids = mapping.get(cid, [])[:4]
    qa_blocks = []

    for rid in ref_ids:
        ref_obj = all_data.get(rid)
        if not ref_obj:
            continue
        q = ref_obj.get('basic_info', '').strip()
        a = bool_str(ref_obj.get('label', 0))
        qa_blocks.append(f"Q: {q}\nA: {a}\n")

    target_obj = all_data.get(cid, {})
    target_q   = target_obj.get('basic_info', '').strip()
    target_section = (
        "Below is the target company's basic information:\n"
        f"Q: {target_q}\n"
    )

    prompt = (
        f"{intro}\n"
        + "\n".join(qa_blocks)
        + "\n"
        + target_section
        + "\n"
        + closing
    )

    prompts[cid] = {"input_prompt": prompt}

with out_pkl.open('wb') as f:
    pickle.dump(prompts, f)

print(f"Generated prompts for {len(prompts)} companies → {out_pkl}")


### Here is a prompt for background analysis of the lead investor.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pathlib import Path
import json, pickle, random, sys
from collections import defaultdict

company_json = Path('../qualified_test_company_to_investor.json')
train_pkl    = Path('../train_data_2022_basic.pkl')
test_pkl     = Path('../test_data_2022_basic.pkl')
graph_pkl    = Path('../filtered_graph_2022_updated_labels.pkl')
out_pkl      = Path('../company_prompts_single_investor.pkl')

company_map = json.loads(company_json.read_text())

with train_pkl.open("rb") as f:
    train_data = pickle.load(f)
with test_pkl.open("rb") as f:
    test_data = pickle.load(f)

basic_info_db = {**train_data, **test_data}

import networkx as nx

with graph_pkl.open("rb") as f:
    G: nx.MultiDiGraph = pickle.load(f)  

id2index = {str(d["id"]): n for n, d in G.nodes(data=True)}

def label_str(lbl: int) -> str:
    return "True" if lbl == 1 else ("False" if lbl == 0 else "Unknown")


def fetch_basic(cid: str):
    rec = basic_info_db.get(cid, {})
    return rec.get("basic_info", "").strip(), label_str(rec.get("label"))


def collect_person_history(
    person_nd: int, cutoff_time: int, target_cid: str, top_k: int = 5
):
    invest, position = [], []

    for _, nbr, data in G.edges(person_nd, data=True):
        edge_time = data.get("edge_date")
        if edge_time is None or edge_time >= cutoff_time:
            continue 

        cid_nbr = str(G.nodes[nbr]["id"])
        if cid_nbr == target_cid:
            continue  

        edge_type = str(data.get("edge_type")) 
        if edge_type == "0":
            invest.append(nbr)
        else:
            position.append(nbr)

    random.shuffle(invest)
    random.shuffle(position)
    return invest[:top_k], position[:top_k]


def build_prompt(person_id: str, target_cid: str, target_time: int, target_info: str):
    p_nd = id2index.get(person_id)
    if p_nd is None:
        raise RuntimeError("Investor node not found in graph")

    invest_nodes, pos_nodes = collect_person_history(
        p_nd, cutoff_time=target_time, target_cid=target_cid
    )

    invest_lines = []
    for i, n in enumerate(invest_nodes, 1):
        cid = str(G.nodes[n]["id"])
        info, lbl = fetch_basic(cid)
        if info:
            invest_lines.append(f"  • Deal {i}: {info} (Label: {lbl})")

    position_lines = []
    for i, n in enumerate(pos_nodes, 1):
        cid = str(G.nodes[n]["id"])
        info, lbl = fetch_basic(cid)
        if info:
            position_lines.append(f"  • Role {i}: {info} (Label: {lbl})")

    if not invest_lines and not position_lines:
        raise RuntimeError("All history filtered – empty prompt")

    return "\n".join(
        [
            "You are an independent investment analysis expert.",
            f"The following is Investor {person_id}'s track record *prior to month {target_time}*.",
            "",
            "=== Investment History ===",
            *(invest_lines or ["  (none)"]),
            "",
            "=== Board / Executive Positions ===",
            *(position_lines or ["  (none)"]),
            "",
            "=" * 60,
            "",
            "Target company that has just closed its Series-A:",
            f"Q: {target_info}",
            "",
            "### Instructions",
            "• Analyse how the investor's experience relates to the target company.",
            "• Discuss positive factors, negative factors, and open questions.",
            "",
            "Output format (exactly):",
            "Analysis:",
            "<your analysis here>",
            "Prediction: True/False",
        ]
    )


company_prompts = {}
fail_cnt = 0

for cid, meta in company_map.items():
    investor_id = meta.get("invest_person")
    tgt_time = int(meta.get("time", -1))
    tgt_info, _ = fetch_basic(cid)

    if tgt_time < 0 or not investor_id or not tgt_info:
        continue  

    try:
        prompt = build_prompt(investor_id, cid, tgt_time, tgt_info)
        company_prompts[cid] = {"input_prompt": prompt}
    except Exception as e:
        fail_cnt += 1
        print(f"[WARN] skip {cid} ({investor_id}): {e}", file=sys.stderr)

print(
    f"\nGenerated prompts for {len(company_prompts)} companies. "
    f"Skipped: {fail_cnt}"
)

out_pkl.parent.mkdir(parents=True, exist_ok=True)
with out_pkl.open("wb") as f:
    pickle.dump(company_prompts, f)

print("Saved →", out_pkl)


### Here is the prompt for building the graph reasoning path

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
import json
import pickle
from typing import List, Optional, Tuple, Dict

import networkx as nx

paths_json = Path("../company_two_paths.json")
train_pkl  = Path("../train_data_2022_basic.pkl")
test_pkl   = Path("../test_data_2022_basic.pkl")
graph_pkl  = Path("../graph_2022.pkl")
out_pkl    = Path("../company_path_prompts_merged.pkl")

company_paths: Dict[str, Dict[str, List[str]]] = json.loads(paths_json.read_text(encoding="utf-8"))

with train_pkl.open("rb") as f:
    train_data = pickle.load(f)
with test_pkl.open("rb") as f:
    test_data = pickle.load(f)

basic_info_db = {**train_data, **test_data}

with graph_pkl.open("rb") as f:
    G: nx.MultiDiGraph = pickle.load(f)

id2node = {str(d["id"]): n for n, d in G.nodes(data=True)}

def is_person_id(rid: str) -> bool:
    return isinstance(rid, str) and rid.endswith("P")

def label_str(lbl: Optional[int]) -> str:
    return "True" if lbl == 1 else ("False" if lbl == 0 else "Unknown")

def company_display_name(cid: str) -> str:
    rec = basic_info_db.get(cid) or {}
    return (rec.get("name") or "").strip() or cid

def company_profile_and_label(cid: str) -> Tuple[str, str]:
    rec = basic_info_db.get(cid) or {}
    profile = (rec.get("basic_info") or "").strip()
    return profile, label_str(rec.get("label"))

def person_profile(pid: str) -> str:
    nd = id2node.get(pid)
    if nd is None:
        return f"Investor {pid}"
    attrs = G.nodes[nd]
    name = (attrs.get("name") or "").strip()
    binfo = (attrs.get("basic_info") or "").strip()
    if name and binfo:
        return f"{name} — {binfo}"
    if name:
        return name
    if binfo:
        return binfo
    return f"Investor {pid}"

def format_path(ids: Optional[List[str]]) -> str:
    if not ids:
        return "(none)"
    return " -> ".join(ids)

def unique_in_order(items: List[str]) -> List[str]:
    seen, out = set(), []
    for x in items:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out

def collect_entities_both_paths(target_cid: str, p1: Optional[List[str]], p2: Optional[List[str]]):
    seq = (p1 or []) + (p2 or [])
    comp_ids, person_ids = [], []
    comp_seen, pers_seen = set(), set()
    for rid in seq:
        if is_person_id(rid):
            if rid not in pers_seen:
                pers_seen.add(rid)
                person_ids.append(rid)
        else:
            if rid != target_cid and rid not in comp_seen:
                comp_seen.add(rid)
                comp_ids.append(rid)
    return comp_ids, person_ids

def build_merged_prompt(target_cid: str, path1: Optional[List[str]], path2: Optional[List[str]]) -> Optional[str]:
    if not (path1 or path2):
        return None

    target_name = company_display_name(target_cid)
    target_profile, _ = company_profile_and_label(target_cid)

    path_a = format_path(path1)
    path_b = format_path(path2)

    comp_ids, person_ids = collect_entities_both_paths(target_cid, path1, path2)

    company_profile_lines = []
    label_pairs = []
    for cid in comp_ids:
        info, lbl = company_profile_and_label(cid)
        name = company_display_name(cid)
        if info:
            company_profile_lines.append(f"  • {name} ({cid}): {info} (Outcome Label: {lbl})")
        else:
            company_profile_lines.append(f"  • {name} ({cid}) (Outcome Label: {lbl})")
        label_pairs.append(f"{cid}:{lbl}")
    labels_summary = ", ".join(label_pairs) if label_pairs else "(none)"

    investor_lines = []
    for pid in person_ids:
        investor_lines.append(f"  • {person_profile(pid)} ({pid})")

    tgt_block = (target_profile or f"(No profile text available for {target_name})").strip()

    if path_a == "(none)" and path_b == "(none)":
        return None

    lines = [
        "Role: You are a senior venture-capital analyst who excels at step-by-step reasoning over investment paths to judge whether a seed/angel-stage start-up is likely to secure Series-A funding within the next year.",
        "",
        "You are given the following information blocks:",
        f"(1) High-value investment paths retrieved for {target_name} ({target_cid}):",
        f"    (A) {path_a}",
        f"    (B) {path_b}",
        "(2) Company profiles appearing in the paths (each with outcome labels; True = raised Series A within 12 months after seed/angel, False = did not):",
        *(company_profile_lines or ["  (none)"]),
        f"    Success/Failure: {labels_summary}",
        "(3) Investor profiles appearing in the paths:",
        *(investor_lines or ["  (none)"]),
        "(4) Target company profile:",
        f"    {tgt_block}",
        "",
        "Task:",
        f"• Analyse the evidence and predict whether {target_name} will raise a Series-A round within 12 months.",
        "",
        "Output exactly in the format:",
        "Prediction: True/False",
        "Analysis: <your step-by-step reasoning>",
        "",
        "If evidence is insufficient, reason cautiously but still decide.",
    ]
    return "\n".join(lines)

out: Dict[str, Dict[str, object]] = {}
skipped = 0

for cid, obj in company_paths.items():
    p1 = obj.get("path1") or None
    p2 = obj.get("path2") or None

    prompt = build_merged_prompt(cid, p1, p2)
    if not prompt:
        skipped += 1
        continue

    out[cid] = {
        "paths": {"path1": p1, "path2": p2},
        "input_prompt": prompt,
    }

out_pkl.parent.mkdir(parents=True, exist_ok=True)
with out_pkl.open("wb") as f:
    pickle.dump(out, f)

print("Saved →", out_pkl.resolve())

## 4. After getting the required prompt, the following is to call LLM on these prompts in parallel to get the return results

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, pickle, json
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from openai import OpenAI

IN_PKLS = [
    Path('../text_similar_company_prompt.pkl'),
    Path('../company_path_prompts_merged.pkl'),
    Path('../company_prompts_single_investor.pkl'),
]

def out_path(in_path: Path) -> Path:
    return in_path.with_name(in_path.stem + "_with_pred.pkl")

BASE_URL = os.getenv("BASE_URL", "")
MODEL    = ""

api_keys = [
    os.getenv("OPENAI_API_KEY"),
    os.getenv("OPENAI_API_KEY_2"),
    os.getenv("OPENAI_API_KEY_3"),
    os.getenv("OPENAI_API_KEY_4"),
    os.getenv("OPENAI_API_KEY_5"),
    os.getenv("OPENAI_API_KEY_6"),
    os.getenv("OPENAI_API_KEY_7"),
    os.getenv("OPENAI_API_KEY_8"),
]
api_keys = [k for k in api_keys if k]

clients = [OpenAI(api_key=k, base_url=BASE_URL) for k in api_keys for _ in range(2)]
MAX_WORKERS = len(clients)

def call_llm(prompt: str, client: OpenAI) -> str:
    resp = client.chat.completions.create(
        model       = MODEL,
        messages    = [{"role": "user", "content": prompt}],
        temperature = 0.0,
        
    )
    return resp.choices[0].message.content.strip()

SAVE_EVERY_N = 50      

def process_pkl(in_pkl: Path):
    out_pkl = out_path(in_pkl)
    with in_pkl.open('rb') as f:
        data = pickle.load(f)

    if out_pkl.exists():
        with out_pkl.open('rb') as f:
            saved = pickle.load(f)
        for k, v in saved.items():
            if isinstance(v, dict) and v.get('prediction'):
                data.setdefault(k, {}).update({'prediction': v['prediction']})

    tasks = []
    for i, (cid, rec) in enumerate(data.items()):
        if rec.get('prediction'):     
            continue
        prompt = rec.get('input_prompt', '').strip()
        if not prompt:
            data[cid]['prediction'] = ""
            continue
        client = clients[i % MAX_WORKERS]
        tasks.append((cid, prompt, client))

    if not tasks:
        print(f"[Skip] {in_pkl.name} ")
        return

    pending = tasks[:]
    completed = 0
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as exe:
        futures = {exe.submit(call_llm, p, c): cid for cid, p, c in pending}
        pbar = tqdm(total=len(futures), desc=f"LLM predicting ({in_pkl.name})")
        for fut in as_completed(futures):
            cid = futures[fut]
            try:
                data[cid]['prediction'] = fut.result()
            except Exception as e:
                print(f"[Error] {cid}: {e}")
                data[cid]['prediction'] = ""
            completed += 1
            pbar.update(1)

            if completed % SAVE_EVERY_N == 0:
                with out_pkl.open('wb') as f:
                    pickle.dump(data, f)
        pbar.close()

    with out_pkl.open('wb') as f:
        pickle.dump(data, f)
    print(f"[Done] {in_pkl.name} → {out_pkl}")

for pkl_path in IN_PKLS:
    if not pkl_path.exists():
        print(f"[WARN] {pkl_path}")
        continue
    process_pkl(pkl_path)


## 5. Next, we use the text encoder to preprocess all LLM output content to facilitate subsequent training of the weight generator.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pickle, numpy as np
from pathlib import Path
from tqdm import tqdm
from sentence_transformers import SentenceTransformer

IN_PKL   = Path('../combined_predictions.pkl')
EMB_NPZ  = Path('../text_embed.npz')
LBL_NPY  = Path('../labels.npy')

VIEW_KEYS = ['c', 'p', 'pp']      
MODEL_NAME = 'sentence-transformers/all-MiniLM-L6-v2'

with IN_PKL.open('rb') as f:
    raw = pickle.load(f)     

company_ids = sorted(raw.keys())
N = len(company_ids)
print(f"Total samples: {N}")

texts = []
labels = np.zeros((N, 3), dtype=np.int8)

for i, cid in enumerate(company_ids):
    for j, vk in enumerate(VIEW_KEYS):
        view = raw[cid][vk]
        texts.append(view['text'])
        labels[i, j] = int(view['label'])

model = SentenceTransformer(MODEL_NAME)

emb = model.encode(
    texts,
    batch_size=64,
    convert_to_numpy=True,
    normalize_embeddings=True,
    show_progress_bar=True
)  

vectors_text = emb.reshape(N, 4, -1).astype('float32')

print("vectors_text shape :", vectors_text.shape)  
print("labels shape       :", labels.shape)       

np.savez_compressed(
    EMB_NPZ,
    company_ids=np.array(company_ids, dtype='<U20'),
    view_keys=np.array(VIEW_KEYS),
    vectors_text=vectors_text
)

np.save(LBL_NPY, labels)

print(f"Saved embeddings → {EMB_NPZ}")
print(f"Saved labels     → {LBL_NPY}")


## 6. This is the part that quantifies the basic features of the company, which is used by the weight generator to generate different weights based on the company's features.

In [None]:
import pandas as pd

csv1_path = '../Company.csv'
csv2_path = '../CompanyIndustryRelation.csv'
csv3_path = '../CompanyFinancialRelation.csv'
csv4_path = '../CompanyInvestorRelation.csv'
csv5_path = '../CompanyEmployeeHistoryRelation.csv'

df1 = pd.read_csv(csv1_path, nrows=0)
df2 = pd.read_csv(csv2_path, nrows=0)
df3 = pd.read_csv(csv3_path, nrows=0)
df4 = pd.read_csv(csv4_path, nrows=0)
df5 = pd.read_csv(csv5_path, nrows=0)

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import pandas as pd, numpy as np, json, torch
from pathlib import Path

ID2JSON = Path('../ID2index.json')
COMP_CSV = Path('../qualified_new_company.csv')
COMP_RAW = Path('../Company.csv')
DEAL_RAW = Path('../Deal.csv')

ID2index = json.loads(ID2JSON.read_text())
target_ids = pd.read_csv(COMP_CSV, dtype=str)['CompanyID'].tolist()

target_ids = [cid for cid in target_ids if cid in ID2index]
print(f"Target companies in mapping: {len(target_ids)}")

company_df = pd.read_csv(COMP_RAW)
deal_df    = pd.read_csv(DEAL_RAW)

company_df = company_df[ company_df['CompanyID'].isin(target_ids) ]
deal_df    = deal_df[ deal_df['CompanyID'].isin(target_ids) & (deal_df['DealNo']==1) ]

company_df['Index'] = company_df['CompanyID'].map(ID2index)
deal_df['Index']    = deal_df['CompanyID'].map(ID2index)

company_df = company_df.set_index('Index', drop=False)
deal_df    = deal_df.set_index('Index', drop=False)

industry_list   = sorted(company_df['PrimaryIndustryGroup'].dropna().unique())
dealtype_list   = sorted(deal_df['DealType'].dropna().unique())

Industry2idx = {name:i for i,name in enumerate(industry_list)}
DealType2idx = {name:i for i,name in enumerate(dealtype_list)}

print("Industry categories :", len(Industry2idx))
print("DealType categories :", len(DealType2idx))

dim = 2 + len(Industry2idx) + len(DealType2idx)
mat = torch.zeros(len(ID2index), dim, dtype=torch.float32)   

for idx,row in company_df.iterrows():
    mat[idx, 0] = row['YearFounded'] if not pd.isna(row['YearFounded']) else 0
    industry = row['PrimaryIndustryGroup']
    if industry in Industry2idx:
        mat[idx, 2 + Industry2idx[industry]] = 1

base_dt = 2 + len(Industry2idx)
for idx,row in deal_df.iterrows():
    mat[idx, 1] = row['DealSize'] if not pd.isna(row['DealSize']) else 0
    dt = row['DealType']
    if dt in DealType2idx:
        mat[idx, base_dt + DealType2idx[dt]] = 1

row_idx = [ID2index[cid] for cid in target_ids]         
attr_matrix = mat[row_idx].numpy()                      

np.save('company_attr_target.npy', attr_matrix)
np.save('company_id_order.npy', np.array(target_ids))

print("Saved attr matrix :", attr_matrix.shape)
