In [1]:
from graphviz import Digraph

In [2]:
dot = Digraph(format='pdf')
dot.attr(rankdir='LR')
dot.attr('node', fontsize='12')

# Priors
prior_labels = {
    'H0': 'Hubble constant',
    'M_W': 'Cepheid PLR normalization',
    'b_W': 'Cepheid PLR slope',
    'Z_W': 'Cepheid PLR metallicity correction',
    'sigma_v': 'Velocity dispersion',
    'beta': 'Velocity field scaling',
    'mu_host': '37 host distances',
    'mu_N4258': 'NGC 4258 distance',
    'mu_LMC': 'LMC distances',
    'mu_M31': 'M31 distance',
}
for var, label in prior_labels.items():
    dot.node(var, label, shape='ellipse', style='filled', fillcolor='lightblue')

# External data nodes (split)
dot.node('density_field', 'Density field', shape='note', style='filled', fillcolor='white')
dot.node('velocity_field', 'Velocity field', shape='note', style='filled', fillcolor='white')

# Deterministic nodes
dot.node('Vpec', 'Host peculiar velocity', shape='box', style='filled', fillcolor='lightyellow')
dot.node('cz_pred', 'Predicted host redshift', shape='box', style='filled', fillcolor='lightyellow')
dot.node('mu_cepheid', 'Cepheids distances', shape='box', style='filled', fillcolor='lightyellow')
dot.node('mag_pred', 'Predicted Cepheid magnitudes', shape='box', style='filled', fillcolor='lightyellow')

# Likelihood nodes
dot.node('mu_N4258_ll', 'NGC 4258 geometric anchor', shape='diamond', style='filled', fillcolor='lightcoral')
dot.node('mu_LMC_ll', 'LMC geometric anchor', shape='diamond', style='filled', fillcolor='lightcoral')
dot.node('redshift_ll', 'Host redshift likelihood', shape='diamond', style='filled', fillcolor='lightcoral')
dot.node('cepheid_ll', 'Cepheid magnitude likelihood', shape='diamond', style='filled', fillcolor='lightcoral')
dot.node('MW_ll', 'MW PLR likelihood', shape='diamond', style='filled', fillcolor='lightcoral')

# Data nodes (component shape for distinction)
dot.node('logP', 'Cepheid periods', shape='component', style='filled', fillcolor='lightgray')
dot.node('OH', 'Cepheid metallicities', shape='component', style='filled', fillcolor='lightgray')
dot.node('cz_obs', 'Observed host redshift', shape='component', style='filled', fillcolor='lightgray')
dot.node('mag_obs', 'Observed Cepheid magnitudes', shape='component', style='filled', fillcolor='lightgray')
dot.node('NGC4258_geo', 'NGC 4258 anchor data', shape='component', style='filled', fillcolor='lightgray')
dot.node('LMC_geo', 'LMC anchor data', shape='component', style='filled', fillcolor='lightgray')
dot.node('MW_data', 'MW Cepheid PLR data', shape='component', style='filled', fillcolor='lightgray')

# Group anchors at same rank
with dot.subgraph() as s:
    s.attr(rank='same')
    s.node('mu_N4258_ll')
    s.node('mu_LMC_ll')

# Connections for deterministic structure
dot.edge('mu_host', 'Vpec')
dot.edge('beta', 'Vpec')
dot.edge('velocity_field', 'Vpec')

dot.edge('density_field', 'mu_host')
dot.edge('density_field', 'mu_N4258')
dot.edge('density_field', 'mu_LMC')
dot.edge('density_field', 'mu_M31')

dot.edge('mu_host', 'cz_pred')
dot.edge('Vpec', 'cz_pred')
dot.edge('H0', 'cz_pred')
dot.edge('cz_pred', 'redshift_ll')
dot.edge('sigma_v', 'redshift_ll')

dot.edge('mu_N4258', 'mu_N4258_ll')
dot.edge('NGC4258_geo', 'mu_N4258_ll')

dot.edge('mu_LMC', 'mu_LMC_ll')
dot.edge('LMC_geo', 'mu_LMC_ll')

dot.edge('mu_host', 'mu_cepheid')
dot.edge('mu_N4258', 'mu_cepheid')
dot.edge('mu_LMC', 'mu_cepheid')
dot.edge('mu_M31', 'mu_cepheid')

dot.edge('mu_cepheid', 'mag_pred')
dot.edge('M_W', 'mag_pred')
dot.edge('b_W', 'mag_pred')
dot.edge('Z_W', 'mag_pred')
dot.edge('logP', 'mag_pred')
dot.edge('OH', 'mag_pred')

dot.edge('mag_pred', 'cepheid_ll')

# Observed data to likelihoods
dot.edge('cz_obs', 'redshift_ll')
dot.edge('mag_obs', 'cepheid_ll')

# MW constraint on M_W
dot.edge('MW_data', 'MW_ll')
dot.edge('M_W', 'MW_ll')

# with dot.subgraph(name='cluster_legend') as legend:
#     legend.attr(label='Legend', fontsize='18')
#     legend.attr(style='solid')

#     # Add example nodes
#     legend.node('prior_ex', 'Prior', shape='ellipse', style='filled', fillcolor='lightblue')
#     legend.node('deterministic_ex', 'Deterministic node', shape='box', style='filled', fillcolor='lightyellow')
#     legend.node('likelihood_ex', 'Likelihood', shape='diamond', style='filled', fillcolor='lightcoral')
#     legend.node('data_ex', 'Observed data', shape='component', style='filled', fillcolor='lightgray')

#     # Align them horizontally or vertically
#     legend.attr(rank='same')
#     legend.edges([('prior_ex', 'deterministic_ex'), ('deterministic_ex', 'likelihood_ex'), ('data_ex', 'likelihood_ex')])

# Save and view
dot.render('/Users/rstiskalek/Downloads/CH0_DAG', view=True)


'/Users/rstiskalek/Downloads/CH0_DAG.pdf'