In [23]:
from graphviz import Digraph

"""
Graphviz DAG for Cepheid-based H0 inference.
Nodes represent priors, deterministic calculations, likelihoods, and observed data.
Cluster highlights the velocity field and Hubble flow submodel.
"""

dot = Digraph(format='pdf')
dot.attr(rankdir='LR')
dot.attr('node', fontsize='24')
# dot.attr(concentrate='true')
dot.attr(splines='true')
# dot.attr(ranksep='2.0')
dot.attr(nodesep='0.75')

# --- Priors ---
prior_nodes = {
    'H0': 'Hubble constant',
    'M_W': 'CPLR normalization',
    'b_W': 'CPLR slope',
    'Z_W': 'CPLR metallicity correction',
    'sigma_v': 'Velocity dispersion',
    'beta': 'Velocity field scaling',
    'mu_host': '37 host distances',
    'mu_N4258': 'NGC 4258 distance',
    'mu_LMC': 'LMC distance',
    'mu_M31': 'M31 distance',
}

for var, label in prior_nodes.items():
    dot.node(var, label, shape='ellipse', style='filled', fillcolor='lightblue')

# --- External data ---
dot.node('density_field', 'Density field', shape='note', style='filled', fillcolor='white')
dot.node('velocity_field', 'Velocity field', shape='note', style='filled', fillcolor='white')

# --- Deterministic nodes ---
deterministic_nodes = [
    ('Vpec', 'Host peculiar velocity'),
    ('cz_pred', 'Predicted host redshift'),
    ('mu_cepheid', 'Cepheid distances'),
    ('mag_pred', 'Predicted Cepheid magnitudes'),
]
for var, label in deterministic_nodes:
    dot.node(var, label, shape='box', style='filled', fillcolor='lightyellow')

# --- Likelihood nodes ---
likelihood_nodes = [
    ('mu_N4258_ll', 'NGC 4258 geometric anchor'),
    ('mu_LMC_ll', 'LMC geometric anchor'),
    ('redshift_ll', 'Host redshift likelihood'),
    ('cepheid_ll', 'Cepheid magnitude likelihood'),
    ('MW_ll', 'MW CPLR likelihood'),
]
for var, label in likelihood_nodes:
    dot.node(var, label, shape='diamond', style='filled', fillcolor='lightcoral')

# --- Observed data nodes ---
data_nodes = [
    ('logP', 'Cepheid periods'),
    ('OH', 'Cepheid metallicities'),
    ('cz_obs', 'Observed host redshift'),
    ('mag_obs', 'Observed Cepheid magnitudes'),
    ('NGC4258_geo', 'NGC 4258 anchor data'),
    ('LMC_geo', 'LMC anchor data'),
    ('MW_data', 'MW CPLR data'),
]
for var, label in data_nodes:
    dot.node(var, label, shape='component', style='filled', fillcolor='lightgray')

# --- Cluster for velocity / H0 submodel ---
with dot.subgraph(name='cluster_velocity_block') as c:
    c.attr(label="Hubble constant inference",
           color='black',
            # bgcolor='lightsteelblue1',
           style='dashed', fontsize='24',)
    c.node('velocity_field')
    c.node('beta')
    c.node('H0')
    c.node('Vpec')
    c.node('cz_pred')
    c.node('redshift_ll')
    c.node('sigma_v')
    c.node('cz_obs')

# --- Rank constraints for anchors ---
with dot.subgraph(name="cluster_anchor") as s:
    s.attr(label="", color="invis", style="invis", rank='same', ordering='out')
    s.node('mu_N4258_ll')
    s.node('mu_LMC_ll')

# --- Graph edges ---
dot.edge('mu_host', 'Vpec')
dot.edge('beta', 'Vpec')
dot.edge('velocity_field', 'Vpec')

dot.edge('density_field', 'mu_host')
dot.edge('density_field', 'mu_N4258')
dot.edge('density_field', 'mu_LMC')
dot.edge('density_field', 'mu_M31')

dot.edge('mu_host', 'cz_pred')
dot.edge('Vpec', 'cz_pred')
dot.edge('H0', 'cz_pred')
dot.edge('cz_pred', 'redshift_ll')
dot.edge('sigma_v', 'redshift_ll')

dot.edge('mu_N4258', 'mu_N4258_ll')
dot.edge('NGC4258_geo', 'mu_N4258_ll')

dot.edge('mu_LMC', 'mu_LMC_ll')
dot.edge('LMC_geo', 'mu_LMC_ll')

dot.edge('mu_host', 'mu_cepheid')
dot.edge('mu_N4258', 'mu_cepheid')
dot.edge('mu_LMC', 'mu_cepheid')
dot.edge('mu_M31', 'mu_cepheid')

dot.edge('mu_cepheid', 'mag_pred')
dot.edge('M_W', 'mag_pred')
dot.edge('b_W', 'mag_pred')
dot.edge('Z_W', 'mag_pred')
dot.edge('logP', 'mag_pred')
dot.edge('OH', 'mag_pred')

dot.edge('mag_pred', 'cepheid_ll')

# Observed data to likelihoods
dot.edge('cz_obs', 'redshift_ll')
dot.edge('mag_obs', 'cepheid_ll')

# MW constraint on M_W
dot.edge('MW_data', 'MW_ll')
dot.edge('M_W', 'MW_ll')

# --- Legend cluster with background colour and extra spacing ---
with dot.subgraph(name='cluster_legend') as legend:
    legend.attr(
        label='Legend',
        fontsize='24',
        style='filled,solid',
        # color='black',
        bgcolor="lavender",
        margin='0.4'
    )
    
    # Force the legend to a different rank (bottom)
    legend.attr(rank='sink')
    
    # Example nodes with matching shapes and colours
    legend.node('prior_ex', 'Prior', shape='ellipse', style='filled', fillcolor='lightblue')
    legend.node('deterministic_ex', 'Deterministic node', shape='box', style='filled', fillcolor='lightyellow')
    legend.node('likelihood_ex', 'Likelihood', shape='diamond', style='filled', fillcolor='lightcoral')
    legend.node('data_ex', 'Observed data', shape='component', style='filled', fillcolor='lightgray')
    legend.node('external_ex', 'External field', shape='note', style='filled', fillcolor='white')
    
    # Arrange legend items horizontally
    legend.attr(rank='same')
    legend.edges([
        ('prior_ex', 'deterministic_ex'),
        ('deterministic_ex', 'likelihood_ex'),
        ('data_ex', 'likelihood_ex'),
        ('external_ex', 'prior_ex')
    ])
# --- Render ---
dot.render('/Users/rstiskalek/Downloads/CH0_DAG', view=True)

'/Users/rstiskalek/Downloads/CH0_DAG.pdf'