In [5]:
import math
import random
from typing import List, Dict, Set, Tuple
import csv
import numpy as np
import multilayerGM as gm  # 来自仓库


def build_layer_node_sets(
    N: int,
    L: int,
    init_size: int,
    join_rate: float,
    leaving_rate: float,
    *,
    stochastic: str = "binomial",  # "binomial" 或 "expected"
    seed: int = 42,
):
    """
    按速率构造多层节点集合 S_l：
    - 第0层：从 [0..N-1] 中采样 init_size 个节点
    - 每步 l->l+1：
        * 每个上一层节点以 (1 - leaving_rate) 概率存活
        * 从候选池以 join_rate 决定新进入数量，然后采样加入
      forbid_reentry=True 时，候选池为“从未出现过的节点”；否则为“当前不在层内的所有节点”（允许回流）
    返回：
      layer_sets: List[Set[int]]
      stats:      每层的统计信息（规模、存活数、新进入数、Jaccard等）
    """
    assert 0 <= init_size <= N
    assert 0 <= join_rate <= 1 and 0 <= leaving_rate <= 1
    assert L >= 1

    py_rng = random.Random(seed)
    rng = np.random.default_rng(seed)

    universe = list(range(N))
    layer_sets: List[Set[int]] = []

    # 第0层
    S0 = set(py_rng.sample(universe, init_size))
    layer_sets.append(S0)

    seen = set(S0)  # 历史出现过的节点
    stats = [{
        "layer": 0,
        "size": len(S0),
        "survivors": len(S0),  # 定义为与上一层交集，这里等于自身
        "entrants": len(S0),   # 初始层全是新进入
        "jaccard_with_prev": 1.0  # 定义方便，一层时无意义
    }]

    for l in range(1, L):
        prev = layer_sets[-1]
        a = len(prev)

        # 存活采样
        survive_prob = 1.0 - leaving_rate
        survivors = {u for u in prev if rng.random() < survive_prob}
        S_count = len(survivors)
        candidates = [u for u in universe if (u not in survivors)]

        # 进入数
        if stochastic == "binomial":
            n_join = int(rng.binomial(n=len(candidates), p=join_rate))
        elif stochastic == "expected":
            # 以上一层规模 a 为基数，按期望比例进入，再截断到候选池大小
            n_join = int(round(join_rate * a))
        else:
            raise ValueError("stochastic 必须是 'binomial' 或 'expected'")

        n_join = min(n_join, len(candidates))  # 候选不足时截断

        # 抽取进入者并形成本层
        entrants = set(py_rng.sample(candidates, n_join)) if n_join > 0 else set()
        Sl = survivors | entrants

        layer_sets.append(Sl)
        seen |= entrants

        # 统计
        inter = len(prev & Sl)
        union = len(prev | Sl)
        jacc = (inter / union) if union > 0 else 0.0

        stats.append({
            "layer": l,
            "size": len(Sl),
            "survivors": S_count,
            "entrants": n_join,
            "jaccard_with_prev": jacc
        })

    return layer_sets, stats


def sample_layer_labels_inherit(
    layer_sets: List[Set[int]],
    n_sets: int = 4,
    theta: float = 1.0,
    p_stay: float = 0.8,   # 继承概率（0~1）
    seed: int = 42,
) -> Dict[Tuple[int, int], int]:
    """
    基于继承概率 p_stay 的跨层标签过程。
    层0按 Dirichlet 抽；层l>0 对“存活节点”以 p_stay 保留，否则按层l分布重采样；
    “新进入节点”总是按层l分布采样。
    """
    rng = np.random.default_rng(seed)
    partition: Dict[Tuple[int, int], int] = {}

    # 层0：生成分布并采样
    probs0 = rng.dirichlet(alpha=[theta] * n_sets)
    for u in layer_sets[0]:
        partition[(u, 0)] = int(rng.choice(n_sets, p=probs0))

    # 后续各层
    for l in range(1, len(layer_sets)):
        prev_nodes = layer_sets[l - 1]
        cur_nodes  = layer_sets[l]
        probs_l = rng.dirichlet(alpha=[theta] * n_sets)

        survivors = cur_nodes & prev_nodes
        entrants  = cur_nodes - prev_nodes

        # 存活节点：以 p_stay 继承，否则重采样
        for u in survivors:
            if rng.random() < p_stay:
                partition[(u, l)] = partition[(u, l - 1)]
            else:
                partition[(u, l)] = int(rng.choice(n_sets, p=probs_l))

        # 新进入节点：按本层分布采样
        for u in entrants:
            partition[(u, l)] = int(rng.choice(n_sets, p=probs_l))

    return partition


