In [None]:
import os
import torch
from torch_geometric.data import HeteroData
from tqdm.notebook import tqdm

In [None]:
def load_num_nodes(pt_folder):
    """Count nodes by summing length of all chunk files."""
    total = 0
    for file in sorted(os.listdir(pt_folder)):
        if file.endswith(".pt"):
            chunk = torch.load(os.path.join(pt_folder, file), map_location="cpu")
            total += len(chunk)
    return total


def load_edge_index(pt_folder):
    """Concatenate edge_index chunks safely."""
    edges = []

    for file in sorted(os.listdir(pt_folder)):
        if file.endswith(".pt"):
            ei = torch.load(os.path.join(pt_folder, file), map_location="cpu")
            edges.append(ei)

    if len(edges) == 0:
        return None

    return torch.cat(edges, dim=1)   # (2, N_total)


def build_hetero_graph(node_dirs: dict, edge_dirs: dict, out_path: str):
    """
    node_dirs = {
        "nasabah": "node_nasabah_pt",
        "pekerja": "node_pekerja_pt",
        ...
    }

    edge_dirs = {
        ("nasabah", "is_pekerja", "pekerja"): "edge_nasabah_is_pekerja_pt",
        ("nasabah", "memiliki_simp", "simpanan"): "edge_nasabah_memiliki_simp_pt",
        ...
    }
    """

    data = HeteroData()

    print("=== LOADING NODE COUNTS ===")
    for ntype, folder in node_dirs.items():
        num_nodes = load_num_nodes(folder)
        print(f"{ntype:<12} â†’ {num_nodes:,} nodes")

        data[ntype].num_nodes = num_nodes
        data[ntype].x = None   # placeholder (optional)


    print("\n=== LOADING EDGE INDEX ===")
    for (src, rel, dst), folder in edge_dirs.items():
        print(f"{src} -[{rel}]-> {dst}")

        edge_index = load_edge_index(folder)

        if edge_index is None:
            print(f"WARNING: No edges found in {folder}")
            continue

        data[(src, rel, dst)].edge_index = edge_index
        print(f"  Loaded edges: {edge_index.shape[1]:,}")


    print("\nSaving final hetero graph:", out_path)
    torch.save(data, out_path)
    print("Done.")

In [None]:
node_dirs = {
    "nasabah":   "./node_nasabah_pt",
    "pekerja":   "./node_pekerja_pt",
    "simpanan":  "./node_simpanan_pt",
    "pinjaman":  "./node_pinjaman_pt",
    "transaksi": "./node_transaksi_pt",
}

edge_dirs = {
    ("simpanan", "rek_credit", "transaksi"): "./edge_rek_credit_pt",
    ("simpanan", "rek_debit", "transaksi"): "./edge_rek_debit_pt",
    ("nasabah", "is_pekerja", "pekerja"): "./edge_nasabah_is_pekerja_pt",
    ("nasabah", "memiliki_simp", "simpanan"): "./edge_nasabah_memiliki_simp_pt",
    ("nasabah", "memiliki_pinj", "pinjaman"): "./edge_nasabah_memiliki_pinj_pt",
}

build_hetero_graph(node_dirs, edge_dirs, "hetero_graph.pt")
