In [None]:
# -*- coding: utf-8 -*-
"""
ABM: Demographic-aware Network Evolution (Mesa 3.3.0 / Python 3.12)
- 新增：断点续跑（checkpoint / resume）
- 新增：导出节点属性 nodes.csv、最终边 edges.csv、周期边快照 edges_step_XXXX.parquet
- 仍使用 Mesa 3.x 新接口：model.agents.shuffle_do("step")（不引用 mesa.time）
- 带心跳输出

运行:
  # 全新开始
  python abm_demographic_network_mesa3_resume.py --run

  # 断点续跑（会从 OUTPUT_DIR 下的最近 checkpoint 继续）
  python abm_demographic_network_mesa3_resume.py --resume

  # 自定义输出目录、步数等
  python abm_demographic_network_mesa3_resume.py --run --out ./abm_out --steps 2000 --N 6000
"""

import os
import sys
import time
import math
import random
import argparse
import gzip
import pickle
from dataclasses import dataclass, asdict
from typing import Dict, Any, List, Tuple

import numpy as np
import pandas as pd
import networkx as nx
import mesa  # Mesa 3.x

# -----------------------------
# 配置
# -----------------------------
@dataclass
class SimConfig:
    # 规模与步数
    N: int = 4000
    STEPS: int = 1500
    SEED: int = 42

    # 行为超参
    CANDIDATES_PER_AGENT: int = 6
    MAX_NEW_EDGES_PER_STEP: int = 2
    EDGE_DECAY_P: float = 0.002
    REWIRE_P: float = 0.25

    W_HOMOPHILY: float = 1.2
    W_PREF_ATTACH: float = 0.9
    W_RANDOM_NOISE: float = 0.2

    HOMO_NATIONALITY: float = 1.0
    HOMO_RELIGION: float = 0.8
    HOMO_GENDER: float = 0.5
    HOMO_AGE_BIN: float = 0.7

    # 运行控制
    EARLY_STOP_WINDOW: int = 50
    HEARTBEAT_EVERY: int = 50

    # 持久化
    OUTPUT_DIR: str = "./abm_output"
    CHECKPOINT_EVERY: int = 100          # 每多少步保存一次 checkpoint
    SNAPSHOT_EVERY: int = 0              # 每多少步保存一次边快照（0 = 关闭），如 200

CONFIG = SimConfig()

# 人口属性空间
NATIONALITIES = ["A", "B", "C", "D"]
RELIGIONS    = ["X", "Y", "Z"]
GENDERS      = ["M", "F"]
AGE_BINS     = ["18-29", "30-44", "45-64", "65+"]

def sample_demographics(rng: random.Random) -> Dict[str, Any]:
    return {
        "nationality": rng.choice(NATIONALITIES),
        "religion":    rng.choice(RELIGIONS),
        "gender":      rng.choice(GENDERS),
        "age_bin":     rng.choice(AGE_BINS),
    }

def homophily_score(a: Dict[str, Any], b: Dict[str, Any]) -> float:
    s = 0.0
    if a["nationality"] == b["nationality"]:
        s += CONFIG.HOMO_NATIONALITY
    if a["religion"] == b["religion"]:
        s += CONFIG.HOMO_RELIGION
    if a["gender"] == b["gender"]:
        s += CONFIG.HOMO_GENDER
    if a["age_bin"] == b["age_bin"]:
        s += CONFIG.HOMO_AGE_BIN
    return s

