# HII region expansion in $1 / r^2$ density profile

## TOC
* [Packages & Settings](#packages)
* [Snapshots](#snapshots)
* [Loading Snapshots](#loading)
* [Maps](#maps)
* [Lineouts](#lineouts)
* [I-front position](#pos-vel)

## Packages & Settings  <a class="anchor" id="packages"></a>

In [None]:
from pyrhyme import PyRhyme

from pathlib import Path

# !pip install numpy
import numpy as np

# !pip install astropy
from astropy import units as U
from astropy import constants as C
from astropy.cosmology import WMAP7

# !pip install scipy
from scipy.io import FortranFile
from scipy.interpolate import interp1d

# !pip install plotly
import plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.templates.default = "simple_white"

from math import pi

FIGURE_DIR = './figures'
Path(FIGURE_DIR).mkdir(parents=True, exist_ok=True)

## Snapshots  <a class="anchor" id="snapshots"></a>

Run
```bash
wget -r –level=0 -E –ignore-length -x -k -p -erobots=off -np -N https://astronomy.sussex.ac.uk/\~iti20/RT_comparison_project/RT_workshop_data/T6_results/
```
to download Iliev-6 test snapshots!

In [None]:
ILIEV_6_DIR = 'astronomy.sussex.ac.uk/~iti20/RT_comparison_project/RT_workshop_data/T6_results'

TIMES = [1, 3, 10, 25, 75]

SIMULATION_NAMES = [
    ('CAPREOLE+C2-RAY', 'c2ray'),
#     ('Flash', 'flash'),
    ('HART', 'hart'),
#     ('LICORICE', 'licorice'),
    ('RSPH', 'rsph'),
    ('RH1D', 'rh1d'),
#     ('ZEUS-MP', 'zeus'),
]


BINARIES = {sn: {
    'orig': {t: {'path': f"{ILIEV_6_DIR}/{abbr}{ti+1}.bin"} for ti, t in enumerate(TIMES)},
    'add': {t: {'path': f"{ILIEV_6_DIR}/{abbr}_add{ti+1}.bin"} for ti, t in enumerate(TIMES)}
} for sn, abbr in SIMULATION_NAMES}

RHYME = {
    'Rhyme CaseA': {
        0: {'path': './Iliev-6-15px/Iliev-6-logBins-CaseA/Iliev-6-logBins-CaseA-000000.chombo.h5'}
    },
    'Rhyme CaseB': {
        0: {'path': './Iliev-6-15px/Iliev-6-logBins-CaseB/Iliev-6-logBins-CaseB-000000.chombo.h5'}
    },
}

# Check if files exist
def check_if_exist(path):
    if not Path(snap['path']).is_file():
        print(f"Not found: {snap['path']}")
        
for sim_name, sim in BINARIES.items():
    for snap in sim['orig'].values():
        check_if_exist(snap['path'])
    for snap in sim['add'].values():
        check_if_exist(snap['path'])

for sim in RHYME.values():
    check_if_exist(sim[0]['path'])

## Loading snapshots  <a class="anchor" id="loading"></a>

In [None]:
import os
def time_to_snap_id(t, ds):
    for i in range(0, ds.dataset.num_of_snapshots):
        ds.dataset.jump_to(i)
        snap_time = ds.dataset.time
        if snap_time > t:
            if i == 0:
                return i
            
            ds.dataset.jump_to(i-1)
            prev_time = ds.dataset.time
            
            return i if abs(t - snap_time) < abs(t - prev_time) else i - 1
        
    return ds.dataset.num_of_snapshots - 1


def load_rhyme(snap_id, rhyme, Gamma=5./3.):
    rhyme.dataset.jump_to(snap_id)
    res = {}

    v = rhyme.load_variables(silent=True)
    domain = rhyme.dataset.problem_domain

    rho = v['rho'][0].reshape(domain) * (1 * v['rho'][1]).to(U.cm**-3).value
    vx = v['rho_u'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    vy = v['rho_v'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    vz = v['rho_w'][0].reshape(domain) / rho * (1 * U.Mpc / U.Myr).to(U.cm / U.s).value
    T_orig = v['temp'][0].reshape(domain)
    fHI = v['ntr_frac_0'][0].reshape(domain)
    p, _, T = rhyme.calc_temperature(v, X=1.0, Y=0.0, Gamma=5./3.)
    T = T[0].reshape(domain)
    p = p.reshape(domain)
    p *= C.m_p.to(U.g).value * (1 * rhyme.dataset.active['h5']['attrs']['pressure_unit']).to(U.cm**-3 * U.cm**2 / U.s**2).value

    M = np.sqrt(vx**2 + vy**2 + vz**2) / np.sqrt(Gamma * p / (C.m_p.to(U.g).value * rho))

    res['n'] = rho
    res['nHI'] = fHI
    res['nHII'] = 1.0 - fHI
    res['T'] = T_orig
    res['p'] = p
    res['M'] = M
    
    return res


def loading_binary_files():
    result = {}
    
    for simname, sim in BINARIES.items():
        print(simname)
        for snap_time, snap in sim['orig'].items():
            if snap_time not in result:
                result[snap_time] = {}
                
            if simname not in result[snap_time]:
                result[snap_time][simname] = {}
            
            f = FortranFile(snap['path'], 'r')
            
            domain = f.read_ints(dtype=np.int32)
            
            result[snap_time][simname]['nHI'] = f.read_reals(np.float64).reshape(domain)
            result[snap_time][simname]['p'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['T'] = f.read_reals(np.float32).reshape(domain)
            
            f.close()
            
        for snap_time, snap in sim['add'].items():
            if snap_time not in result:
                result[snap_time] = {}
                
            if simname not in result[snap_time]:
                result[snap_time][simname] = {}

            f = FortranFile(snap['path'], 'r')
            domain = f.read_ints(np.int32)
            
            result[snap_time][simname]['n'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['M'] = f.read_reals(np.float32).reshape(domain)
            result[snap_time][simname]['nHII'] = f.read_reals(np.float32).reshape(domain)
            
            f.close()
            
    for sn, sim in RHYME.items():
        for t in TIMES:
            r = PyRhyme(sim[0]['path'])
            result[t][sn] = load_rhyme(time_to_snap_id(t, r), r)
        
        r.dataset.close_current()
        r.dataset.clean_all()
        
        del r
    
    return result
            
    
if __name__ == '__main__' and '__file__' not in globals():
    DATA = loading_binary_files()

## Maps <a class="anchor" id="maps"></a>

In [None]:
def plot_maps(x, y, heatmaps, width, height):
    titles = {
        'Rhyme CaseA': {'name': 'Rhyme (Case A)', 'num': 0},
        'CAPREOLE+C2-RAY': {'name': 'C2Ray', 'num': 1},
        'Flash': {'name': 'Flash', 'num': 2},
        'HART': {'name': 'HART', 'num': 3},
        'LICORICE': {'name': 'Licorice', 'num': 4},
        'Rhyme CaseB': {'name': 'Rhyme (Case B)', 'num': 5},
        'RSPH': {'name': 'RSPH', 'num': 6},
        'RH1D': {'name': 'RH1D', 'num': 7},
        'ZEUS-MP': {'name': "Zeus-MP", 'num': 8}
    }
    
    fig = fig = make_subplots(
        4, 4,
        column_widths=[1] * 4, row_heights=[1] * 4,
        specs=[
            [{"type": "heatmap", "rowspan": 2, "colspan": 2}, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [None, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [{"type": "heatmap", "rowspan": 2, "colspan": 2}, None, {"type": "heatmap"}, {"type": "heatmap"}],
            [None, None, {"type": "heatmap"}, None],
        ],
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=True, shared_yaxes=True,
        x_title='x [px]', y_title='y [px]',
    )
    
    fig.layout.annotations[0]["font"] = dict(size=18, color='black')
    fig.layout.annotations[1]["font"] = dict(size=18, color='black')
    
    attrs = ['nH', 'p', 'T', 'fHI']
    attr_titles = {'nH': 'n<sub>HI</sub>', 'p': 'p', 'T': 'T', 'fHI': 'f<sub>HI</sub>'}
    attr_cs = {'nH': 'Spectral', 'p': 'Viridis', 'T': 'Inferno', 'fHI': 'Spectral'}
    attr_cs_rng = {'nH': [-2, 1], 'p': [-17, -14.5], 'T': [2, 4.5], 'fHI': [-4, 0]}
    n_slices = len(attrs)
    
    zmin, zmax = {}, {}
    for attr in attrs:
        zmin[attr] = np.min([np.min(xy) for sn in heatmaps.keys() for xy in heatmaps[sn][attr]])
        zmax[attr] = np.max([np.max(xy) for sn in heatmaps.keys() for xy in heatmaps[sn][attr]])
    
    for sn, h in heatmaps.items():
        si = titles[sn]['num']
        for slc, attr in enumerate(attrs):
            d1, d2 = slc * pi / 2 / n_slices, (slc + 1) * pi / 2 / n_slices
            
            xy = h[attr]
            
            for i in range(xy.shape[0]):
                for j in range(xy.shape[1]):
                    d = np.arctan2(j + 0.5, i + 0.5)
                    if not (d1 <= d <= d2):
                        xy[i][j] = None
                        
            fig.add_trace(go.Heatmap(
                x=x, y=y, z=xy, name=f"{sn}: {attr}",
                zauto=False, zmin=attr_cs_rng[attr][0], zmax=attr_cs_rng[attr][1],
                colorbar=dict(
                    title=attr_titles[attr], titleside='bottom', titlefont=dict(size=18),
                    thicknessmode='pixels', thickness=10, tickfont=dict(size=18),
                    x=1.02, y=slc/n_slices, len=1/n_slices, xanchor='left', yanchor='bottom',
                ),
                colorscale=attr_cs[attr], showscale=True if sn == 'Rhyme CaseA' else False,
            ), h['row'], h['col'])
            
            if h['row'] == 4 or sn == 'Rhyme CaseB':
                fig.update_xaxes(
                    row=h['row'], col=h['col'], mirror=True, titlefont=dict(size=18), tickfont=dict(size=14),
                    range=(0, max(x)),
                    tickmode='array', tickvals=[0, 0.2, 0.4, 0.6, 0.8], ticktext=['0', '.2', '.4', '.6', '.8'],
                )
            else:
                fig.update_xaxes(row=h['row'], col=h['col'], mirror=True, showticklabels=False, range=(0, max(x)),)
            
            if h['col'] == 1 or sn == 'Rhyme CaseA':
                fig.update_yaxes(
                    row=h['row'], col=h['col'], mirror=True, titlefont=dict(size=18), tickfont=dict(size=14),
                    range=(0, max(y)),
                    tickmode='array', tickvals=[0, 0.2, 0.4, 0.6, 0.8], ticktext=['0', '.2', '.4', '.6', '.8'],
                )
            else:
                fig.update_yaxes(row=h['row'], col=h['col'], mirror=True, showticklabels=False, range=(0, max(y)),)
                
            l = 0.8 + 8 * 0.8/128
            if d2 <= pi / 4:
                xend = l
                yend = l * np.tan(d2)
                xannot = .9 * l * np.sin((d1 + d2)/2)
                yannot = .9 * l * np.cos((d1 + d2)/2)
            else:
                xend = l * np.tan(pi/2 - d2)
                yend = l
                xannot = .9 * l * np.cos(pi/2 - (d1+d2)/2)
                yannot = .9 * l * np.sin(pi/2 - (d1+d2)/2)
                
            if d2 < pi / 2:
                fig.add_shape(
                    type='line', xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
                    x0=0, y0=0, x1=xend, y1=yend,
                    line=dict(color="white", width=4 if h['col'] == 1 else 2.5), opacity=1.0,
                )
            
            textangle = 90 / (2 * n_slices) * (2 * slc + 1)
            fig.add_annotation(
                x=xannot, y=yannot, xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
                text=attr_titles[attr], showarrow=False,
                font=dict(color='white', size=24 if h['col'] == 1 else 16), textangle=textangle,
            )
            
        fig.add_annotation(
            x=0.73, y=0.02, xanchor='right', yanchor='bottom',
            xref='x' if si == 0 else f"x{si+1}", yref='y' if si == 0 else f"y{si+1}",
            text=titles[sn]['name'], showarrow=False,
            font=dict(color='white', size=22 if h['col'] == 1 else 14), 
        )
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(b=60, t=10, l=60, r=60),
    )
        
    return fig

In [None]:
def maps(t=75):
    if t not in DATA.keys():
        print(f"Don't have a snapshot at t = {t} Myr")
        return
    
    heatmaps = {}
    
    for sn, sim in DATA[t].items():  
        if sn == 'Rhyme CaseA':
            heatmaps[sn] = {}
            heatmaps[sn]['row'] = 1
            heatmaps[sn]['col'] = 1
        elif sn == 'Rhyme CaseB':
            heatmaps[sn] = {}
            heatmaps[sn]['row'] = 3
            heatmaps[sn]['col'] = 1
        elif sn[:5] == 'Rhyme':
            continue
        else:
            heatmaps[sn] = {}
            heatmaps[sn]['row'] = BINARIES[sn]['row']
            heatmaps[sn]['col'] = BINARIES[sn]['col']

        l = 2
        heatmaps[sn]['nH'] = np.log10(sim['nHI'][:, :, l])
        heatmaps[sn]['p'] = np.log10(sim['p'][:, :, l])
        heatmaps[sn]['T'] = np.log10(sim['T'][:, :, l])
        heatmaps[sn]['fHI'] = np.log10(sim['nHI'][:, :, l] / (sim['nHI'][:, :, l] + sim['nHII'][:, :, l]))
    
    x = np.linspace(0, 0.8, 128)
    
    width, height = 1100, 1000
    fig = plot_maps(x, x, heatmaps, width, height)
    
    fig.write_image(f"{FIGURE_DIR}/iliev-6-maps-{int(t):02d}.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/iliev-6-maps-{int(t):02d}.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
#     maps(t=1)
#     maps(t=3)
#     maps(t=10)
    maps(t=25)
#     maps(t=75)

## Lineouts <a class="anchor" id="lineouts"></a>

In [None]:
def plot_lineouts(ts, attrs, titles):
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(DATA[10].keys())))
    
    ticks = {
        'p': {
            'vals': list(range(-20, -10)),
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-20, -10)],
        },
        'T': {
            'vals': list(range(2, 8)),
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(2, 8)],
        },
        'fHI': {
            'vals': list(range(-10, 1, 2)),
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-10, 1, 2)],
        },
        'n': {
            'vals': list(range(-4, 2)),
            'text': [f"10<sup>{x}</sup>" if x != 0 else '1' for x in range(-4, 2)],
        }
    }
    
    len_t = len(ts)
    
    fig = make_subplots(
        len(attrs), len_t,
        column_widths=[1] * len_t, row_heights=[1] * len(attrs),
        vertical_spacing=0.025, horizontal_spacing=0.025,
        shared_xaxes=False, shared_yaxes=True,
        x_title='x / L<sub>box</sub>',
        subplot_titles=[f"t = {t} Myr" for t in ts],
    )
    
    for i in range(len_t):
        fig.layout.annotations[i]["font"] = dict(size=18, color='black')
        fig.layout.annotations[i]["y"] += 0.0125

    fig.layout.annotations[len_t]["font"] = dict(size=18, color='black')
    fig.layout.annotations[len_t]["y"] -= 0.0125
    
    for irow, attr in enumerate(attrs):
        for icol, t in enumerate(ts):
            for si, (sn, sim) in enumerate(DATA[t].items()):
                x = np.linspace(0, 1, len(sim['nHI'][:, 0, 0]))
                l = 2
                if attr == 'fHI':
                    y = np.array(sim['nHI'][:, l, l] / (sim['nHI'][:, l, l] + sim['nHII'][:, l, l]))
                else:
                    y = np.array(sim[attr][:, l, l])
                    
                if sn == 'RH1D' and attr == 'p':
                    y /= 10**(-1.35+15.23)
                    
                if sn == 'RH1D' and attr == 'n':
                    y /= 10**(12.50+1.37)
                    
                y = np.log10(y)

                fig.add_trace(go.Scatter(
                    x=x, y=y, name=f"{sn}", mode='lines',
                    line=dict(color=colors[si], dash=None, width=1),
                    showlegend=True if irow == icol == 0 else False,
                ), irow+1, icol+1)

                fig.update_xaxes(
                    row=irow+1, col=icol+1, type='linear', mirror=True,
                    titlefont=dict(size=18), tickfont=dict(size=14),
                    tickmode='array', tickvals=[0, 0.5, 1], ticktext=['0', '.5', '1'],
                    showticklabels=True if irow == len(attrs)-1 else False,
                )

                tickvals = list(range(-40, 20, 2))
                ticktext = [f"10<sup>{x}</sup>" if x != 0 else '1' for x in tickvals]

                fig.update_yaxes(
                    row=irow+1, col=icol+1, mirror=True,
                    title=titles[irow] if icol == 0 else None,
                    titlefont=dict(size=18), tickfont=dict(size=14),
                    tickmode='array', tickvals=ticks[attr]['vals'], ticktext=ticks[attr]['text'],
                    showticklabels=True if icol == 0 else False,
                )
            
    width = 1000
    height = 1000
    
    for i in range(len(attrs)):
        fig.update_yaxes(row=i+1, col=1, titlefont=dict(size=18), tickfont=dict(size=14),)
    
    fig.update_layout(
        width=width, height=height,
        margin=dict(t=60, r=60, b=70, l=80),
        legend=dict(orientation="h", yanchor="top", y=1.12, xanchor="center", x=.5),
    )
    
    fig.write_image(f"{FIGURE_DIR}/iliev-6-lineouts.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/iliev-6-lineouts.png", width=width, height=height, scale=3)
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    plot_lineouts([1, 10, 25, 75], ['n', 'p', 'T', 'fHI'], ['n [cm<sup>-3</sup>]', 'p [g / m s<sup>2</sup>]', 'T [K]', 'f<sub>HI</sub>'])

## I-front position <a class="anchor" id="pos-vel"></a>

In [None]:
def i_front_get_position(xs, fHI, f0=.5):
    ydir = np.sign(np.diff(fHI))
    turning_points = 1 + np.where(np.diff(ydir) != 0)[0]
    
    signs = np.split(ydir, turning_points)
    xs_grp = np.split(xs, turning_points)
    fHI_grp = np.split(fHI, turning_points)
    fs = [
        (interp1d(y, x, bounds_error=False, fill_value=-1), sgn[0])
        for x, y, sgn in zip(xs_grp, fHI_grp, signs) if len(x) > 2
    ]
    
    return [(float(p[0]), p[1]) for p in [(f[0](f0), f[1]) for f in fs] if p[0] != -1]


def i_front_pos():
    xs = np.sqrt(3) * np.linspace(0, 0.8, 128)
    
    for t, snap in DATA.items():
        if t > 25:
            continue
        for sn, sim in DATA[t].items():
            if 'I-front' in sim:
                del sim['I-front']
                
            sim['I-front'] = {}

            nHI_prof = np.einsum('iii->i', sim['nHI'])
            nHII_prof = np.einsum('iii->i', sim['nHII'])
            fHI_prof = nHII_prof / (nHI_prof + nHII_prof)

            pos = i_front_get_position(xs, fHI_prof, f0=0.9)
            if len(pos) == 1:
                sim['I-front']['pos_kpc'] = pos[0][0]
            else:
                sim['I-front']['pos_kpc'] = None
                print(f"{sn}-{t}: {pos}")

if __name__ == '__main__' and '__file__' not in globals():
    i_front_pos()

In [None]:
def i_front_pos_plot():
    colors = px.colors.sample_colorscale(px.colors.diverging.Portland_r, np.linspace(0, 1, len(DATA[10].keys())))
    
    fig = go.Figure()
    
    times = list(DATA.keys())[:-1]
    for i, sn in enumerate(DATA[10].keys()):
        ys = [DATA[t][sn]['I-front']['pos_kpc'] for t in times]
        
        fig.add_trace(go.Scatter(
            x=times, y=ys, name=sn,
            mode='lines+markers',
            line=dict(color=colors[i], dash=None, width=1,),
        ))

            
    width, height = 500 + 150, 460
            
    fig.update_layout(
        margin=dict(b=10, t=10, l=10, r=10),
        width=width, height=height,
        xaxis=dict(
            type='linear', mirror=True, title='<i>t</i> [Myr]', range=(0, 26),
            titlefont=dict(size=18),
            tickmode='array',
            tickvals=TIMES,
            ticktext=[str(tt) for tt in TIMES],
            tickfont=dict(size=14),
        ),
        yaxis=dict(
            type='linear', mirror=True, title='<i>x</i><sub>I-front</sub> [kpc]',
            titlefont=dict(size=18),
            tickmode='array',
            tickvals=np.linspace(0.2, 1.2, 6),
            ticktext=[f"{yy:2.1f}" for yy in np.linspace(0.2, 1.2, 6)],
            tickfont=dict(size=14),
        ),
        legend=dict(orientation="h", yanchor="top", y=1.17, xanchor="center", x=.5),
    )
    
    fig.write_image(f"{FIGURE_DIR}/iliev-6-I-front-pos.svg", width=width, height=height)
    fig.write_image(f"{FIGURE_DIR}/iliev-6-I-front-pos.png", width=width, height=height, scale=3)   
    fig.show()

if __name__ == '__main__' and '__file__' not in globals():
    i_front_pos_plot()