def build_multilayer_network(
    partition: Dict[Tuple[int, int], int],
    mu: float = 0.1,
    k_min: int = 5,
    k_max: int = 70,
    t_k: float = -2.0,
):
    """
    用 MultilayerGM 的 DCSBM 基准模型按给定 partition 生成多层网络
    返回 MultilayerGraph（节点形如 (u, layer)），并带 'mesoset' 属性
    """
    # gm.multilayer_DCSBM_network 会读取节点上的 'mesoset'（我们会在内部设置）
    # 需要把 partition 转换为节点属性字典：
    # 其内部会创建节点并设置 'mesoset'，我们只需传 mapping 即可。
    multinet = gm.multilayer_DCSBM_network(
        partition, mu=mu, k_min=k_min, k_max=k_max, t_k=t_k
    )
    return multinet


def export_edges_csv(
    multinet,
    out_path: str = "edges.csv",
):
    """
    将多层网络的“同层边”导出为 CSV：
    u,v,layer,u_label,v_label
    其中 layer 从 0 开始计数；u_label/v_label 取自节点属性 'mesoset'
    """
    with open(out_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["u", "v", "layer", "u_label", "v_label"])

        # multinet 的节点是 (u, layer) 的元组
        for (u_node, v_node) in multinet.edges():
            # 只导出“同层边”（一般 DCSBM 生成的就是同层边）
            if u_node[1] != v_node[1]:
                continue
            u_phys, layer_u = u_node[0], u_node[1]
            v_phys, layer_v = v_node[0], v_node[1]
            # 取标签
            u_lab = multinet.nodes[u_node].get("mesoset", -1)
            v_lab = multinet.nodes[v_node].get("mesoset", -1)
            writer.writerow([u_phys, v_phys, layer_u, u_lab, v_lab])

In [15]:
# 参数
N = 500
L = 5
init_size = 200
join_rate = 0.8      # 加入概率
leaving_rate = 0.8    # 离开概率

layer_sets, stats = build_layer_node_sets(
    N=N, L=L, init_size=init_size,
    join_rate=join_rate, 
    leaving_rate=leaving_rate,
    stochastic="binomial",    # 或 "expected"
    seed=123
)
print(stats)
# 后续沿用你的代码
partition = sample_layer_labels_inherit(layer_sets, n_sets=10, theta=1.0, seed=123, p_stay=0.2)

# 造图后“裁剪到 partition.keys()”，避免库把整层补满（防扩张）
multinet = build_multilayer_network(partition, mu=0.25, k_min=3, k_max=30, t_k=-2.0)
allowed = set(partition.keys())
to_remove = [n for n in list(multinet.nodes) if n not in allowed]
if to_remove:  # 只保留我们指定的状态节点
    multinet.remove_nodes_from(to_remove)
for n in multinet.nodes:
    multinet.nodes[n]['mesoset'] = partition[n]

export_edges_csv(multinet, out_path="sync_data/test.csv")

[{'layer': 0, 'size': 200, 'survivors': 200, 'entrants': 200, 'jaccard_with_prev': 1.0}, {'layer': 1, 'size': 404, 'survivors': 41, 'entrants': 363, 'jaccard_with_prev': 0.38215102974828374}, {'layer': 2, 'size': 415, 'survivors': 71, 'entrants': 344, 'jaccard_with_prev': 0.7027027027027027}, {'layer': 3, 'size': 415, 'survivors': 80, 'entrants': 335, 'jaccard_with_prev': 0.7291666666666666}, {'layer': 4, 'size': 401, 'survivors': 80, 'entrants': 321, 'jaccard_with_prev': 0.7035490605427975}]