# -----------------------------
# Agent / Model
# -----------------------------
class Person(mesa.Agent):
    def __init__(self, model: "SocietyModel", attrs: Dict[str, Any]):
        super().__init__(model)
        self.attrs = attrs

    @property
    def uid(self) -> int:
        return self.unique_id

    def step(self):
        G = self.model.G
        rng = self.model._rng

        # 断边衰减 + 重连尝试
        if G.degree(self.uid) > 0:
            neighbors = list(G.neighbors(self.uid))
            for nb in neighbors:
                if rng.random() < CONFIG.EDGE_DECAY_P:
                    G.remove_edge(self.uid, nb)
                    if rng.random() < CONFIG.REWIRE_P:
                        self._try_connect(rewire=True)

        # 新建连接
        new_edges = 0
        candidates = self._sample_candidates()
        deg = G.degree
        scored: List[Tuple[float, int]] = []
        for cid in candidates:
            if cid == self.uid or G.has_edge(self.uid, cid):
                continue
            score = (
                CONFIG.W_HOMOPHILY * homophily_score(self.attrs, self.model.people[cid].attrs)
                + CONFIG.W_PREF_ATTACH * math.log1p(deg(cid))
                + CONFIG.W_RANDOM_NOISE * rng.random()
            )
            scored.append((score, cid))
        scored.sort(reverse=True)

        for sc, cid in scored:
            if new_edges >= CONFIG.MAX_NEW_EDGES_PER_STEP:
                break
            p = 1 / (1 + math.exp(-sc))
            if rng.random() < p:
                G.add_edge(self.uid, cid)
                new_edges += 1

    def _sample_candidates(self) -> List[int]:
        rng = self.model._rng
        N = len(self.model.people)
        picks = set()
        while len(picks) < CONFIG.CANDIDATES_PER_AGENT:
            picks.add(rng.randrange(1, N + 1))
        return list(picks)

    def _try_connect(self, rewire: bool = False):
        G = self.model.G
        rng = self.model._rng
        candidates = self._sample_candidates()
        best = None
        best_score = -1e9
        for cid in candidates:
            if cid == self.uid or G.has_edge(self.uid, cid):
                continue
            score = (
                1.2 * homophily_score(self.attrs, self.model.people[cid].attrs)
                + 1.0 * math.log1p(G.degree(cid))
                + 0.1 * rng.random()
            )
            if score > best_score:
                best_score = score
                best = cid
        if best is not None:
            G.add_edge(self.uid, best)

