In [5]:
from graphviz import Digraph

COLS = ["#87193dD9", "#80A1D4", "#893168D9", "#75C9C8", "#DED9E2"]

fontsize = "20"

dot = Digraph(format='pdf')
dot.attr(rankdir='TB', splines='true', fontsize=fontsize, nodesep='0.6', ranksep='1.0', concentrate='true')
dot.attr(fontname='CMU Serif')
dot.attr('node', fontname='CMU Serif', fontsize=fontsize)
dot.attr('edge', fontname='CMU Serif', fontsize=fontsize,)

# ---------- Priors ----------
dot.node('TFR_params', '<a<SUB>TFR</SUB>, b<SUB>TFR</SUB>,<BR/>c<SUB>TFR</SUB>>',
         shape='ellipse', style='filled', fillcolor=COLS[1])
dot.node('sigma_int', '<σ<SUB>int</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[1])
dot.node('beta_vext', '<β, V<SUB>ext</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[1])
dot.node('sigma_v', '<σ<SUB>v</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[1])
dot.node('bias_params', '<b<SUB>1</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[1])
dot.node('pqR', '<p, q, R>', shape='ellipse', style='filled', fillcolor=COLS[1])
dot.edge('pqR', 'r')

# η hyperprior (merged)
dot.node('eta_prior', '<η̂ , w<sub>η</sub>>', shape='ellipse', style='filled', fillcolor=COLS[1])

# ---------- Data / Observed ----------
dot.node('eta_obs', '<η<SUB>obs</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[-1])
dot.node('mag_obs', '<m<SUB>obs</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[-1])
dot.node('czcmb', '<z<SUB>obs</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[-1])
dot.node('los_vel_r', '<Velocity<BR/>field>', shape='ellipse', style='filled', fillcolor=COLS[-1])
dot.node('density_field', '<Density<BR/>field>', shape='ellipse', style='filled', fillcolor=COLS[-1])

# ---------- Sampled latent variables ----------
dot.node('r', 'r', shape='ellipse', style='filled', fillcolor=COLS[1])  # sampled distance
dot.node('eta_true', '<η<SUB>true</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[1])

# ---------- Deterministic / internal ----------
dot.node('z_cosmo', '<z<SUB>cosmo</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[3])
dot.node('z_pec', '<z<SUB>pec</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[3])
dot.node('z_pred', '<z<SUB>pred</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[3])

dot.node('mu_r', 'μ', shape='ellipse', style='filled', fillcolor=COLS[3])
dot.node('M_eta', '<M>', shape='ellipse', style='filled', fillcolor=COLS[3])
dot.node('m_pred', '<m<SUB>pred</SUB>>', shape='ellipse', style='filled', fillcolor=COLS[3])

# ---------- Likelihoods ----------
dot.node('cz_ll', '<z<SUB>obs</SUB> | z<SUB>pred</SUB>>',
         shape='ellipse', style='filled', fillcolor=COLS[2])
dot.node('mag_ll', '<m<SUB>obs</SUB> | m<SUB>pred</SUB>>',
         shape='ellipse', style='filled', fillcolor=COLS[2])
dot.node('eta_ll', '<η<SUB>obs</SUB> | η<SUB>true</SUB>>',
         shape='ellipse', style='filled', fillcolor=COLS[2])

# ---------- Edges ----------
# Distance sampling from empirical prior + bias
dot.edge('density_field', 'r')
dot.edge('bias_params', 'r')

# Redshift pieces
dot.edge('r', 'z_cosmo')
dot.edge('r', 'z_pec')        # allow scale dependence if needed
dot.edge('beta_vext', 'z_pec')
dot.edge('los_vel_r', 'z_pec')

# Combine to predicted redshift
dot.edge('z_cosmo', 'z_pred')
dot.edge('z_pec', 'z_pred')

# Redshift likelihood
dot.edge('z_pred', 'cz_ll')
dot.edge('sigma_v', 'cz_ll')
dot.edge('czcmb', 'cz_ll')

# η hyperprior → latent → likelihood
dot.edge('eta_prior', 'eta_true')
dot.edge('eta_true', 'eta_ll')
dot.edge('eta_obs', 'eta_ll')

# Magnitudes: m_pred from μ(r) and M(η); likelihood uses σ_int
dot.edge('r', 'mu_r')
dot.edge('eta_true', 'M_eta')
dot.edge('TFR_params', 'M_eta')
dot.edge('mu_r', 'm_pred')
dot.edge('M_eta', 'm_pred')
dot.edge('m_pred', 'mag_ll')
dot.edge('mag_obs', 'mag_ll')
dot.edge('sigma_int', 'mag_ll')

# ---------- Legend ----------
LEG_PRIOR, LEG_LIKE, LEG_DET, LEG_DATA = "#80A1D4", "#893168", "#75C9C8", "#DED9E2"
legend_label = f"""<
<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">
  <TR><TD>
    <TABLE BORDER="0" CELLBORDER="1" CELLSPACING="0" CELLPADDING="3">
      <TR>
        <TD><TABLE BORDER="0" CELLBORDER="0" CELLSPACING="4">
            <TR><TD BGCOLOR="{LEG_PRIOR}" WIDTH="30" HEIGHT="12"></TD><TD ALIGN="LEFT"><FONT POINT-SIZE="20">Priors / Latent samples</FONT></TD></TR>
        </TABLE></TD>
        <TD><TABLE BORDER="0" CELLBORDER="0" CELLSPACING="4">
            <TR><TD BGCOLOR="{LEG_DET}" WIDTH="30" HEIGHT="12"></TD><TD ALIGN="LEFT"><FONT POINT-SIZE="20">Deterministic</FONT></TD></TR>
        </TABLE></TD>
        <TD><TABLE BORDER="0" CELLBORDER="0" CELLSPACING="4">
            <TR><TD BGCOLOR="{LEG_LIKE}" WIDTH="30" HEIGHT="12"></TD><TD ALIGN="LEFT"><FONT POINT-SIZE="20">Likelihoods</FONT></TD></TR>
        </TABLE></TD>
        <TD><TABLE BORDER="0" CELLBORDER="0" CELLSPACING="4">
            <TR><TD BGCOLOR="{LEG_DATA}" WIDTH="30" HEIGHT="12"></TD><TD ALIGN="LEFT"><FONT POINT-SIZE="20">Observed / External</FONT></TD></TR>
        </TABLE></TD>
      </TR>
    </TABLE>
  </TD></TR>
  <TR><TD><FONT POINT-SIZE="6">&#160;</FONT></TD></TR>
</TABLE>
>"""
dot.attr(label=legend_label, labelloc='t', labeljust='r', margin='0.01')

# Render
dot.render('/Users/rstiskalek/Downloads/TFR_DAG', view=True)

'/Users/rstiskalek/Downloads/TFR_DAG.pdf'