In [5]:
import nbformat as nbf
from nbformat.v4 import new_notebook, new_code_cell, new_markdown_cell
import json, os, textwrap, uuid, pathlib

nb = new_notebook()

md_intro = new_markdown_cell("# Poincaré Disk Visualization of Node Trajectories\n"
                             "This notebook loads the saved embedding trajectory `traj.pt` "
                             "(generated during HIM training) and visualizes selected nodes' "
                             "movement on the Poincaré disk.\n\n"
                             "**Requirements**: `torch`, `matplotlib`\n")

code_imports = new_code_cell(textwrap.dedent("""
    import torch
    import matplotlib.pyplot as plt
    from pathlib import Path
    """).strip())

code_helper = new_code_cell(textwrap.dedent("""
    def lorentz_to_poincare(x: torch.Tensor, gamma: float = 1.0) -> torch.Tensor:
        \"\"\"Convert Lorentz coordinates (x0, x1, …) to Poincaré disk coordinates.\"\"\"
        sqrt_g = gamma ** 0.5
        return x[..., 1:] / (x[..., :1] + sqrt_g)
    """).strip())

code_load = new_code_cell(textwrap.dedent("""
    # === Configuration ===
    TRAJ_PATH = Path('traj.pt')   # same directory
    NODE_IDS  = [0, 7, 23, 42]    # edit as you like
    GAMMA     = 1.0               # curvature parameter used in training
    
    assert TRAJ_PATH.exists(), f'{TRAJ_PATH} not found'
    
    traj = torch.load(TRAJ_PATH)   # list[Tensor] (epochs, N, d+1)
    print(f'Loaded {len(traj)} epochs, {traj[0].shape[0]} nodes.')
    """).strip())

code_plot = new_code_cell(textwrap.dedent("""
    fig, ax = plt.subplots(figsize=(6,6))
    # unit circle
    circle = plt.Circle((0,0), 1.0, fill=False, linestyle='--')
    ax.add_artist(circle)
    
    for nid in NODE_IDS:
        pts = torch.stack([lorentz_to_poincare(epoch_emb[nid], GAMMA) for epoch_emb in traj])
        ax.plot(pts[:,0], pts[:,1], marker='o', label=f'node {nid}')
    
    ax.set_aspect('equal')
    ax.set_xlim(-1.05,1.05)
    ax.set_ylim(-1.05,1.05)
    ax.set_title('Node trajectories in the Poincaré disk')
    ax.legend(loc='best')
    plt.show()
    """).strip())

nb.cells = [md_intro, code_imports, code_helper, code_load, code_plot]

out_path = "viz_poincare.ipynb"
with open(out_path, "w", encoding="utf-8") as f:
    nbf.write(nb, f)

out_path


'viz_poincare.ipynb'