class SocietyModel(mesa.Model):
    def __init__(self, n_agents: int, seed: int | None = None):
        super().__init__(seed=seed)
        self._rng = random.Random(seed)
        self.G = nx.Graph()
        self.people: Dict[int, Person] = {}
        self.metrics_rows: List[Dict[str, Any]] = []
        self._last_edge_counts: List[int] = []
        self.running = True

        # 创建人口
        for _ in range(n_agents):
            attrs = sample_demographics(self._rng)
            p = Person(self, attrs)
            self.people[p.unique_id] = p
            self.G.add_node(p.unique_id, **attrs)

        self._seed_random_edges(p0=0.0005)

    def _seed_random_edges(self, p0: float = 0.0005):
        ids = list(self.people.keys())
        rng = self._rng
        trials = int(len(ids) * CONFIG.CANDIDATES_PER_AGENT)
        for _ in range(trials):
            a = rng.choice(ids)
            b = rng.choice(ids)
            if a != b and not self.G.has_edge(a, b) and rng.random() < p0:
                self.G.add_edge(a, b)

    def step(self):
        self.agents.shuffle_do("step")
        self._collect_metrics()
        if self._early_stop():
            self.running = False

    def _collect_metrics(self):
        G = self.G
        n = G.number_of_nodes()
        m = G.number_of_edges()
        avg_deg = (2 * m / n) if n > 0 else 0.0
        cc = nx.average_clustering(G) if n > 1 else 0.0

        row = dict(
            step=len(self.metrics_rows) + 1,
            num_nodes=n,
            num_edges=m,
            avg_degree=avg_deg,
            clustering=cc,
        )
        self.metrics_rows.append(row)

        if row["step"] % CONFIG.HEARTBEAT_EVERY == 0 or row["step"] <= 3:
            print(f"[heartbeat] step={row['step']:>5} | edges={m:>8} | "
                  f"avg_deg={avg_deg:5.2f} | clustering={cc:5.3f}",
                  flush=True)

        self._last_edge_counts.append(m)
        if len(self._last_edge_counts) > CONFIG.EARLY_STOP_WINDOW:
            self._last_edge_counts.pop(0)

    def _early_stop(self) -> bool:
        w = CONFIG.EARLY_STOP_WINDOW
        if len(self._last_edge_counts) < w:
            return False
        return (max(self._last_edge_counts) - min(self._last_edge_counts)) == 0

    # ========= 持久化 =========
    def save_checkpoint(self, out_dir: str, tag: str | None = None):
        os.makedirs(out_dir, exist_ok=True)
        step = len(self.metrics_rows)
        fname = f"checkpoint_step_{step:06d}.pkl.gz" if tag is None else f"checkpoint_{tag}.pkl.gz"
        path = os.path.join(out_dir, fname)

        # 序列化内容：图、节点属性、度量、随机态、配置等
        data = {
            "version": "v1",
            "config": asdict(CONFIG),
            "seed": self.random.seed,  # Mesa自带seed也存一下（可不必用）
            "rng_state": self._rng.getstate(),
            "nodes": {n: dict(self.G.nodes[n]) for n in self.G.nodes()},
            "edges": list(self.G.edges()),
            "metrics_rows": self.metrics_rows,
            "last_edge_counts": self._last_edge_counts,
            "step": step,
        }
        with gzip.open(path, "wb") as f:
            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
        print(f"[checkpoint] saved → {path}", flush=True)

    @staticmethod
    def load_from_checkpoint(path: str) -> "SocietyModel":
        with gzip.open(path, "rb") as f:
            data = pickle.load(f)

        cfg = data.get("config", None)
        if cfg is not None:
            # 用 checkpoint 的配置覆盖当前全局 CONFIG（可根据需要保留/部分合并）
            for k, v in cfg.items():
                setattr(CONFIG, k, v)

        nodes = data["nodes"]
        edges = data["edges"]
        metrics_rows = data["metrics_rows"]
        last_edge_counts = data["last_edge_counts"]
        rng_state = data["rng_state"]

        # 复建模型（按节点数）
        model = SocietyModel(n_agents=len(nodes), seed=CONFIG.SEED)
        # 覆盖随机态
        model._rng.setstate(rng_state)

        # 重建属性与图
        model.G.clear()
        for n, attrs in nodes.items():
            model.G.add_node(n, **attrs)
        model.G.add_edges_from(edges)

        # 重建 agent 容器（确保 unique_id 与 nodes 对齐）
        model.people.clear()
        # Mesa 3.x 会自动给已有 agents 分配 1..N 的 ID；我们要匹配已有 node id：
        # 这里的做法：丢弃自动创建的 agents，按 nodes 顺序再创建一次，并手动替换 model.agents 容器。
        # 更简单：创建空模型后不依赖 model.agents，people 仅用于属性索引。
        for n in nodes.keys():
            dummy = Person(model, nodes[n])
            # 强制把 dummy 的 unique_id 视为 n（轻量 hack：存 people，实际 step 时不依赖 dummy.id 顺序）
            object.__setattr__(dummy, "unique_id", n)
            model.people[n] = dummy

        model.metrics_rows = metrics_rows
        model._last_edge_counts = last_edge_counts
        model.running = True
        return model

    # ========= 导出 =========
    def export_results(self, out_dir: str):
        os.makedirs(out_dir, exist_ok=True)
        metrics_df = pd.DataFrame(self.metrics_rows)
        metrics_df.to_csv(os.path.join(out_dir, "metrics.csv"), index=False)

        # 节点属性 + 度
        degs = dict(self.G.degree())
        rows = []
        for n, attrs in self.G.nodes(data=True):
            r = {"node": n, "degree": degs.get(n, 0)}
            r.update(attrs)
            rows.append(r)
        pd.DataFrame(rows).to_csv(os.path.join(out_dir, "nodes.csv"), index=False)

        # 最终边集合
        edges_df = pd.DataFrame(list(self.G.edges()), columns=["source", "target"])
        edges_df.to_csv(os.path.join(out_dir, "edges.csv"), index=False)

        print(f"[done] metrics.csv / nodes.csv / edges.csv saved to: {out_dir}", flush=True)

    def export_edge_snapshot(self, out_dir: str):
        """按需导出当前步的边快照为 parquet（体积/速度更友好）。"""
        if CONFIG.SNAPSHOT_EVERY <= 0:
            return
        step = len(self.metrics_rows)
        if step % CONFIG.SNAPSHOT_EVERY == 0:
            df = pd.DataFrame(list(self.G.edges()), columns=["source", "target"])
            path = os.path.join(out_dir, f"edges_step_{step:06d}.parquet")
            df.to_parquet(path, index=False)
            print(f"[snapshot] edges → {path}", flush=True)

# -----------------------------
# 运行 & 续跑
# -----------------------------
def latest_checkpoint(out_dir: str) -> str | None:
    if not os.path.isdir(out_dir):
        return None
    cks = [f for f in os.listdir(out_dir) if f.startswith("checkpoint_") and f.endswith(".pkl.gz")]
    if not cks:
        # 兼容老命名
        cks = [f for f in os.listdir(out_dir) if f.endswith(".pkl.gz")]
    if not cks:
        return None
    cks.sort()
    return os.path.join(out_dir, cks[-1])

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--run", action="store_true", help="start a fresh run")
    parser.add_argument("--resume", action="store_true", help="resume from latest checkpoint")
    parser.add_argument("--out", type=str, default=CONFIG.OUTPUT_DIR, help="output directory")
    parser.add_argument("--N", type=int, default=None, help="number of agents")
    parser.add_argument("--steps", type=int, default=None, help="max steps")
    args = parser.parse_args()

    # 覆盖全局配置（可选）
    if args.N is not None:
        CONFIG.N = args.N
    if args.steps is not None:
        CONFIG.STEPS = args.steps
    CONFIG.OUTPUT_DIR = args.out
    os.makedirs(CONFIG.OUTPUT_DIR, exist_ok=True)

    # 决定是否续跑
    model = None
    start_step = 0
    if args.resume:
        ck = latest_checkpoint(CONFIG.OUTPUT_DIR)
        if ck:
            print(f"[resume] loading checkpoint: {ck}")
            model = SocietyModel.load_from_checkpoint(ck)
            start_step = len(model.metrics_rows)
            print(f"[resume] resume from step={start_step}")
        else:
            print("[resume] no checkpoint found, start fresh.")
    if model is None:
        print("[run] start fresh model")
        model = SocietyModel(n_agents=CONFIG.N, seed=CONFIG.SEED)

    print("=== Demographic Network Evolution (Mesa 3.x, checkpointable) ===")
    print(f"Config: N={CONFIG.N}, STEPS={CONFIG.STEPS}, OUT={CONFIG.OUTPUT_DIR}")
    print(f"Heartbeat every {CONFIG.HEARTBEAT_EVERY} steps, checkpoint every {CONFIG.CHECKPOINT_EVERY} steps")
    print("----------------------------------------------------------------")

    t0 = time.time()
    # 主循环：固定上限 + 早停
    while model.running and len(model.metrics_rows) < CONFIG.STEPS:
        model.step()

        # 周期快照（边）
        model.export_edge_snapshot(CONFIG.OUTPUT_DIR)

        # 周期 checkpoint
        if CONFIG.CHECKPOINT_EVERY > 0 and (len(model.metrics_rows) % CONFIG.CHECKPOINT_EVERY == 0):
            model.save_checkpoint(CONFIG.OUTPUT_DIR)

    # 收尾：最后再保存一次 checkpoint + 全量导出
    model.save_checkpoint(CONFIG.OUTPUT_DIR, tag="final")
    model.export_results(CONFIG.OUTPUT_DIR)

    dt = (time.time() - t0) / 60.0
    print(f"[total] steps={len(model.metrics_rows)} | elapsed={dt:.2f} min")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\n[interrupt] 手动终止，已尽量保留到目前为止的输出与最近 checkpoint。", flush=True)


=== Demographic Network Evolution (Mesa 3.x) ===
Config: N=4000, STEPS=1500, CANDIDATES=6, DECAY_P=0.002
Tip: Mesa 3.x 使用 model.agents.shuffle_do('step')，不再导入 mesa.time.*
------------------------------------------------
[heartbeat] step=    1 | edges=    8000 | avg_deg= 4.00 | clustering=0.001
[heartbeat] step=    2 | edges=   15961 | avg_deg= 7.98 | clustering=0.002
[heartbeat] step=    3 | edges=   23910 | avg_deg=11.96 | clustering=0.004


In [None]:
!pip install mesa

In [3]:
import sys, inspect
import mesa
print("mesa file:", inspect.getfile(mesa))
print("mesa version:", getattr(mesa, "__version__", "unknown"))
print("sys.path[0]:", sys.path[0])

  from .autonotebook import tqdm as notebook_tqdm


mesa file: d:\anaconda3\envs\torchcu\Lib\site-packages\mesa\__init__.py
mesa version: 3.3.0
sys.path[0]: d:\anaconda3\envs\torchcu\python312.zip


In [2]:
!pip install mesa

Collecting mesa
  Downloading mesa-3.3.0-py3-none-any.whl.metadata (11 kB)
Downloading mesa-3.3.0-py3-none-any.whl (239 kB)
Installing collected packages: mesa
Successfully installed mesa-3.3.0
